1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include "public/gemmlowp.h"
22 #include "tensorflow/lite/kernels/cpu_backend_context.h"
23 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
24 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_multithread.h"
25 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h"
26 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30
31 namespace tflite {
32 namespace optimized_ops {
33
34 // Unoptimized reference ops:
35 using reference_ops::ArgMax;
36 using reference_ops::ArgMinMax;
37 using reference_ops::Broadcast4DSlowGreater;
38 using reference_ops::Broadcast4DSlowGreaterEqual;
39 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
40 using reference_ops::Broadcast4DSlowGreaterWithScaling;
41 using reference_ops::Broadcast4DSlowLess;
42 using reference_ops::Broadcast4DSlowLessEqual;
43 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
44 using reference_ops::Broadcast4DSlowLessWithScaling;
45 using reference_ops::BroadcastAdd4DSlow;
46 using reference_ops::BroadcastGreater;
47 using reference_ops::BroadcastGreaterEqual;
48 using reference_ops::BroadcastLess;
49 using reference_ops::BroadcastLessEqual;
50 using reference_ops::BroadcastMul4DSlow;
51 using reference_ops::BroadcastSubSlow;
52 using reference_ops::Concatenation;
53 using reference_ops::ConcatenationWithScaling;
54 using reference_ops::DepthConcatenation;
55 using reference_ops::Div;
56 using reference_ops::FakeQuant;
57 using reference_ops::Gather;
58 using reference_ops::Greater;
59 using reference_ops::GreaterEqual;
60 using reference_ops::GreaterEqualWithScaling;
61 using reference_ops::GreaterWithScaling;
62 using reference_ops::Less;
63 using reference_ops::LessEqual;
64 using reference_ops::LessEqualWithScaling;
65 using reference_ops::LessWithScaling;
66 using reference_ops::Mean;
67 using reference_ops::RankOneSelect;
68 using reference_ops::Relu1;
69 using reference_ops::Relu6;
70 using reference_ops::ReluX;
71 using reference_ops::Select;
72 using reference_ops::SpaceToBatchND;
73 using reference_ops::Split;
74 using reference_ops::TensorFlowSplit;
75
76 static constexpr int kDepthwiseReverseShift = -1;
77
78 template <typename Scalar, int N>
MapAsVector(Scalar * data,const Dims<N> & dims)79 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
80 const int size = FlatSize(dims);
81 return VectorMap<Scalar>(data, size, 1);
82 }
83
84 template <typename Scalar, int N>
MapAsMatrixWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)85 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
86 const Dims<N>& dims) {
87 const int rows = dims.sizes[0];
88 int cols = 1;
89 for (int d = 1; d < N; d++) {
90 cols *= dims.sizes[d];
91 }
92 return MatrixMap<Scalar>(data, rows, cols);
93 }
94
95 template <typename Scalar, int N>
MapAsMatrixWithLastDimAsCols(Scalar * data,const Dims<N> & dims)96 MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
97 const Dims<N>& dims) {
98 const int cols = dims.sizes[N - 1];
99 int rows = 1;
100 for (int d = 0; d < N - 1; d++) {
101 rows *= dims.sizes[d];
102 }
103 return MatrixMap<Scalar>(data, rows, cols);
104 }
105
106 template <typename Scalar, int N>
MapAsArrayWithFirstDimAsRows(Scalar * data,const Dims<N> & dims)107 ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
108 const Dims<N>& dims) {
109 const int rows = dims.sizes[0];
110 int cols = 1;
111 for (int d = 1; d < N; d++) {
112 cols *= dims.sizes[d];
113 }
114 return ArrayMap<Scalar>(data, rows, cols);
115 }
116
117 // TODO(b/62193649): this function is only needed as long
118 // as we have the --variable_batch hack.
119 template <typename Scalar, int N>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const Dims<N> & dims,int rows)120 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
121 const Dims<N>& dims,
122 int rows) {
123 const int flatsize = FlatSize(dims);
124 TFLITE_DCHECK((flatsize % rows) == 0);
125 const int cols = flatsize / rows;
126 return MatrixMap<Scalar>(data, rows, cols);
127 }
128
AreSameDims(const Dims<4> & dims1,const Dims<4> & dims2)129 inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
130 for (int i = 0; i < 4; i++) {
131 if (dims1.sizes[i] != dims2.sizes[i]) {
132 return false;
133 }
134 }
135 return true;
136 }
137
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)138 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
139 const float* filter_data, const Dims<4>& filter_dims,
140 const float* bias_data, const Dims<4>& bias_dims,
141 int stride_width, int stride_height,
142 int dilation_width_factor, int dilation_height_factor,
143 int pad_width, int pad_height, int depth_multiplier,
144 float output_activation_min,
145 float output_activation_max, float* output_data,
146 const Dims<4>& output_dims) {
147 tflite::DepthwiseParams op_params;
148 // Padding type is ignored, but still set.
149 op_params.padding_type = PaddingType::kSame;
150 op_params.padding_values.width = pad_width;
151 op_params.padding_values.height = pad_height;
152 op_params.stride_width = stride_width;
153 op_params.stride_height = stride_height;
154 op_params.dilation_width_factor = dilation_width_factor;
155 op_params.dilation_height_factor = dilation_height_factor;
156 op_params.depth_multiplier = depth_multiplier;
157 op_params.float_activation_min = output_activation_min;
158 op_params.float_activation_max = output_activation_max;
159
160 const RuntimeShape output_shape = DimsToShape(output_dims);
161 const int output_height = output_shape.Dims(1);
162
163 DepthwiseConvImpl(op_params, DimsToShape(input_dims), input_data,
164 DimsToShape(filter_dims), filter_data,
165 DimsToShape(bias_dims), bias_data, output_shape,
166 output_data, CpuFlags(), /*thread_start=*/0,
167 /*thread_end=*/output_height, /*thread_dim=*/1);
168 }
169
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)170 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
171 const float* filter_data, const Dims<4>& filter_dims,
172 const float* bias_data, const Dims<4>& bias_dims,
173 int stride_width, int stride_height, int pad_width,
174 int pad_height, int depth_multiplier,
175 float output_activation_min,
176 float output_activation_max, float* output_data,
177 const Dims<4>& output_dims) {
178 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
179 bias_dims, stride_width, stride_height, 1, 1, pad_width,
180 pad_height, depth_multiplier, output_activation_min,
181 output_activation_max, output_data, output_dims);
182 }
183
184 // legacy, for compatibility with old checked-in code
185 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)186 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
187 const float* filter_data, const Dims<4>& filter_dims,
188 const float* bias_data, const Dims<4>& bias_dims,
189 int stride_width, int stride_height, int pad_width,
190 int pad_height, int depth_multiplier, float* output_data,
191 const Dims<4>& output_dims) {
192 float output_activation_min, output_activation_max;
193 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
194 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
195 bias_dims, stride_width, stride_height, pad_width, pad_height,
196 depth_multiplier, output_activation_min, output_activation_max,
197 output_data, output_dims);
198 }
199
200 // legacy, for compatibility with old checked-in code
201 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)202 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
203 const float* filter_data, const Dims<4>& filter_dims,
204 const float* bias_data, const Dims<4>& bias_dims, int stride,
205 int pad_width, int pad_height, int depth_multiplier,
206 float* output_data, const Dims<4>& output_dims) {
207 DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
208 bias_dims, stride, stride, pad_width, pad_height,
209 depth_multiplier, output_data, output_dims);
210 }
211
212 template <DepthwiseConvOutputRounding kOutputRounding>
LegacyDepthwiseConvWithRounding(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)213 inline void LegacyDepthwiseConvWithRounding(
214 const DepthwiseParams& params, const RuntimeShape& input_shape,
215 const uint8* input_data, const RuntimeShape& filter_shape,
216 const uint8* filter_data, const RuntimeShape& bias_shape,
217 const int32* bias_data, const RuntimeShape& output_shape,
218 uint8* output_data, int thread_start, int thread_end, int thread_dim) {
219 ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
220 const int depth_multiplier = params.depth_multiplier;
221 const int32 output_activation_min = params.quantized_activation_min;
222 const int32 output_activation_max = params.quantized_activation_max;
223 const int dilation_width_factor = params.dilation_width_factor;
224 const int dilation_height_factor = params.dilation_height_factor;
225 TFLITE_DCHECK_GE(dilation_width_factor, 1);
226 TFLITE_DCHECK_GE(dilation_height_factor, 1);
227 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
228 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
229 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
230 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
231 const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
232 const int input_depth = input_shape.Dims(3);
233 TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
234 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
235
236 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
237 // Jetson TX-2. This compiler does not support the offsetof() macro.
238 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
239 const int stride_width = params.stride_width;
240 const int stride_height = params.stride_height;
241 const int pad_width = params.padding_values.width;
242 const int pad_height = params.padding_values.height;
243 const int output_shift = params.output_shift;
244
245 // Call kernel optimized for depthwise convolutions using 3x3 filters if
246 // parameters are supported.
247 if (depthwise_conv::Fast3x3FilterKernelSupported(
248 input_shape, filter_shape, stride_width, stride_height,
249 dilation_width_factor, dilation_height_factor, pad_width, pad_height,
250 depth_multiplier, output_shape, output_shift)) {
251 ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
252 depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
253 params, input_shape, input_data, filter_shape, filter_data, bias_shape,
254 bias_data, output_shape, output_data, thread_start, thread_end,
255 thread_dim);
256 return;
257 }
258 #endif
259
260 ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
261 depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
262 filter_shape, filter_data, bias_shape,
263 bias_data, output_shape, output_data,
264 thread_start, thread_end, thread_dim);
265 }
266
LegacyDepthwiseConvImpl(const DepthwiseParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,int thread_start,int thread_end,int thread_dim)267 inline void LegacyDepthwiseConvImpl(
268 const DepthwiseParams& params, const RuntimeShape& input_shape,
269 const uint8* input_data, const RuntimeShape& filter_shape,
270 const uint8* filter_data, const RuntimeShape& bias_shape,
271 const int32* bias_data, const RuntimeShape& output_shape,
272 uint8* output_data, int thread_start, int thread_end, int thread_dim) {
273 return LegacyDepthwiseConvWithRounding<
274 DepthwiseConvOutputRounding::kAwayFromZero>(
275 params, input_shape, input_data, filter_shape, filter_data, bias_shape,
276 bias_data, output_shape, output_data, thread_start, thread_end,
277 thread_dim);
278 }
279
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)280 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
281 int32 input_offset, const uint8* filter_data,
282 const Dims<4>& filter_dims, int32 filter_offset,
283 const int32* bias_data, const Dims<4>& bias_dims,
284 int stride_width, int stride_height,
285 int dilation_width_factor, int dilation_height_factor,
286 int pad_width, int pad_height, int depth_multiplier,
287 int32 output_offset, int32 output_multiplier,
288 int output_shift, int32 output_activation_min,
289 int32 output_activation_max, uint8* output_data,
290 const Dims<4>& output_dims) {
291 tflite::DepthwiseParams op_params;
292 // Padding type is ignored, but still set.
293 op_params.padding_type = PaddingType::kSame;
294 op_params.padding_values.width = pad_width;
295 op_params.padding_values.height = pad_height;
296 op_params.stride_width = stride_width;
297 op_params.stride_height = stride_height;
298 op_params.dilation_width_factor = dilation_width_factor;
299 op_params.dilation_height_factor = dilation_height_factor;
300 op_params.depth_multiplier = depth_multiplier;
301 op_params.quantized_activation_min = output_activation_min;
302 op_params.quantized_activation_max = output_activation_max;
303 op_params.input_offset = input_offset;
304 op_params.weights_offset = filter_offset;
305 op_params.output_offset = output_offset;
306 op_params.output_multiplier = output_multiplier;
307 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
308 op_params.output_shift = kDepthwiseReverseShift * output_shift;
309
310 const RuntimeShape output_shape = DimsToShape(output_dims);
311 const int output_height = output_shape.Dims(1);
312
313 LegacyDepthwiseConvImpl(
314 op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
315 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
316 output_data, /*thread_start=*/0,
317 /*thread_end=*/output_height, /*thread_dim=*/1);
318 }
319
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)320 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
321 int32 input_offset, const uint8* filter_data,
322 const Dims<4>& filter_dims, int32 filter_offset,
323 const int32* bias_data, const Dims<4>& bias_dims,
324 int stride_width, int stride_height, int pad_width,
325 int pad_height, int depth_multiplier,
326 int32 output_offset, int32 output_multiplier,
327 int output_shift, int32 output_activation_min,
328 int32 output_activation_max, uint8* output_data,
329 const Dims<4>& output_dims) {
330 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
331 filter_offset, bias_data, bias_dims, stride_width,
332 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
333 output_offset, output_multiplier, output_shift,
334 output_activation_min, output_activation_max, output_data,
335 output_dims);
336 }
337
338 // Legacy, for compatibility with old checked-in code.
339 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)340 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
341 int32 input_offset, const uint8* filter_data,
342 const Dims<4>& filter_dims, int32 filter_offset,
343 const int32* bias_data, const Dims<4>& bias_dims,
344 int stride_width, int stride_height, int pad_width,
345 int pad_height, int depth_multiplier, int32 output_offset,
346 int32 output_multiplier, int output_shift,
347 int32 output_activation_min, int32 output_activation_max,
348 uint8* output_data, const Dims<4>& output_dims) {
349 if (Ac == FusedActivationFunctionType::kNone) {
350 TFLITE_DCHECK_EQ(output_activation_min, 0);
351 TFLITE_DCHECK_EQ(output_activation_max, 255);
352 }
353 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
354 filter_offset, bias_data, bias_dims, stride_width,
355 stride_height, pad_width, pad_height, depth_multiplier,
356 output_offset, output_multiplier, output_shift,
357 output_activation_min, output_activation_max, output_data,
358 output_dims);
359 }
360
361 // Legacy, for compatibility with old checked-in code.
362 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)363 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
364 int32 input_offset, const uint8* filter_data,
365 const Dims<4>& filter_dims, int32 filter_offset,
366 const int32* bias_data, const Dims<4>& bias_dims, int stride,
367 int pad_width, int pad_height, int depth_multiplier,
368 int32 output_offset, int32 output_multiplier,
369 int output_shift, int32 output_activation_min,
370 int32 output_activation_max, uint8* output_data,
371 const Dims<4>& output_dims) {
372 DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
373 filter_dims, filter_offset, bias_data, bias_dims, stride,
374 stride, pad_width, pad_height, depth_multiplier,
375 output_offset, output_multiplier, output_shift,
376 output_activation_min, output_activation_max, output_data,
377 output_dims);
378 }
379
380 template <typename T, typename TS>
381 struct LegacyDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyDepthwiseConvWorkerTaskLegacyDepthwiseConvWorkerTask382 LegacyDepthwiseConvWorkerTask(
383 const DepthwiseParams& params, const RuntimeShape& input_shape,
384 const T* input_data, const RuntimeShape& filter_shape,
385 const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
386 const RuntimeShape& output_shape, T* output_data, int thread_start,
387 int thread_end, int thread_dim)
388 : params_(params),
389 input_shape_(input_shape),
390 input_data_(input_data),
391 filter_shape_(filter_shape),
392 filter_data_(filter_data),
393 bias_shape_(bias_shape),
394 bias_data_(bias_data),
395 output_shape_(output_shape),
396 output_data_(output_data),
397 thread_start_(thread_start),
398 thread_end_(thread_end),
399 thread_dim_(thread_dim) {}
400
RunLegacyDepthwiseConvWorkerTask401 void Run() override {
402 LegacyDepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
403 filter_data_, bias_shape_, bias_data_,
404 output_shape_, output_data_, thread_start_,
405 thread_end_, thread_dim_);
406 }
407
408 private:
409 const DepthwiseParams& params_;
410 const RuntimeShape& input_shape_;
411 const T* input_data_;
412 const RuntimeShape& filter_shape_;
413 const T* filter_data_;
414 const RuntimeShape& bias_shape_;
415 const TS* bias_data_;
416 const RuntimeShape& output_shape_;
417 T* output_data_;
418 int thread_start_;
419 int thread_end_;
420 int thread_dim_;
421 };
422
HowManyConvThreads(const RuntimeShape & output_shape,const RuntimeShape & filter_shape,int thread_dim)423 inline int HowManyConvThreads(const RuntimeShape& output_shape,
424 const RuntimeShape& filter_shape,
425 int thread_dim) {
426 constexpr int kMinMulPerThread = 8;
427 const int output_units = output_shape.Dims(thread_dim);
428 const int filter_height = filter_shape.Dims(1);
429 const int filter_width = filter_shape.Dims(2);
430 const int num_mul_per_unit =
431 FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width;
432 const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1;
433 int thread_count = output_units / min_units_per_thread;
434 return thread_count;
435 }
436
437 inline void DepthwiseConv(
438 const DepthwiseParams& params, const RuntimeShape& input_shape,
439 const uint8* input_data, const RuntimeShape& filter_shape,
440 const uint8* filter_data, const RuntimeShape& bias_shape,
441 const int32* bias_data, const RuntimeShape& output_shape,
442 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) {
443 ruy::profiler::ScopeLabel label("DepthwiseConv");
444
445 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
446 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
447 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
448
449 const int output_batches = output_shape.Dims(0);
450 const int output_rows = output_shape.Dims(1);
451 int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
452 int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
453 int thread_dim, thread_count, thread_dim_size;
454 if (thread_count_batch > thread_count_row) {
455 thread_dim = 0;
456 thread_dim_size = output_batches;
457 thread_count = thread_count_batch;
458 } else {
459 thread_dim = 1;
460 thread_dim_size = output_rows;
461 thread_count = thread_count_row;
462 }
463
464 const int max_threads =
465 gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
466 thread_count = std::max(1, std::min(thread_count, max_threads));
467
468 if (thread_count == 1) {
469 LegacyDepthwiseConvImpl(params, input_shape, input_data, filter_shape,
470 filter_data, bias_shape, bias_data, output_shape,
471 output_data, /*thread_start=*/0,
472 /*thread_end=*/output_rows, /*thread_dim=*/1);
473 } else {
474 std::vector<gemmlowp::Task*> tasks(thread_count);
475 int thread_start = 0;
476 for (int i = 0; i < thread_count; ++i) {
477 int thread_end =
478 thread_start + (thread_dim_size - thread_start) / (thread_count - i);
479 tasks[i] = new LegacyDepthwiseConvWorkerTask<uint8, int32>(
480 params, input_shape, input_data, filter_shape, filter_data,
481 bias_shape, bias_data, output_shape, output_data, thread_start,
482 thread_end, thread_dim);
483 thread_start = thread_end;
484 }
485 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
486 }
487 }
488
489 template <typename T, typename TS>
490 struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task {
LegacyPerChannelDepthwiseConvWorkerTaskLegacyPerChannelDepthwiseConvWorkerTask491 LegacyPerChannelDepthwiseConvWorkerTask(
492 const DepthwiseParams& params, const int32* output_multiplier,
493 const int32* output_shift, const RuntimeShape& input_shape,
494 const T* input_data, const RuntimeShape& filter_shape,
495 const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data,
496 const RuntimeShape& output_shape, T* output_data, int thread_start,
497 int thread_end, int thread_dim)
498 : params_(params),
499 output_multiplier_(output_multiplier),
500 output_shift_(output_shift),
501 input_shape_(input_shape),
502 input_data_(input_data),
503 filter_shape_(filter_shape),
504 filter_data_(filter_data),
505 bias_shape_(bias_shape),
506 bias_data_(bias_data),
507 output_shape_(output_shape),
508 output_data_(output_data),
509 thread_start_(thread_start),
510 thread_end_(thread_end),
511 thread_dim_(thread_dim) {}
512
RunLegacyPerChannelDepthwiseConvWorkerTask513 void Run() override {
514 CpuBackendContext backend_context;
515 optimized_integer_ops::DepthwiseConvImpl(
516 params_, output_multiplier_, output_shift_, input_shape_, input_data_,
517 filter_shape_, filter_data_, bias_shape_, bias_data_, output_shape_,
518 output_data_, thread_start_, thread_end_, thread_dim_, backend_context);
519 }
520
521 private:
522 const DepthwiseParams& params_;
523 const int32* output_multiplier_;
524 const int32* output_shift_;
525 const RuntimeShape& input_shape_;
526 const T* input_data_;
527 const RuntimeShape& filter_shape_;
528 const T* filter_data_;
529 const RuntimeShape& bias_shape_;
530 const TS* bias_data_;
531 const RuntimeShape& output_shape_;
532 T* output_data_;
533 int thread_start_;
534 int thread_end_;
535 int thread_dim_;
536 };
537
538 inline void DepthwiseConvPerChannel(
539 const DepthwiseParams& params, const int32* output_multiplier,
540 const int32* output_shift, const RuntimeShape& input_shape,
541 const int8* input_data, const RuntimeShape& filter_shape,
542 const int8* filter_data, const RuntimeShape& bias_shape,
543 const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
544 gemmlowp::GemmContext* gemmlowp_context = nullptr) {
545 ruy::profiler::ScopeLabel label("DepthwiseConvInt8");
546
547 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
548 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
549 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
550
551 const int output_batches = output_shape.Dims(0);
552 const int output_rows = output_shape.Dims(1);
553 int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
554 int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
555 int thread_dim, thread_count, thread_dim_size;
556 if (thread_count_batch > thread_count_row) {
557 thread_dim = 0;
558 thread_dim_size = output_batches;
559 thread_count = thread_count_batch;
560 } else {
561 thread_dim = 1;
562 thread_dim_size = output_rows;
563 thread_count = thread_count_row;
564 }
565
566 const int max_threads =
567 gemmlowp_context ? gemmlowp_context->max_num_threads() : 1;
568 thread_count = std::max(1, std::min(thread_count, max_threads));
569
570 if (thread_count == 1) {
571 CpuBackendContext backend_context;
572 optimized_integer_ops::DepthwiseConvImpl(
573 params, output_multiplier, output_shift, input_shape, input_data,
574 filter_shape, filter_data, bias_shape, bias_data, output_shape,
575 output_data, /*thread_start=*/0,
576 /*thread_end=*/output_rows, /*thread_dim=*/1, backend_context);
577 } else {
578 std::vector<gemmlowp::Task*> tasks(thread_count);
579 int thread_start = 0;
580 for (int i = 0; i < thread_count; ++i) {
581 int thread_end =
582 thread_start + (thread_dim_size - thread_start) / (thread_count - i);
583 tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask<int8, int32>(
584 params, output_multiplier, output_shift, input_shape, input_data,
585 filter_shape, filter_data, bias_shape, bias_data, output_shape,
586 output_data, thread_start, thread_end, thread_dim);
587 thread_start = thread_end;
588 }
589 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
590 }
591 }
592
DepthwiseConv(const DepthwiseParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data)593 inline void DepthwiseConv(
594 const DepthwiseParams& params, const RuntimeShape& input_shape,
595 const float* input_data, const RuntimeShape& filter_shape,
596 const float* filter_data, const RuntimeShape& bias_shape,
597 const float* bias_data, const RuntimeShape& output_shape,
598 float* output_data) {
599 DepthwiseConvImpl(params, input_shape, input_data, filter_shape, filter_data,
600 bias_shape, bias_data, output_shape, output_data,
601 CpuFlags(),
602 /*thread_start=*/0,
603 /*thread_end=*/output_shape.Dims(1), /*thread_dim=*/1);
604 }
605
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims,float output_activation_min,float output_activation_max)606 inline void AddBiasAndEvalActivationFunction(const float* bias_data,
607 const Dims<4>& bias_dims,
608 float* array_data,
609 const Dims<4>& array_dims,
610 float output_activation_min,
611 float output_activation_max) {
612 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
613 DimsToShape(bias_dims), bias_data,
614 DimsToShape(array_dims), array_data);
615 }
616
617 // legacy, for compatibility with old checked-in code
618 template <FusedActivationFunctionType Ac>
AddBiasAndEvalActivationFunction(const float * bias_data,const Dims<4> & bias_dims,float * array_data,const Dims<4> & array_dims)619 void AddBiasAndEvalActivationFunction(const float* bias_data,
620 const Dims<4>& bias_dims,
621 float* array_data,
622 const Dims<4>& array_dims) {
623 float output_activation_min, output_activation_max;
624 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
625 AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
626 output_activation_min,
627 output_activation_max);
628 }
629
630 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)631 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
632 Eigen::MatrixBase<Result>* result) {
633 if (rhs.cols() == 1) {
634 ruy::profiler::ScopeLabel label("GEMV");
635 result->col(0).noalias() = lhs * rhs.col(0);
636 } else {
637 ruy::profiler::ScopeLabel label("GEMM");
638 result->noalias() = lhs * rhs;
639 }
640 }
641
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)642 inline void FullyConnected(
643 const FullyConnectedParams& params, const RuntimeShape& input_shape,
644 const float* input_data, const RuntimeShape& weights_shape,
645 const float* weights_data, const RuntimeShape& bias_shape,
646 const float* optional_bias_data, const RuntimeShape& output_shape,
647 float* output_data) {
648 ruy::profiler::ScopeLabel label("FullyConnected");
649 const float output_activation_min = params.float_activation_min;
650 const float output_activation_max = params.float_activation_max;
651
652 // TODO(b/62193649): this convoluted shape computation (determining
653 // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
654 // is because the current --variable_batch hack consists in overwriting the
655 // 3rd dimension with the runtime batch size, as we don't keep track for each
656 // array of which dimension is the batch dimension in it.
657 // When that is fixed, this should become:
658 // const auto input_matrix_map =
659 // MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
660 const int dims_count = weights_shape.DimensionsCount();
661 const int input_rows = weights_shape.Dims(dims_count - 1);
662 const auto input_matrix_map =
663 MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
664 const auto filter_matrix_map =
665 MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
666 auto output_matrix_map =
667 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
668
669 Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
670
671 if (optional_bias_data != nullptr) {
672 AddBiasAndEvalActivationFunction(
673 output_activation_min, output_activation_max, bias_shape,
674 optional_bias_data, output_shape, output_data);
675 } else {
676 const int flat_size = output_shape.FlatSize();
677 for (int i = 0; i < flat_size; ++i) {
678 output_data[i] = ActivationFunctionWithMinMax(
679 output_data[i], output_activation_min, output_activation_max);
680 }
681 }
682 }
683
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)684 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
685 const float* weights_data,
686 const Dims<4>& weights_dims, const float* bias_data,
687 const Dims<4>& bias_dims,
688 float output_activation_min,
689 float output_activation_max, float* output_data,
690 const Dims<4>& output_dims) {
691 tflite::FullyConnectedParams op_params;
692 op_params.float_activation_min = output_activation_min;
693 op_params.float_activation_max = output_activation_max;
694
695 FullyConnected(op_params, DimsToShape(input_dims), input_data,
696 DimsToShape(weights_dims), weights_data,
697 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
698 output_data);
699 }
700
701 // legacy, for compatibility with old checked-in code
702 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)703 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
704 const float* weights_data, const Dims<4>& weights_dims,
705 const float* bias_data, const Dims<4>& bias_dims,
706 float* output_data, const Dims<4>& output_dims) {
707 float output_activation_min, output_activation_max;
708 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
709 FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
710 bias_dims, output_activation_min, output_activation_max,
711 output_data, output_dims);
712 }
713
714 struct GemmlowpOutputPipeline {
715 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
716 ColVectorMap;
717 typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
718 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
719 gemmlowp::OutputStageClamp,
720 gemmlowp::OutputStageSaturatingCastToUint8>
721 Pipeline;
MakeExpGemmlowpOutputPipeline722 static Pipeline MakeExp(const int32* bias_data, int output_rows,
723 int32 output_offset, int32 output_multiplier,
724 int output_left_shift, int32 output_activation_min,
725 int32 output_activation_max) {
726 ColVectorMap bias_vector(bias_data, output_rows);
727 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
728 bias_addition_stage.bias_vector = bias_vector;
729 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
730 quantize_down_stage.result_offset_after_shift = output_offset;
731 quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
732 quantize_down_stage.result_exponent = output_left_shift;
733 gemmlowp::OutputStageClamp clamp_stage;
734 clamp_stage.min = output_activation_min;
735 clamp_stage.max = output_activation_max;
736 gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
737 return std::make_tuple(bias_addition_stage, quantize_down_stage,
738 clamp_stage, saturating_cast_stage);
739 }
740 };
741
742 struct GemmlowpOutputPipelineInt8 {
743 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
744 ColVectorMap;
745 typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
746 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
747 gemmlowp::OutputStageClamp,
748 gemmlowp::OutputStageSaturatingCastToInt8>
749 Pipeline;
MakeExpGemmlowpOutputPipelineInt8750 static Pipeline MakeExp(const int32* bias_data, int output_rows,
751 int32 output_offset, int32 output_multiplier,
752 int output_left_shift, int32 output_activation_min,
753 int32 output_activation_max) {
754 ColVectorMap bias_vector(bias_data, output_rows);
755 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
756 bias_addition_stage.bias_vector = bias_vector;
757 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
758 quantize_down_stage.result_offset_after_shift = output_offset;
759 quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
760 quantize_down_stage.result_exponent = output_left_shift;
761 gemmlowp::OutputStageClamp clamp_stage;
762 clamp_stage.min = output_activation_min;
763 clamp_stage.max = output_activation_max;
764 gemmlowp::OutputStageSaturatingCastToInt8 saturating_cast_stage;
765 return std::make_tuple(bias_addition_stage, quantize_down_stage,
766 clamp_stage, saturating_cast_stage);
767 }
768 };
769
770 #ifdef USE_NEON
LegacyFullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)771 inline void LegacyFullyConnectedAsGEMVWorkerImpl(
772 const RuntimeShape& input_shape, const uint8* input_data,
773 int32 input_offset, const RuntimeShape& filter_shape,
774 const uint8* filter_data, int32 filter_offset,
775 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
776 int32 output_multiplier, int output_shift, int32 output_activation_min,
777 int32 output_activation_max, const RuntimeShape& output_shape,
778 uint8* output_data, int row_start, int row_end) {
779 ruy::profiler::ScopeLabel label("FullyConnectedAsGEMV/8bit");
780 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
781 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
782 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
783 const int output_dim_count = output_shape.DimensionsCount();
784 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
785 const int input_size = FlatSizeSkipDim(input_shape, 0);
786 static constexpr int kPeel = 4;
787 const bool shift_left = (output_shift > 0);
788 for (int k = 0; k < input_size; k += 64) {
789 optimized_ops_preload_l1_stream(input_data + k);
790 }
791 for (int k = 0; k < kPeel * input_size; k += 64) {
792 optimized_ops_preload_l1_stream(filter_data + k);
793 }
794
795 TFLITE_DCHECK_GE(row_end - row_start, kPeel);
796
797 for (int out = row_start; out < row_end; out += kPeel) {
798 out = std::min(out, row_end - kPeel);
799 int32x4_t acc0 = vdupq_n_s32(0);
800 int32x4_t acc1 = acc0;
801 int32x4_t acc2 = acc0;
802 int32x4_t acc3 = acc0;
803 const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
804 const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
805 int in = 0;
806 for (; in <= input_size - 16; in += 16) {
807 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
808 const uint8* filter_ptr = filter_data + in + out * input_size;
809 uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
810 optimized_ops_preload_l1_stream(filter_ptr + 64);
811 filter_ptr += input_size;
812 uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
813 optimized_ops_preload_l1_stream(filter_ptr + 64);
814 filter_ptr += input_size;
815 uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
816 optimized_ops_preload_l1_stream(filter_ptr + 64);
817 filter_ptr += input_size;
818 uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
819 optimized_ops_preload_l1_stream(filter_ptr + 64);
820 int16x8_t input_val_0, input_val_1;
821 uint8x8_t low = vget_low_u8(input_val_u8);
822 uint8x8_t high = vget_high_u8(input_val_u8);
823 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
824 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
825 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
826 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
827 low = vget_low_u8(filter_val_u8_0);
828 high = vget_high_u8(filter_val_u8_0);
829 int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
830 int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
831 filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
832 filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
833 low = vget_low_u8(filter_val_u8_1);
834 high = vget_high_u8(filter_val_u8_1);
835 int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
836 int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
837 filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
838 filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
839 low = vget_low_u8(filter_val_u8_2);
840 high = vget_high_u8(filter_val_u8_2);
841 int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
842 int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
843 filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
844 filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
845 low = vget_low_u8(filter_val_u8_3);
846 high = vget_high_u8(filter_val_u8_3);
847 int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
848 int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
849 filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
850 filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
851 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
852 vget_low_s16(input_val_0));
853 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
854 vget_low_s16(input_val_0));
855 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
856 vget_low_s16(input_val_0));
857 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
858 vget_low_s16(input_val_0));
859 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
860 vget_low_s16(input_val_1));
861 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
862 vget_low_s16(input_val_1));
863 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
864 vget_low_s16(input_val_1));
865 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
866 vget_low_s16(input_val_1));
867 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
868 vget_high_s16(input_val_0));
869 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
870 vget_high_s16(input_val_0));
871 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
872 vget_high_s16(input_val_0));
873 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
874 vget_high_s16(input_val_0));
875 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
876 vget_high_s16(input_val_1));
877 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
878 vget_high_s16(input_val_1));
879 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
880 vget_high_s16(input_val_1));
881 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
882 vget_high_s16(input_val_1));
883 }
884 for (; in <= input_size - 8; in += 8) {
885 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
886 const uint8* filter_ptr = filter_data + in + out * input_size;
887 uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
888 filter_ptr += input_size;
889 uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
890 filter_ptr += input_size;
891 uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
892 filter_ptr += input_size;
893 uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
894 int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
895 input_val = vaddq_s16(input_val, input_offset_vec);
896 int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
897 filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
898 int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
899 filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
900 int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
901 filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
902 int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
903 filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
904 acc0 =
905 vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
906 acc1 =
907 vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
908 acc2 =
909 vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
910 acc3 =
911 vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
912 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
913 vget_high_s16(input_val));
914 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
915 vget_high_s16(input_val));
916 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
917 vget_high_s16(input_val));
918 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
919 vget_high_s16(input_val));
920 }
921 if (in < input_size) {
922 int32 buf[16];
923 vst1q_s32(buf + 0, acc0);
924 vst1q_s32(buf + 4, acc1);
925 vst1q_s32(buf + 8, acc2);
926 vst1q_s32(buf + 12, acc3);
927 for (; in < input_size; in++) {
928 int lane = (in + 8 - input_size) % 4;
929 const int32 input_val = input_data[in] + input_offset;
930 for (int k = 0; k < kPeel; k++) {
931 int32 filter_val =
932 filter_data[in + (out + k) * input_size] + filter_offset;
933 buf[lane + 4 * k] += filter_val * input_val;
934 }
935 }
936 acc0 = vld1q_s32(buf + 0);
937 acc1 = vld1q_s32(buf + 4);
938 acc2 = vld1q_s32(buf + 8);
939 acc3 = vld1q_s32(buf + 12);
940 }
941
942 // Horizontally reduce accumulators
943 int32x2_t pairwise_reduced_acc_0 =
944 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
945 int32x2_t pairwise_reduced_acc_1 =
946 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
947 int32x2_t pairwise_reduced_acc_2 =
948 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
949 int32x2_t pairwise_reduced_acc_3 =
950 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
951 const int32x2_t reduced_lo =
952 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
953 const int32x2_t reduced_hi =
954 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
955 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
956 // Add bias values.
957 int32x4_t bias_vec = vld1q_s32(bias_data + out);
958 reduced = vaddq_s32(reduced, bias_vec);
959 if (shift_left) {
960 const int32 multiplier_power_of_two = 1 << output_shift;
961 reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
962 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
963 } else {
964 // Multiply by the fixed-point multiplier.
965 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
966 // Rounding-shift-right.
967 using gemmlowp::RoundingDivideByPOT;
968 reduced = RoundingDivideByPOT(reduced, -output_shift);
969 }
970 // Add the output offset.
971 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
972 reduced = vaddq_s32(reduced, output_offset_vec);
973 // Narrow values down to 16 bit signed.
974 const int16x4_t res16 = vqmovn_s32(reduced);
975 // Narrow values down to 8 bit unsigned, saturating.
976 uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
977 // Apply the clamping from the activation function
978 res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
979 res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
980 // Store results to destination.
981 vst1_lane_u8(output_data + out + 0, res8, 0);
982 vst1_lane_u8(output_data + out + 1, res8, 1);
983 vst1_lane_u8(output_data + out + 2, res8, 2);
984 vst1_lane_u8(output_data + out + 3, res8, 3);
985 }
986 }
987
988 struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyFullyConnectedAsGEMVWorkerTaskLegacyFullyConnectedAsGEMVWorkerTask989 LegacyFullyConnectedAsGEMVWorkerTask(
990 const RuntimeShape& input_shape, const uint8* input_data,
991 int32 input_offset, const RuntimeShape& filter_shape,
992 const uint8* filter_data, int32 filter_offset,
993 const RuntimeShape& bias_shape, const int32* bias_data,
994 int32 output_offset, int32 output_multiplier, int output_shift,
995 int32 output_activation_min, int32 output_activation_max,
996 const RuntimeShape& output_shape, uint8* output_data, int row_start,
997 int row_end)
998 : input_shape_(input_shape),
999 input_data_(input_data),
1000 input_offset_(input_offset),
1001 filter_shape_(filter_shape),
1002 filter_data_(filter_data),
1003 filter_offset_(filter_offset),
1004 bias_shape_(bias_shape),
1005 bias_data_(bias_data),
1006 output_offset_(output_offset),
1007 output_multiplier_(output_multiplier),
1008 output_shift_(output_shift),
1009 output_activation_min_(output_activation_min),
1010 output_activation_max_(output_activation_max),
1011 output_shape_(output_shape),
1012 output_data_(output_data),
1013 row_start_(row_start),
1014 row_end_(row_end) {}
1015
RunLegacyFullyConnectedAsGEMVWorkerTask1016 void Run() override {
1017 LegacyFullyConnectedAsGEMVWorkerImpl(
1018 input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1019 filter_offset_, bias_shape_, bias_data_, output_offset_,
1020 output_multiplier_, output_shift_, output_activation_min_,
1021 output_activation_max_, output_shape_, output_data_, row_start_,
1022 row_end_);
1023 }
1024
1025 const RuntimeShape& input_shape_;
1026 const uint8* input_data_;
1027 int32 input_offset_;
1028 const RuntimeShape& filter_shape_;
1029 const uint8* filter_data_;
1030 int32 filter_offset_;
1031 const RuntimeShape& bias_shape_;
1032 const int32* bias_data_;
1033 int32 output_offset_;
1034 int32 output_multiplier_;
1035 int output_shift_;
1036 int32 output_activation_min_;
1037 int32 output_activation_max_;
1038 const RuntimeShape& output_shape_;
1039 uint8* output_data_;
1040 int row_start_;
1041 int row_end_;
1042 };
1043
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1044 inline void FullyConnectedAsGEMV(
1045 const RuntimeShape& input_shape, const uint8* input_data,
1046 int32 input_offset, const RuntimeShape& filter_shape,
1047 const uint8* filter_data, int32 filter_offset,
1048 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1049 int32 output_multiplier, int output_shift, int32 output_activation_min,
1050 int32 output_activation_max, const RuntimeShape& output_shape,
1051 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1052 const int output_dim_count = output_shape.DimensionsCount();
1053 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1054 const int output_rows = output_shape.Dims(output_dim_count - 1);
1055 const int input_size = FlatSizeSkipDim(input_shape, 0);
1056 static constexpr int kKernelRows = 4;
1057 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1058 gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
1059 if (thread_count == 1) {
1060 // Single-thread case: do the computation on the current thread, don't
1061 // use a threadpool
1062 LegacyFullyConnectedAsGEMVWorkerImpl(
1063 input_shape, input_data, input_offset, filter_shape, filter_data,
1064 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1065 output_shift, output_activation_min, output_activation_max,
1066 output_shape, output_data, 0, output_rows);
1067 return;
1068 }
1069
1070 // Multi-threaded case: use the gemmlowp context's threadpool.
1071 TFLITE_DCHECK_GT(thread_count, 1);
1072 std::vector<gemmlowp::Task*> tasks(thread_count);
1073 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
1074 gemmlowp::CeilQuotient(output_rows, thread_count));
1075 int row_start = 0;
1076 for (int i = 0; i < thread_count; ++i) {
1077 int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1078 tasks[i] = new LegacyFullyConnectedAsGEMVWorkerTask(
1079 input_shape, input_data, input_offset, filter_shape, filter_data,
1080 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1081 output_shift, output_activation_min, output_activation_max,
1082 output_shape, output_data, row_start, row_end);
1083 row_start = row_end;
1084 }
1085 TFLITE_DCHECK_EQ(row_start, output_rows);
1086 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
1087 }
1088 #endif // USE_NEON
1089
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemmlowp_context)1090 inline void FullyConnected(
1091 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1092 const uint8* input_data, const RuntimeShape& filter_shape,
1093 const uint8* filter_data, const RuntimeShape& bias_shape,
1094 const int32* bias_data, const RuntimeShape& output_shape,
1095 uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1096 ruy::profiler::ScopeLabel label("FullyConnected/8bit");
1097 const int32 input_offset = params.input_offset;
1098 const int32 filter_offset = params.weights_offset;
1099 const int32 output_offset = params.output_offset;
1100 const int32 output_multiplier = params.output_multiplier;
1101 const int output_shift = params.output_shift;
1102 const int32 output_activation_min = params.quantized_activation_min;
1103 const int32 output_activation_max = params.quantized_activation_max;
1104 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1105 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1106 // TODO(b/62193649): This really should be:
1107 // const int batches = ArraySize(output_dims, 1);
1108 // but the current --variable_batch hack consists in overwriting the 3rd
1109 // dimension with the runtime batch size, as we don't keep track for each
1110 // array of which dimension is the batch dimension in it.
1111 const int output_dim_count = output_shape.DimensionsCount();
1112 const int filter_dim_count = filter_shape.DimensionsCount();
1113 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1114 #ifdef USE_NEON
1115 if (batches == 1) {
1116 const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1117 output_shape, output_dim_count - 1);
1118 if (output_size >= 4) {
1119 return FullyConnectedAsGEMV(
1120 input_shape, input_data, input_offset, filter_shape, filter_data,
1121 filter_offset, bias_shape, bias_data, output_offset,
1122 output_multiplier, output_shift, output_activation_min,
1123 output_activation_max, output_shape, output_data, gemmlowp_context);
1124 }
1125 }
1126 #endif // USE_NEON
1127 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1128 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1129 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1130 const int output_rows = output_shape.Dims(output_dim_count - 1);
1131 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1132 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1133
1134 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1135 filter_data, output_rows, filter_cols, filter_cols);
1136 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1137 input_data, filter_cols, batches, filter_cols);
1138 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1139 output_data, output_rows, batches, output_rows);
1140 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1141 bias_data, output_rows, output_offset, output_multiplier, output_shift,
1142 output_activation_min, output_activation_max);
1143 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1144 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1145 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
1146 filter_offset, input_offset, output_pipeline);
1147 }
1148
1149 #ifdef GEMMLOWP_NEON
1150 // In the common case of batch size 1, a fully-connected node degenerates
1151 // to a matrix*vector product. LSTM cells contain a fully-connected node;
1152 // when quantized, this becomes a special type of GEMV operation where
1153 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1154 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
1155 const uint8* input_data,
1156 const RuntimeShape& weights_shape,
1157 const uint8* weights_data, uint8 weights_zero_point,
1158 const RuntimeShape& bias_shape,
1159 const int32* bias_data, int32 accum_multiplier,
1160 int accum_shift, const RuntimeShape& output_shape,
1161 int16* output_data) {
1162 ruy::profiler::ScopeLabel label("GEMVForLstmCell");
1163 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1164 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1165 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1166 const int output_dim_count = output_shape.DimensionsCount();
1167 const int weights_dim_count = weights_shape.DimensionsCount();
1168 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1169 const int input_size = FlatSizeSkipDim(input_shape, 0);
1170 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1171 output_shape, output_dim_count - 1);
1172 // This special fast path for quantized LSTM cells does not try to support
1173 // odd sizes that we haven't encountered in any LSTM cell, that would
1174 // require special code (that would go untested until any LSTM cell
1175 // exercises it). We just guard our assumptions about size evenness with
1176 // the following assertions.
1177 TFLITE_DCHECK(!(output_size % 4));
1178 TFLITE_DCHECK(!(input_size % 8));
1179 const int32* bias_ptr = bias_data;
1180 int16* output_ptr = output_data;
1181 for (int out = 0; out < output_size; out += 4) {
1182 int32x4_t acc_0 = vdupq_n_s32(0);
1183 int32x4_t acc_1 = vdupq_n_s32(0);
1184 int32x4_t acc_2 = vdupq_n_s32(0);
1185 int32x4_t acc_3 = vdupq_n_s32(0);
1186 const int16x8_t input_offset_vec = vdupq_n_s16(-128);
1187 const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
1188 int in = 0;
1189 // Handle 16 levels of depth at a time.
1190 for (; in <= input_size - 16; in += 16) {
1191 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
1192 const uint8* weights_ptr = weights_data + in + out * input_size;
1193 uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
1194 uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
1195 uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
1196 uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
1197 int16x8_t input_val_0, input_val_1;
1198 const uint8x8_t low = vget_low_u8(input_val_u8);
1199 const uint8x8_t high = vget_high_u8(input_val_u8);
1200 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
1201 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
1202 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1203 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1204 int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
1205 weights_val_3_0;
1206 int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
1207 weights_val_3_1;
1208 weights_val_0_0 = vaddq_s16(
1209 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
1210 weights_offset_vec);
1211 weights_val_0_1 = vaddq_s16(
1212 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
1213 weights_offset_vec);
1214 weights_val_1_0 = vaddq_s16(
1215 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
1216 weights_offset_vec);
1217 weights_val_1_1 = vaddq_s16(
1218 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
1219 weights_offset_vec);
1220 weights_val_2_0 = vaddq_s16(
1221 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
1222 weights_offset_vec);
1223 weights_val_2_1 = vaddq_s16(
1224 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
1225 weights_offset_vec);
1226 weights_val_3_0 = vaddq_s16(
1227 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
1228 weights_offset_vec);
1229 weights_val_3_1 = vaddq_s16(
1230 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
1231 weights_offset_vec);
1232 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
1233 vget_low_s16(input_val_0));
1234 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
1235 vget_low_s16(input_val_0));
1236 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
1237 vget_low_s16(input_val_0));
1238 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
1239 vget_low_s16(input_val_0));
1240 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
1241 vget_high_s16(input_val_0));
1242 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
1243 vget_high_s16(input_val_0));
1244 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
1245 vget_high_s16(input_val_0));
1246 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
1247 vget_high_s16(input_val_0));
1248 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
1249 vget_low_s16(input_val_1));
1250 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
1251 vget_low_s16(input_val_1));
1252 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
1253 vget_low_s16(input_val_1));
1254 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
1255 vget_low_s16(input_val_1));
1256 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
1257 vget_high_s16(input_val_1));
1258 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
1259 vget_high_s16(input_val_1));
1260 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
1261 vget_high_s16(input_val_1));
1262 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
1263 vget_high_s16(input_val_1));
1264 }
1265 // Handle 8 levels of depth at a time.
1266 for (; in < input_size; in += 8) {
1267 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
1268 const uint8* weights_ptr = weights_data + in + out * input_size;
1269 uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
1270 uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
1271 uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
1272 uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
1273 int16x8_t input_val;
1274 input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
1275 input_val = vaddq_s16(input_val, input_offset_vec);
1276 int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
1277 weights_val_0 =
1278 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
1279 weights_offset_vec);
1280 weights_val_1 =
1281 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
1282 weights_offset_vec);
1283 weights_val_2 =
1284 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
1285 weights_offset_vec);
1286 weights_val_3 =
1287 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
1288 weights_offset_vec);
1289 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
1290 vget_low_s16(input_val));
1291 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
1292 vget_low_s16(input_val));
1293 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
1294 vget_low_s16(input_val));
1295 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
1296 vget_low_s16(input_val));
1297 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
1298 vget_high_s16(input_val));
1299 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
1300 vget_high_s16(input_val));
1301 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
1302 vget_high_s16(input_val));
1303 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
1304 vget_high_s16(input_val));
1305 }
1306 // Horizontally reduce accumulators
1307 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1308 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1309 pairwise_reduced_acc_0 =
1310 vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
1311 pairwise_reduced_acc_1 =
1312 vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
1313 pairwise_reduced_acc_2 =
1314 vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
1315 pairwise_reduced_acc_3 =
1316 vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
1317 const int32x2_t reduced_lo =
1318 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1319 const int32x2_t reduced_hi =
1320 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1321 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1322 // Add bias values.
1323 int32x4_t bias_vec = vld1q_s32(bias_ptr);
1324 bias_ptr += 4;
1325 reduced = vaddq_s32(reduced, bias_vec);
1326 int left_shift = accum_shift > 0 ? accum_shift : 0;
1327 int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1328 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1329 // Multiply by the fixed-point multiplier.
1330 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1331 // Rounding-shift-right.
1332 using gemmlowp::RoundingDivideByPOT;
1333 reduced = RoundingDivideByPOT(reduced, right_shift);
1334 // Narrow values down to 16 bit signed.
1335 const int16x4_t res16 = vqmovn_s32(reduced);
1336 vst1_s16(output_ptr, res16);
1337 output_ptr += 4;
1338 }
1339 }
1340 #endif
1341
1342 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)1343 inline void GEMVForLstmCellWithSymmetricRange(
1344 const RuntimeShape& input_shape, const uint8* input_data,
1345 const RuntimeShape& weights_shape, const uint8* weights_data,
1346 const RuntimeShape& bias_shape, const int32* bias_data,
1347 int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
1348 int16* output_data) {
1349 ruy::profiler::ScopeLabel label("GEMVForLstmCellWithSymmetricRange");
1350 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1351 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1352 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1353 const int output_dim_count = output_shape.DimensionsCount();
1354 const int weights_dim_count = weights_shape.DimensionsCount();
1355 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1356 const int input_size = FlatSizeSkipDim(input_shape, 0);
1357 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
1358 output_shape, output_dim_count - 1);
1359 // This special fast path for quantized LSTM cells does not try to support
1360 // odd sizes that we haven't encountered in any LSTM cell, that would
1361 // require special code (that would go untested until any LSTM cell
1362 // exercises it). We just guard our assumptions about size evenness with
1363 // the following assertions.
1364 TFLITE_DCHECK(!(output_size % 4));
1365 TFLITE_DCHECK(!(input_size % 64));
1366 const int32* bias_ptr = bias_data;
1367 int16* output_ptr = output_data;
1368 const uint8x16_t signbit = vdupq_n_u8(0x80);
1369 for (int in = 0; in < input_size; in += 32) {
1370 optimized_ops_preload_l1_keep(input_data + in);
1371 }
1372 const int left_shift = accum_shift > 0 ? accum_shift : 0;
1373 const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
1374 for (int out = 0; out < output_size; out += 4) {
1375 // Load the bias values
1376 int32x4_t bias_vec = vld1q_s32(bias_ptr);
1377 bias_ptr += 4;
1378
1379 // Clear accumulators. We use 2 accumulator registers per row,
1380 // for 4 rows. row_accumRN is the N-th accumulator for row R.
1381 int32x4_t row_accum00 = vdupq_n_s32(0);
1382 int32x4_t row_accum01 = vdupq_n_s32(0);
1383 int32x4_t row_accum10 = vdupq_n_s32(0);
1384 int32x4_t row_accum11 = vdupq_n_s32(0);
1385 int32x4_t row_accum20 = vdupq_n_s32(0);
1386 int32x4_t row_accum21 = vdupq_n_s32(0);
1387 int32x4_t row_accum30 = vdupq_n_s32(0);
1388 int32x4_t row_accum31 = vdupq_n_s32(0);
1389
1390 // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
1391 const int kReadAhead = 512;
1392 // Prefetch the first weights values.
1393 for (int k = 0; k < kReadAhead; k += 64) {
1394 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1395 k);
1396 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1397 k);
1398 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1399 k);
1400 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1401 k);
1402 }
1403 // Loop along the rows, handling 64 bytes per iteration because that's
1404 // cache line size on most current ARM-architecture CPUs.
1405 for (int in = 0; in < input_size; in += 64) {
1406 // Prefetch some future weights values.
1407 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
1408 in + kReadAhead);
1409 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
1410 in + kReadAhead);
1411 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
1412 in + kReadAhead);
1413 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
1414 in + kReadAhead);
1415
1416 // We will use 2 local 16-bit accumulators per row, for 2 rows.
1417 // See below (*) for the rationale of processing only 2 rows at a time.
1418 // local_accumRN is the N-th local accumulator for row R.
1419 int16x8_t local_accum00;
1420 int16x8_t local_accum01;
1421 int16x8_t local_accum10;
1422 int16x8_t local_accum11;
1423
1424 // Load 64 bytes of input activations values. Convert to signed int8
1425 // by flipping the sign bit (i.e. subtracting 128, the required
1426 // zero_point value).
1427 int8x16_t input0 = vreinterpretq_s8_u8(
1428 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
1429 int8x16_t input1 = vreinterpretq_s8_u8(
1430 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
1431 int8x16_t input2 = vreinterpretq_s8_u8(
1432 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
1433 int8x16_t input3 = vreinterpretq_s8_u8(
1434 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
1435
1436 // Beginning of the core accumulation. Notice how while we have 4
1437 // rows to process, this code is taking care of only 2 rows at a time,
1438 // thus being divided into two parts looking similar ("Rows 0 and 1" and
1439 // "Rows 2 and 3").
1440 //
1441 // (*) The rationale for handling only 2 rows at a time is to avoid
1442 // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
1443 // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
1444 // we may find ourselves in a situation where rows alias each other in
1445 // the L1 cache, and moreover may also mutually alias with the input
1446 // activations. If we try to load 4 rows at a time, together with the
1447 // input activations, that may be 5 mutually-aliasing vectors, resulting
1448 // in constant mutual eviction from L1 cache. Handling 2 rows at a time
1449 // here largely mitigates these issues, and seems at least to be very
1450 // effective on Cortex-A53:
1451 // Before After
1452 // big (Cortex-A73) 2.85 ms 2.85 ms
1453 // little (Cortex-A53) 11.0 ms 5.16 ms
1454
1455 // Rows 0 and 1:
1456 // Load 64 bytes of weights values from each row. Convert to signed int8
1457 // by flipping the sign bit (i.e. subtracting 128, the required
1458 // zero_point value).
1459 int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
1460 signbit,
1461 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
1462 int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
1463 signbit,
1464 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
1465 int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
1466 signbit,
1467 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
1468 int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
1469 signbit,
1470 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
1471 int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
1472 signbit,
1473 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
1474 int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
1475 signbit,
1476 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
1477 int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
1478 signbit,
1479 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
1480 int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
1481 signbit,
1482 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
1483 // Multiply-accumulate into local 16-bit accumulators.
1484 // We can accumulate two products without overflow because weights are
1485 // required to never be -128, so each product is at most 127^2 in absolute
1486 // value.
1487 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1488 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1489 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1490 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1491 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1492 vget_high_s8(input0));
1493 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1494 vget_high_s8(input1));
1495 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1496 vget_high_s8(input0));
1497 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1498 vget_high_s8(input1));
1499 // Pairwise add and accumulate into 32-bit accumulators
1500 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1501 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1502 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1503 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1504 // Multiply-accumulate into local 16-bit accumulators.
1505 // We can accumulate two products without overflow because weights are
1506 // required to never be -128, so each product is at most 127^2 in absolute
1507 // value.
1508 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1509 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1510 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1511 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1512 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1513 vget_high_s8(input2));
1514 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1515 vget_high_s8(input3));
1516 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1517 vget_high_s8(input2));
1518 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1519 vget_high_s8(input3));
1520 // Pairwise add and accumulate into 32-bit accumulators
1521 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
1522 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
1523 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
1524 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
1525
1526 // Rows 2 and 3:
1527 // Load 64 bytes of weights values from each row. Convert to signed int8
1528 // by flipping the sign bit (i.e. subtracting 128, the required
1529 // zero_point value).
1530 weights00 = vreinterpretq_s8_u8(veorq_u8(
1531 signbit,
1532 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
1533 weights01 = vreinterpretq_s8_u8(veorq_u8(
1534 signbit,
1535 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
1536 weights02 = vreinterpretq_s8_u8(veorq_u8(
1537 signbit,
1538 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
1539 weights03 = vreinterpretq_s8_u8(veorq_u8(
1540 signbit,
1541 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
1542 weights10 = vreinterpretq_s8_u8(veorq_u8(
1543 signbit,
1544 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
1545 weights11 = vreinterpretq_s8_u8(veorq_u8(
1546 signbit,
1547 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
1548 weights12 = vreinterpretq_s8_u8(veorq_u8(
1549 signbit,
1550 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
1551 weights13 = vreinterpretq_s8_u8(veorq_u8(
1552 signbit,
1553 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
1554 // Multiply-accumulate into local 16-bit accumulators.
1555 // We can accumulate two products without overflow because weights are
1556 // required to never be -128, so each product is at most 127^2 in absolute
1557 // value.
1558 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
1559 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
1560 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
1561 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
1562 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
1563 vget_high_s8(input0));
1564 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
1565 vget_high_s8(input1));
1566 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
1567 vget_high_s8(input0));
1568 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
1569 vget_high_s8(input1));
1570 // Pairwise add and accumulate into 32-bit accumulators
1571 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1572 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1573 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1574 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1575 // Multiply-accumulate into local 16-bit accumulators.
1576 // We can accumulate two products without overflow because weights are
1577 // required to never be -128, so each product is at most 127^2 in absolute
1578 // value.
1579 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
1580 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
1581 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
1582 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
1583 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
1584 vget_high_s8(input2));
1585 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
1586 vget_high_s8(input3));
1587 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
1588 vget_high_s8(input2));
1589 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
1590 vget_high_s8(input3));
1591 // Pairwise add and accumulate into 32-bit accumulators
1592 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
1593 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
1594 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
1595 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
1596 }
1597
1598 row_accum00 = vaddq_s32(row_accum00, row_accum01);
1599 row_accum10 = vaddq_s32(row_accum10, row_accum11);
1600 row_accum20 = vaddq_s32(row_accum20, row_accum21);
1601 row_accum30 = vaddq_s32(row_accum30, row_accum31);
1602 // Horizontally reduce accumulators
1603 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1604 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1605 pairwise_reduced_acc_0 =
1606 vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
1607 pairwise_reduced_acc_1 =
1608 vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
1609 pairwise_reduced_acc_2 =
1610 vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
1611 pairwise_reduced_acc_3 =
1612 vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
1613 const int32x2_t reduced_lo =
1614 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1615 const int32x2_t reduced_hi =
1616 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1617 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1618 // Add bias values.
1619 reduced = vaddq_s32(reduced, bias_vec);
1620 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1621 // Multiply by the fixed-point multiplier.
1622 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
1623 // Rounding-shift-right.
1624 using gemmlowp::RoundingDivideByPOT;
1625 reduced = RoundingDivideByPOT(reduced, right_shift);
1626 // Narrow values down to 16 bit signed.
1627 const int16x4_t res16 = vqmovn_s32(reduced);
1628 vst1_s16(output_ptr, res16);
1629 output_ptr += 4;
1630 }
1631 }
1632 #endif
1633
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemmlowp_context)1634 inline void FullyConnected(
1635 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1636 const uint8* input_data, const RuntimeShape& filter_shape,
1637 const uint8* filter_data, const RuntimeShape& bias_shape,
1638 const int32* bias_data_int32, const RuntimeShape& output_shape,
1639 int16* output_data, gemmlowp::GemmContext* gemmlowp_context) {
1640 ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
1641 const int32 input_offset = params.input_offset;
1642 const int32 filter_offset = params.weights_offset;
1643 const int32 output_offset = params.output_offset;
1644 const int32 output_multiplier = params.output_multiplier;
1645 const int output_shift = params.output_shift;
1646 const int32 output_activation_min = params.quantized_activation_min;
1647 const int32 output_activation_max = params.quantized_activation_max;
1648 // This is a copy of the reference implementation. We do not currently have a
1649 // properly optimized version.
1650 (void)gemmlowp_context; // only used in properly optimized code.
1651 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1652 TFLITE_DCHECK_EQ(output_offset, 0);
1653 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1654 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1655
1656 // TODO(b/62193649): This really should be:
1657 // const int batches = ArraySize(output_dims, 1);
1658 // but the current --variable_batch hack consists in overwriting the 3rd
1659 // dimension with the runtime batch size, as we don't keep track for each
1660 // array of which dimension is the batch dimension in it.
1661 const int output_dim_count = output_shape.DimensionsCount();
1662 const int filter_dim_count = filter_shape.DimensionsCount();
1663 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1664 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1665 output_shape, output_dim_count - 1);
1666 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1667
1668 // Implementation of the fully connected node suited to the inside of an LSTM
1669 // cell. The operands are 8-bit integers, the accumulators are internally
1670 // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1671 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1672 // is explained in the function comment above.
1673 #ifdef GEMMLOWP_NEON
1674 if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1675 output_activation_max == 32767) {
1676 if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1677 GEMVForLstmCellWithSymmetricRange(
1678 input_shape, input_data, filter_shape, filter_data, bias_shape,
1679 bias_data_int32, output_multiplier, output_shift, output_shape,
1680 output_data);
1681 return;
1682 }
1683 if (!(output_depth % 4) && !(accum_depth % 8)) {
1684 GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1685 filter_offset, bias_shape, bias_data_int32,
1686 output_multiplier, output_shift, output_shape,
1687 output_data);
1688 return;
1689 }
1690 }
1691 #endif
1692 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1693 filter_data, output_depth, accum_depth);
1694 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1695 input_data, accum_depth, batches);
1696 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1697 output_data, output_depth, batches);
1698 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1699 ColVectorMap;
1700 ColVectorMap bias_vector(bias_data_int32, output_depth);
1701 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1702 bias_addition_stage.bias_vector = bias_vector;
1703 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1704 scale_stage.result_offset_after_shift = 0;
1705 scale_stage.result_fixedpoint_multiplier = output_multiplier;
1706 // Note that this shift is negated wrt ordinary FC.
1707 scale_stage.result_exponent = output_shift;
1708 gemmlowp::OutputStageClamp clamp_stage;
1709 clamp_stage.min = output_activation_min;
1710 clamp_stage.max = output_activation_max;
1711 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1712 auto output_pipeline =
1713 std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1714 saturating_cast_int16_stage);
1715 gemmlowp::GemmWithOutputPipeline<uint8, int16,
1716 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1717 gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
1718 filter_offset, input_offset, output_pipeline);
1719 }
1720
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1721 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1722 int32 input_offset, const uint8* filter_data,
1723 const Dims<4>& filter_dims, int32 filter_offset,
1724 const int32* bias_data, const Dims<4>& bias_dims,
1725 int32 output_offset, int32 output_multiplier,
1726 int output_shift, int32 output_activation_min,
1727 int32 output_activation_max, uint8* output_data,
1728 const Dims<4>& output_dims,
1729 gemmlowp::GemmContext* gemmlowp_context) {
1730 tflite::FullyConnectedParams op_params;
1731 op_params.input_offset = input_offset;
1732 op_params.weights_offset = filter_offset;
1733 op_params.output_offset = output_offset;
1734 op_params.output_multiplier = output_multiplier;
1735 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1736 op_params.output_shift = kReverseShift * output_shift;
1737 op_params.quantized_activation_min = output_activation_min;
1738 op_params.quantized_activation_max = output_activation_max;
1739
1740 FullyConnected(op_params, DimsToShape(input_dims), input_data,
1741 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1742 bias_data, DimsToShape(output_dims), output_data,
1743 gemmlowp_context);
1744 }
1745
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data_int32,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1746 inline void FullyConnected(
1747 const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
1748 const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
1749 const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
1750 int32 output_multiplier, int output_shift, int32 output_activation_min,
1751 int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
1752 gemmlowp::GemmContext* gemmlowp_context) {
1753 tflite::FullyConnectedParams op_params;
1754 op_params.input_offset = input_offset;
1755 op_params.weights_offset = filter_offset;
1756 op_params.output_offset = output_offset;
1757 op_params.output_multiplier = output_multiplier;
1758 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1759 op_params.output_shift = kReverseShift * output_shift;
1760 op_params.quantized_activation_min = output_activation_min;
1761 op_params.quantized_activation_max = output_activation_max;
1762
1763 FullyConnected(op_params, DimsToShape(input_dims), input_data,
1764 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
1765 bias_data_int32, DimsToShape(output_dims), output_data,
1766 gemmlowp_context);
1767 }
1768
1769 // legacy, for compatibility with old checked-in code
1770 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)1771 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
1772 int32 input_offset, const uint8* filter_data,
1773 const Dims<4>& filter_dims, int32 filter_offset,
1774 const int32* bias_data, const Dims<4>& bias_dims,
1775 int32 output_offset, int32 output_multiplier,
1776 int output_shift, int32 output_activation_min,
1777 int32 output_activation_max, uint8* output_data,
1778 const Dims<4>& output_dims,
1779 gemmlowp::GemmContext* gemmlowp_context) {
1780 static_assert(Ac == FusedActivationFunctionType::kNone ||
1781 Ac == FusedActivationFunctionType::kRelu ||
1782 Ac == FusedActivationFunctionType::kRelu6 ||
1783 Ac == FusedActivationFunctionType::kRelu1,
1784 "");
1785 FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
1786 filter_offset, bias_data, bias_dims, output_offset,
1787 output_multiplier, output_shift, output_activation_min,
1788 output_activation_max, output_data, output_dims,
1789 gemmlowp_context);
1790 }
1791
1792 #ifdef USE_NEON
LegacyInt8FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,int row_start,int row_end)1793 inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl(
1794 const RuntimeShape& input_shape, const int8_t* input_data,
1795 int32 input_offset, const RuntimeShape& filter_shape,
1796 const int8_t* filter_data, int32 filter_offset,
1797 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1798 int32 output_multiplier, int output_shift, int32 output_activation_min,
1799 int32 output_activation_max, const RuntimeShape& output_shape,
1800 int8_t* output_data, int row_start, int row_end) {
1801 ruy::profiler::ScopeLabel label("FullyConnectedAsGEMVInt8/8bit");
1802 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1803 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1804 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1805 const int output_dim_count = output_shape.DimensionsCount();
1806 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
1807 const int input_size = FlatSizeSkipDim(input_shape, 0);
1808 static constexpr int kPeel = 4;
1809 const bool shift_left = (output_shift > 0);
1810 TFLITE_DCHECK_GE(row_end - row_start, kPeel);
1811
1812 for (int out = row_start; out < row_end; out += kPeel) {
1813 out = std::min(out, row_end - kPeel);
1814 int32x4_t acc0 = vdupq_n_s32(0);
1815 int32x4_t acc1 = acc0;
1816 int32x4_t acc2 = acc0;
1817 int32x4_t acc3 = acc0;
1818 const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
1819 const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
1820 int in = 0;
1821 for (; in <= input_size - 16; in += 16) {
1822 const int8x16_t input_val_s8 = vld1q_s8(input_data + in);
1823 const int8_t* filter_ptr = filter_data + in + out * input_size;
1824 int8x16_t filter_val_s8_0 = vld1q_s8(filter_ptr);
1825 filter_ptr += input_size;
1826 int8x16_t filter_val_s8_1 = vld1q_s8(filter_ptr);
1827 filter_ptr += input_size;
1828 int8x16_t filter_val_s8_2 = vld1q_s8(filter_ptr);
1829 filter_ptr += input_size;
1830 int8x16_t filter_val_s8_3 = vld1q_s8(filter_ptr);
1831 int16x8_t input_val_0, input_val_1;
1832 int8x8_t low = vget_low_s8(input_val_s8);
1833 int8x8_t high = vget_high_s8(input_val_s8);
1834 input_val_0 = vmovl_s8(low);
1835 input_val_1 = vmovl_s8(high);
1836 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
1837 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
1838 low = vget_low_s8(filter_val_s8_0);
1839 high = vget_high_s8(filter_val_s8_0);
1840 int16x8_t filter_val_0_0 = vmovl_s8(low);
1841 int16x8_t filter_val_0_1 = vmovl_s8(high);
1842 filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
1843 filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
1844 low = vget_low_s8(filter_val_s8_1);
1845 high = vget_high_s8(filter_val_s8_1);
1846 int16x8_t filter_val_1_0 = vmovl_s8(low);
1847 int16x8_t filter_val_1_1 = vmovl_s8(high);
1848 filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
1849 filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
1850 low = vget_low_s8(filter_val_s8_2);
1851 high = vget_high_s8(filter_val_s8_2);
1852 int16x8_t filter_val_2_0 = vmovl_s8(low);
1853 int16x8_t filter_val_2_1 = vmovl_s8(high);
1854 filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
1855 filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
1856 low = vget_low_s8(filter_val_s8_3);
1857 high = vget_high_s8(filter_val_s8_3);
1858 int16x8_t filter_val_3_0 = vmovl_s8(low);
1859 int16x8_t filter_val_3_1 = vmovl_s8(high);
1860 filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
1861 filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
1862 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
1863 vget_low_s16(input_val_0));
1864 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
1865 vget_low_s16(input_val_0));
1866 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
1867 vget_low_s16(input_val_0));
1868 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
1869 vget_low_s16(input_val_0));
1870 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
1871 vget_low_s16(input_val_1));
1872 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
1873 vget_low_s16(input_val_1));
1874 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
1875 vget_low_s16(input_val_1));
1876 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
1877 vget_low_s16(input_val_1));
1878 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
1879 vget_high_s16(input_val_0));
1880 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
1881 vget_high_s16(input_val_0));
1882 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
1883 vget_high_s16(input_val_0));
1884 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
1885 vget_high_s16(input_val_0));
1886 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
1887 vget_high_s16(input_val_1));
1888 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
1889 vget_high_s16(input_val_1));
1890 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
1891 vget_high_s16(input_val_1));
1892 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
1893 vget_high_s16(input_val_1));
1894 }
1895 for (; in <= input_size - 8; in += 8) {
1896 const int8x8_t input_val_s8 = vld1_s8(input_data + in);
1897 const int8_t* filter_ptr = filter_data + in + out * input_size;
1898 int8x8_t filter_val_s8_0 = vld1_s8(filter_ptr);
1899 filter_ptr += input_size;
1900 int8x8_t filter_val_s8_1 = vld1_s8(filter_ptr);
1901 filter_ptr += input_size;
1902 int8x8_t filter_val_s8_2 = vld1_s8(filter_ptr);
1903 filter_ptr += input_size;
1904 int8x8_t filter_val_s8_3 = vld1_s8(filter_ptr);
1905 int16x8_t input_val = vmovl_s8(input_val_s8);
1906 input_val = vaddq_s16(input_val, input_offset_vec);
1907 int16x8_t filter_val_0 = vmovl_s8(filter_val_s8_0);
1908 filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
1909 int16x8_t filter_val_1 = vmovl_s8(filter_val_s8_1);
1910 filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
1911 int16x8_t filter_val_2 = vmovl_s8(filter_val_s8_2);
1912 filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
1913 int16x8_t filter_val_3 = vmovl_s8(filter_val_s8_3);
1914 filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
1915 acc0 =
1916 vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
1917 acc1 =
1918 vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
1919 acc2 =
1920 vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
1921 acc3 =
1922 vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
1923 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
1924 vget_high_s16(input_val));
1925 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
1926 vget_high_s16(input_val));
1927 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
1928 vget_high_s16(input_val));
1929 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
1930 vget_high_s16(input_val));
1931 }
1932 if (in < input_size) {
1933 int32 buf[16];
1934 vst1q_s32(buf + 0, acc0);
1935 vst1q_s32(buf + 4, acc1);
1936 vst1q_s32(buf + 8, acc2);
1937 vst1q_s32(buf + 12, acc3);
1938 for (; in < input_size; in++) {
1939 int lane = (in + 8 - input_size) % 4;
1940 const int32 input_val = input_data[in] + input_offset;
1941 for (int k = 0; k < kPeel; k++) {
1942 int32 filter_val =
1943 filter_data[in + (out + k) * input_size] + filter_offset;
1944 buf[lane + 4 * k] += filter_val * input_val;
1945 }
1946 }
1947 acc0 = vld1q_s32(buf + 0);
1948 acc1 = vld1q_s32(buf + 4);
1949 acc2 = vld1q_s32(buf + 8);
1950 acc3 = vld1q_s32(buf + 12);
1951 }
1952
1953 // Horizontally reduce accumulators
1954 int32x2_t pairwise_reduced_acc_0 =
1955 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
1956 int32x2_t pairwise_reduced_acc_1 =
1957 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
1958 int32x2_t pairwise_reduced_acc_2 =
1959 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
1960 int32x2_t pairwise_reduced_acc_3 =
1961 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
1962 const int32x2_t reduced_lo =
1963 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1964 const int32x2_t reduced_hi =
1965 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1966 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1967 // Add bias values.
1968 int32x4_t bias_vec = vld1q_s32(bias_data + out);
1969 reduced = vaddq_s32(reduced, bias_vec);
1970 if (shift_left) {
1971 const int32 multiplier_power_of_two = 1 << output_shift;
1972 reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1973 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1974 } else {
1975 // Multiply by the fixed-point multiplier.
1976 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1977 // Rounding-shift-right.
1978 using gemmlowp::RoundingDivideByPOT;
1979 reduced = RoundingDivideByPOT(reduced, -output_shift);
1980 }
1981 // Add the output offset.
1982 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1983 reduced = vaddq_s32(reduced, output_offset_vec);
1984 // Narrow values down to 16 bit signed.
1985 const int16x4_t res16 = vqmovn_s32(reduced);
1986 // Narrow values down to 8 bit signed, saturating.
1987 int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
1988 // Apply the clamping from the activation function
1989 res8 = vmax_s8(res8, vdup_n_s8(output_activation_min));
1990 res8 = vmin_s8(res8, vdup_n_s8(output_activation_max));
1991 // Store results to destination.
1992 vst1_lane_s8(output_data + out + 0, res8, 0);
1993 vst1_lane_s8(output_data + out + 1, res8, 1);
1994 vst1_lane_s8(output_data + out + 2, res8, 2);
1995 vst1_lane_s8(output_data + out + 3, res8, 3);
1996 }
1997 }
1998
1999 struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
LegacyInt8FullyConnectedAsGEMVWorkerTaskLegacyInt8FullyConnectedAsGEMVWorkerTask2000 LegacyInt8FullyConnectedAsGEMVWorkerTask(
2001 const RuntimeShape& input_shape, const int8_t* input_data,
2002 int32 input_offset, const RuntimeShape& filter_shape,
2003 const int8_t* filter_data, int32 filter_offset,
2004 const RuntimeShape& bias_shape, const int32* bias_data,
2005 int32 output_offset, int32 output_multiplier, int output_shift,
2006 int32 output_activation_min, int32 output_activation_max,
2007 const RuntimeShape& output_shape, int8_t* output_data, int row_start,
2008 int row_end)
2009 : input_shape_(input_shape),
2010 input_data_(input_data),
2011 input_offset_(input_offset),
2012 filter_shape_(filter_shape),
2013 filter_data_(filter_data),
2014 filter_offset_(filter_offset),
2015 bias_shape_(bias_shape),
2016 bias_data_(bias_data),
2017 output_offset_(output_offset),
2018 output_multiplier_(output_multiplier),
2019 output_shift_(output_shift),
2020 output_activation_min_(output_activation_min),
2021 output_activation_max_(output_activation_max),
2022 output_shape_(output_shape),
2023 output_data_(output_data),
2024 row_start_(row_start),
2025 row_end_(row_end) {}
2026
RunLegacyInt8FullyConnectedAsGEMVWorkerTask2027 void Run() override {
2028 LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2029 input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
2030 filter_offset_, bias_shape_, bias_data_, output_offset_,
2031 output_multiplier_, output_shift_, output_activation_min_,
2032 output_activation_max_, output_shape_, output_data_, row_start_,
2033 row_end_);
2034 }
2035
2036 const RuntimeShape& input_shape_;
2037 const int8_t* input_data_;
2038 int32 input_offset_;
2039 const RuntimeShape& filter_shape_;
2040 const int8_t* filter_data_;
2041 int32 filter_offset_;
2042 const RuntimeShape& bias_shape_;
2043 const int32* bias_data_;
2044 int32 output_offset_;
2045 int32 output_multiplier_;
2046 int output_shift_;
2047 int32 output_activation_min_;
2048 int32 output_activation_max_;
2049 const RuntimeShape& output_shape_;
2050 int8_t* output_data_;
2051 int row_start_;
2052 int row_end_;
2053 };
2054
LegacyInt8FullyConnectedAsGEMV(const RuntimeShape & input_shape,const int8_t * input_data,int32 input_offset,const RuntimeShape & filter_shape,const int8_t * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,int8_t * output_data,gemmlowp::GemmContext * gemmlowp_context)2055 inline void LegacyInt8FullyConnectedAsGEMV(
2056 const RuntimeShape& input_shape, const int8_t* input_data,
2057 int32 input_offset, const RuntimeShape& filter_shape,
2058 const int8_t* filter_data, int32 filter_offset,
2059 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
2060 int32 output_multiplier, int output_shift, int32 output_activation_min,
2061 int32 output_activation_max, const RuntimeShape& output_shape,
2062 int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) {
2063 const int output_dim_count = output_shape.DimensionsCount();
2064 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2065 const int output_rows = output_shape.Dims(output_dim_count - 1);
2066 const int input_size = FlatSizeSkipDim(input_shape, 0);
2067 static constexpr int kKernelRows = 4;
2068 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2069 gemmlowp_context->max_num_threads(), output_rows, batches, input_size);
2070 if (thread_count == 1) {
2071 // Single-thread case: do the computation on the current thread, don't
2072 // use a threadpool
2073 LegacyInt8FullyConnectedAsGEMVWorkerImpl(
2074 input_shape, input_data, input_offset, filter_shape, filter_data,
2075 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
2076 output_shift, output_activation_min, output_activation_max,
2077 output_shape, output_data, 0, output_rows);
2078 return;
2079 }
2080
2081 // Multi-threaded case: use the gemmlowp context's threadpool.
2082 TFLITE_DCHECK_GT(thread_count, 1);
2083 std::vector<LegacyInt8FullyConnectedAsGEMVWorkerTask> tasks;
2084 // TODO(b/131746020) don't create new heap allocations every time.
2085 // At least we make it a single heap allocation by using reserve().
2086 tasks.reserve(thread_count);
2087 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2088 gemmlowp::CeilQuotient(output_rows, thread_count));
2089 int row_start = 0;
2090 for (int i = 0; i < thread_count; ++i) {
2091 int row_end = std::min(output_rows, row_start + kRowsPerWorker);
2092 tasks.emplace_back(input_shape, input_data, input_offset, filter_shape,
2093 filter_data, filter_offset, bias_shape, bias_data,
2094 output_offset, output_multiplier, output_shift,
2095 output_activation_min, output_activation_max,
2096 output_shape, output_data, row_start, row_end);
2097 row_start = row_end;
2098 }
2099 TFLITE_DCHECK_EQ(row_start, output_rows);
2100 gemmlowp_context->workers_pool()->Execute(tasks.size(), tasks.data());
2101 }
2102 #endif // USE_NEON
2103
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int8 * output_data,gemmlowp::GemmContext * gemmlowp_context)2104 inline void FullyConnected(
2105 const FullyConnectedParams& params, const RuntimeShape& input_shape,
2106 const int8* input_data, const RuntimeShape& filter_shape,
2107 const int8* filter_data, const RuntimeShape& bias_shape,
2108 const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
2109 gemmlowp::GemmContext* gemmlowp_context) {
2110 ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
2111
2112 #ifdef USE_NEON
2113 const int32 input_offset = params.input_offset;
2114 const int32 filter_offset = params.weights_offset;
2115 const int32 output_offset = params.output_offset;
2116 const int32 output_multiplier = params.output_multiplier;
2117 const int output_shift = params.output_shift;
2118 const int32 output_activation_min = params.quantized_activation_min;
2119 const int32 output_activation_max = params.quantized_activation_max;
2120 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
2121 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2122 // TODO(b/62193649): This really should be:
2123 // const int batches = ArraySize(output_dims, 1);
2124 // but the current --variable_batch hack consists in overwriting the 3rd
2125 // dimension with the runtime batch size, as we don't keep track for each
2126 // array of which dimension is the batch dimension in it.
2127 const int output_dim_count = output_shape.DimensionsCount();
2128 const int filter_dim_count = filter_shape.DimensionsCount();
2129 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2130 if (batches == 1) {
2131 const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
2132 output_shape, output_dim_count - 1);
2133 if (output_size >= 4) {
2134 return LegacyInt8FullyConnectedAsGEMV(
2135 input_shape, input_data, input_offset, filter_shape, filter_data,
2136 filter_offset, bias_shape, bias_data, output_offset,
2137 output_multiplier, output_shift, output_activation_min,
2138 output_activation_max, output_shape, output_data, gemmlowp_context);
2139 }
2140 }
2141 #endif // USE_NEON
2142
2143 #ifdef GEMMLOWP_NEON
2144 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
2145 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
2146 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
2147 const int output_rows = output_shape.Dims(output_dim_count - 1);
2148 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2149 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2150
2151 gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2152 filter_data, output_rows, filter_cols, filter_cols);
2153 gemmlowp::MatrixMap<const int8, gemmlowp::MapOrder::ColMajor> input_matrix(
2154 input_data, filter_cols, batches, filter_cols);
2155 gemmlowp::MatrixMap<int8, gemmlowp::MapOrder::ColMajor> output_matrix(
2156 output_data, output_rows, batches, output_rows);
2157 const auto& output_pipeline = GemmlowpOutputPipelineInt8::MakeExp(
2158 bias_data, output_rows, output_offset, output_multiplier, output_shift,
2159 output_activation_min, output_activation_max);
2160
2161 gemmlowp::GemmWithOutputPipeline<
2162 int8, int8, gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams>(
2163 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2164 filter_offset, input_offset, output_pipeline);
2165 return;
2166 #endif // GEMMLOWP_NEON
2167
2168 // If both GEMMLOWP_NEON && NEON paths are skipped, fallback to reference
2169 // implementation.
2170 reference_integer_ops::FullyConnected(params, input_shape, input_data,
2171 filter_shape, filter_data, bias_shape,
2172 bias_data, output_shape, output_data);
2173 }
2174
2175 struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task {
LegacyShuffledFullyConnectedWorkerTaskLegacyShuffledFullyConnectedWorkerTask2176 LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data,
2177 const int8* shuffled_weights_data,
2178 int batches, int output_depth,
2179 int output_stride, int accum_depth,
2180 const int32* bias_data,
2181 int32 output_multiplier,
2182 int output_shift, int16* output_data)
2183 : input_data_(input_data),
2184 shuffled_weights_data_(shuffled_weights_data),
2185 batches_(batches),
2186 output_depth_(output_depth),
2187 output_stride_(output_stride),
2188 accum_depth_(accum_depth),
2189 bias_data_(bias_data),
2190 output_multiplier_(output_multiplier),
2191 output_shift_(output_shift),
2192 output_data_(output_data) {}
2193
RunLegacyShuffledFullyConnectedWorkerTask2194 void Run() override {
2195 ShuffledFullyConnectedWorkerImpl(
2196 input_data_, shuffled_weights_data_, batches_, output_depth_,
2197 output_stride_, accum_depth_, bias_data_, output_multiplier_,
2198 output_shift_, output_data_);
2199 }
2200
2201 const uint8* input_data_;
2202 const int8* shuffled_weights_data_;
2203 int batches_;
2204 int output_depth_;
2205 int output_stride_;
2206 int accum_depth_;
2207 const int32* bias_data_;
2208 int32 output_multiplier_;
2209 int output_shift_;
2210 int16* output_data_;
2211 };
2212
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2213 inline void ShuffledFullyConnected(
2214 const FullyConnectedParams& params, const RuntimeShape& input_shape,
2215 const uint8* input_data, const RuntimeShape& weights_shape,
2216 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
2217 const int32* bias_data, const RuntimeShape& output_shape,
2218 int16* output_data, uint8* shuffled_input_workspace_data,
2219 gemmlowp::GemmContext* gemmlowp_context) {
2220 ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
2221 const int32 output_multiplier = params.output_multiplier;
2222 const int output_shift = params.output_shift;
2223 const int32 output_activation_min = params.quantized_activation_min;
2224 const int32 output_activation_max = params.quantized_activation_max;
2225 (void)gemmlowp_context; // only used in optimized code.
2226 TFLITE_DCHECK_EQ(output_activation_min, -32768);
2227 TFLITE_DCHECK_EQ(output_activation_max, 32767);
2228 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
2229 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2230 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
2231 // TODO(b/62193649): This really should be:
2232 // const int batches = ArraySize(output_dims, 1);
2233 // but the current --variable_batch hack consists in overwriting the 3rd
2234 // dimension with the runtime batch size, as we don't keep track for each
2235 // array of which dimension is the batch dimension in it.
2236 const int output_dim_count = output_shape.DimensionsCount();
2237 const int weights_dim_count = weights_shape.DimensionsCount();
2238 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
2239 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
2240 output_shape, output_dim_count - 1);
2241 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
2242 TFLITE_DCHECK((accum_depth % 16) == 0);
2243 TFLITE_DCHECK((output_depth % 4) == 0);
2244 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
2245 // so that just reinterpreting them as int8 values is equivalent to
2246 // subtracting 128 from them, thus implementing for free the subtraction of
2247 // the zero_point value 128.
2248 const int8* int8_shuffled_weights_data =
2249 reinterpret_cast<const int8*>(shuffled_weights_data);
2250
2251 // Shuffling and xoring of input activations into the workspace buffer
2252 if (batches == 1) {
2253 #ifdef USE_NEON
2254 const uint8x16_t signbit = vdupq_n_u8(0x80);
2255 for (int i = 0; i < accum_depth; i += 16) {
2256 uint8x16_t val = vld1q_u8(input_data + i);
2257 val = veorq_u8(val, signbit);
2258 vst1q_u8(shuffled_input_workspace_data + i, val);
2259 }
2260 #else
2261 for (int i = 0; i < accum_depth; i++) {
2262 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
2263 }
2264 #endif
2265 } else if (batches == 4) {
2266 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
2267 int c = 0;
2268 #ifdef USE_NEON
2269 const uint8x16_t signbit = vdupq_n_u8(0x80);
2270 for (c = 0; c < accum_depth; c += 16) {
2271 const uint8* src_data_ptr = input_data + c;
2272 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
2273 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
2274 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
2275 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
2276 val0 = veorq_u8(val0, signbit);
2277 val1 = veorq_u8(val1, signbit);
2278 val2 = veorq_u8(val2, signbit);
2279 val3 = veorq_u8(val3, signbit);
2280 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
2281 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
2282 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
2283 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
2284 shuffled_input_workspace_ptr += 64;
2285 }
2286 #else
2287 for (c = 0; c < accum_depth; c += 16) {
2288 for (int b = 0; b < 4; b++) {
2289 const uint8* src_data_ptr = input_data + b * accum_depth + c;
2290 for (int j = 0; j < 16; j++) {
2291 uint8 src_val = *src_data_ptr++;
2292 // Flip the sign bit, so that the kernel will only need to
2293 // reinterpret these uint8 values as int8, getting for free the
2294 // subtraction of the zero_point value 128.
2295 uint8 dst_val = src_val ^ 0x80;
2296 *shuffled_input_workspace_ptr++ = dst_val;
2297 }
2298 }
2299 }
2300 #endif
2301 } else {
2302 TFLITE_DCHECK(false);
2303 return;
2304 }
2305
2306 static constexpr int kKernelRows = 4;
2307 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
2308 gemmlowp_context->max_num_threads(), output_depth, batches, accum_depth);
2309 if (thread_count == 1) {
2310 // Single-thread case: do the computation on the current thread, don't
2311 // use a threadpool
2312 ShuffledFullyConnectedWorkerImpl(
2313 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
2314 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
2315 output_shift, output_data);
2316 return;
2317 }
2318
2319 // Multi-threaded case: use the gemmlowp context's threadpool.
2320 TFLITE_DCHECK_GT(thread_count, 1);
2321 std::vector<gemmlowp::Task*> tasks(thread_count);
2322 const int kRowsPerWorker = gemmlowp::RoundUp<kKernelRows>(
2323 gemmlowp::CeilQuotient(output_depth, thread_count));
2324 int row_start = 0;
2325 for (int i = 0; i < thread_count; i++) {
2326 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
2327 tasks[i] = new LegacyShuffledFullyConnectedWorkerTask(
2328 shuffled_input_workspace_data,
2329 int8_shuffled_weights_data + row_start * accum_depth, batches,
2330 row_end - row_start, output_depth, accum_depth, bias_data + row_start,
2331 output_multiplier, output_shift, output_data + row_start);
2332 row_start = row_end;
2333 }
2334 TFLITE_DCHECK_EQ(row_start, output_depth);
2335 gemmlowp_context->workers_pool()->LegacyExecuteAndDestroyTasks(tasks);
2336 }
2337
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)2338 inline void ShuffledFullyConnected(
2339 const uint8* input_data, const Dims<4>& input_dims,
2340 const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
2341 const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
2342 int output_shift, int32 output_activation_min, int32 output_activation_max,
2343 int16* output_data, const Dims<4>& output_dims,
2344 uint8* shuffled_input_workspace_data,
2345 gemmlowp::GemmContext* gemmlowp_context) {
2346 tflite::FullyConnectedParams op_params;
2347 op_params.output_multiplier = output_multiplier;
2348 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2349 op_params.output_shift = kReverseShift * output_shift;
2350 op_params.quantized_activation_min = output_activation_min;
2351 op_params.quantized_activation_max = output_activation_max;
2352
2353 ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
2354 DimsToShape(weights_dims), shuffled_weights_data,
2355 DimsToShape(bias_dims), bias_data,
2356 DimsToShape(output_dims), output_data,
2357 shuffled_input_workspace_data, gemmlowp_context);
2358 }
2359
2360 template <typename T>
ExtractPatchIntoBufferColumn(const Dims<4> & input_dims,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)2361 inline void ExtractPatchIntoBufferColumn(
2362 const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
2363 int stride_width, int stride_height, int pad_width, int pad_height,
2364 int in_width, int in_height, int in_depth, int single_buffer_length,
2365 int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
2366 ExtractPatchIntoBufferColumn(
2367 DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
2368 stride_height, pad_width, pad_height, in_width, in_height, in_depth,
2369 single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
2370 }
2371
2372 template <typename T>
DilatedIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2373 void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
2374 const Dims<4>& filter_dims, int stride_width,
2375 int stride_height, int dilation_width_factor,
2376 int dilation_height_factor, int pad_width, int pad_height,
2377 const Dims<4>& output_dims, uint8 zero_byte,
2378 T* im2col_data) {
2379 tflite::ConvParams op_params;
2380 // Padding type is ignored, but still set.
2381 op_params.padding_type = PaddingType::kSame;
2382 op_params.padding_values.width = pad_width;
2383 op_params.padding_values.height = pad_height;
2384 op_params.stride_width = stride_width;
2385 op_params.stride_height = stride_height;
2386 op_params.dilation_width_factor = dilation_width_factor;
2387 op_params.dilation_height_factor = dilation_height_factor;
2388
2389 DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2390 DimsToShape(filter_dims), DimsToShape(output_dims),
2391 im2col_data);
2392 }
2393
2394 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2395 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
2396 int stride_height, int pad_width, int pad_height, int kheight,
2397 int kwidth, uint8 zero_byte, T* output_data,
2398 const Dims<4>& output_dims) {
2399 tflite::ConvParams op_params;
2400 // Padding type is ignored, but still set.
2401 op_params.padding_type = PaddingType::kSame;
2402 op_params.padding_values.width = pad_width;
2403 op_params.padding_values.height = pad_height;
2404 op_params.stride_width = stride_width;
2405 op_params.stride_height = stride_height;
2406 op_params.dilation_width_factor = 1;
2407 op_params.dilation_height_factor = 1;
2408
2409 Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
2410 input_data, DimsToShape(output_dims), output_data);
2411 }
2412
2413 // legacy, for compatibility with old checked-in code
2414 template <typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2415 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2416 int pad_width, int pad_height, int kheight, int kwidth,
2417 uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2418 Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2419 kwidth, zero_byte, output_data, output_dims);
2420 }
2421
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2422 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2423 const float* input_data, const RuntimeShape& filter_shape,
2424 const float* filter_data, const RuntimeShape& bias_shape,
2425 const float* bias_data, const RuntimeShape& output_shape,
2426 float* output_data, const RuntimeShape& im2col_shape,
2427 float* im2col_data) {
2428 const int stride_width = params.stride_width;
2429 const int stride_height = params.stride_height;
2430 const int dilation_width_factor = params.dilation_width_factor;
2431 const int dilation_height_factor = params.dilation_height_factor;
2432 const float output_activation_min = params.float_activation_min;
2433 const float output_activation_max = params.float_activation_max;
2434 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2435 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2436 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2437
2438 (void)im2col_data;
2439 (void)im2col_shape;
2440 ruy::profiler::ScopeLabel label("Conv");
2441
2442 // NB: the float 0.0f value is represented by all zero bytes.
2443 const uint8 float_zero_byte = 0x00;
2444 const float* gemm_input_data = nullptr;
2445 const RuntimeShape* gemm_input_shape = nullptr;
2446 const int filter_width = filter_shape.Dims(2);
2447 const int filter_height = filter_shape.Dims(1);
2448 const bool need_dilated_im2col =
2449 dilation_width_factor != 1 || dilation_height_factor != 1;
2450 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2451 filter_width != 1 || filter_height != 1;
2452 if (need_dilated_im2col) {
2453 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2454 filter_shape, output_shape, im2col_data);
2455 gemm_input_data = im2col_data;
2456 gemm_input_shape = &im2col_shape;
2457 } else if (need_im2col) {
2458 TFLITE_DCHECK(im2col_data);
2459 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2460 input_data, im2col_shape, im2col_data);
2461 gemm_input_data = im2col_data;
2462 gemm_input_shape = &im2col_shape;
2463 } else {
2464 // TODO(aselle): We need to make sure to not send im2col if it is not
2465 // needed.
2466 TFLITE_DCHECK(!im2col_data);
2467 gemm_input_data = input_data;
2468 gemm_input_shape = &input_shape;
2469 }
2470
2471 // The following code computes matrix multiplication c = a * transponse(b)
2472 // with CBLAS, where:
2473 // * `a` is a matrix with dimensions (m, k).
2474 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2475 // * `c` is a matrix with dimensions (m, n).
2476 // The naming of variables are aligned with CBLAS specification here.
2477 const float* a = gemm_input_data;
2478 const float* b = filter_data;
2479 float* c = output_data;
2480 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2481 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2482 int n = output_shape.Dims(3);
2483 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2484
2485 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2486 // The stride of matrix a, b and c respectively.
2487 int stride_a = k;
2488 int stride_b = k;
2489 int stride_c = n;
2490
2491 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2492 stride_a, b, stride_b, 0.0f, c, stride_c);
2493 #else
2494 // When an optimized CBLAS implementation is not available, fall back
2495 // to using Eigen.
2496 typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2497 Matrix;
2498 typedef Eigen::Map<Matrix> MatrixRef;
2499 typedef Eigen::Map<const Matrix> ConstMatrixRef;
2500
2501 MatrixRef matrix_c(c, m, n);
2502 ConstMatrixRef matrix_a(a, m, k);
2503 ConstMatrixRef matrix_b(b, n, k);
2504
2505 // The following special casing for when a or b is a vector is required
2506 // as Eigen seem to fail to make this optimization on its own.
2507 if (n == 1) {
2508 ruy::profiler::ScopeLabel label("GEMV");
2509 matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2510 } else if (m == 1) {
2511 ruy::profiler::ScopeLabel label("GEMV");
2512 matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2513 } else {
2514 ruy::profiler::ScopeLabel label("GEMM");
2515 matrix_c.noalias() = matrix_a * matrix_b.transpose();
2516 }
2517
2518 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2519
2520 optimized_ops::AddBiasAndEvalActivationFunction(
2521 output_activation_min, output_activation_max, bias_shape, bias_data,
2522 output_shape, output_data);
2523 }
2524
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2525 inline void Conv(const float* input_data, const Dims<4>& input_dims,
2526 const float* filter_data, const Dims<4>& filter_dims,
2527 const float* bias_data, const Dims<4>& bias_dims,
2528 int stride_width, int stride_height, int dilation_width_factor,
2529 int dilation_height_factor, int pad_width, int pad_height,
2530 float output_activation_min, float output_activation_max,
2531 float* output_data, const Dims<4>& output_dims,
2532 float* im2col_data, const Dims<4>& im2col_dims) {
2533 tflite::ConvParams op_params;
2534 // Padding type is ignored, but still set.
2535 op_params.padding_type = PaddingType::kSame;
2536 op_params.padding_values.width = pad_width;
2537 op_params.padding_values.height = pad_height;
2538 op_params.stride_width = stride_width;
2539 op_params.stride_height = stride_height;
2540 op_params.dilation_width_factor = dilation_width_factor;
2541 op_params.dilation_height_factor = dilation_height_factor;
2542 op_params.float_activation_min = output_activation_min;
2543 op_params.float_activation_max = output_activation_max;
2544
2545 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2546 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2547 output_data, DimsToShape(im2col_dims), im2col_data);
2548 }
2549
HybridConv(const int8_t * input_data,const Dims<4> & input_dims,const int8_t * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * scaling_factors_ptr,float output_activation_min,float output_activation_max,int32_t * scratch_data,const Dims<4> & scratch_dims,float * output_data,const Dims<4> & output_dims,int8_t * im2col_data,const Dims<4> & im2col_dims,CpuBackendContext * context)2550 inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
2551 const int8_t* filter_data, const Dims<4>& filter_dims,
2552 const float* bias_data, const Dims<4>& bias_dims,
2553 int stride_width, int stride_height, int pad_width,
2554 int pad_height, float* scaling_factors_ptr,
2555 float output_activation_min, float output_activation_max,
2556 int32_t* scratch_data, const Dims<4>& scratch_dims,
2557 float* output_data, const Dims<4>& output_dims,
2558 int8_t* im2col_data, const Dims<4>& im2col_dims,
2559 CpuBackendContext* context) {
2560 tflite::ConvParams op_params;
2561 // Padding type is ignored, but still set.
2562 op_params.padding_type = PaddingType::kSame;
2563 op_params.padding_values.width = pad_width;
2564 op_params.padding_values.height = pad_height;
2565 op_params.stride_width = stride_width;
2566 op_params.stride_height = stride_height;
2567 op_params.float_activation_min = output_activation_min;
2568 op_params.float_activation_max = output_activation_max;
2569
2570 HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
2571 input_data, DimsToShape(filter_dims), filter_data,
2572 DimsToShape(bias_dims), bias_data, DimsToShape(scratch_dims),
2573 scratch_data, DimsToShape(output_dims), output_data,
2574 DimsToShape(im2col_dims), im2col_data, context);
2575 }
2576
2577 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2578 void Conv(const float* input_data, const Dims<4>& input_dims,
2579 const float* filter_data, const Dims<4>& filter_dims,
2580 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2581 int stride_height, int dilation_width_factor,
2582 int dilation_height_factor, int pad_width, int pad_height,
2583 float* output_data, const Dims<4>& output_dims, float* im2col_data,
2584 const Dims<4>& im2col_dims) {
2585 float output_activation_min, output_activation_max;
2586 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2587 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2588 stride_width, stride_height, dilation_width_factor,
2589 dilation_height_factor, pad_width, pad_height, output_activation_min,
2590 output_activation_max, output_data, output_dims, im2col_data,
2591 im2col_dims);
2592 }
2593
2594 // legacy, for compatibility with old checked-in code
2595 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2596 void Conv(const float* input_data, const Dims<4>& input_dims,
2597 const float* filter_data, const Dims<4>& filter_dims,
2598 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
2599 int stride_height, int pad_width, int pad_height, float* output_data,
2600 const Dims<4>& output_dims, float* im2col_data,
2601 const Dims<4>& im2col_dims) {
2602 float output_activation_min, output_activation_max;
2603 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
2604 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
2605 stride_width, stride_height, 1, 1, pad_width, pad_height,
2606 output_activation_min, output_activation_max, output_data, output_dims,
2607 im2col_data, im2col_dims);
2608 }
2609
2610 // legacy, for compatibility with old checked-in code
2611 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2612 void Conv(const float* input_data, const Dims<4>& input_dims,
2613 const float* filter_data, const Dims<4>& filter_dims,
2614 const float* bias_data, const Dims<4>& bias_dims, int stride,
2615 int pad_width, int pad_height, float* output_data,
2616 const Dims<4>& output_dims, float* im2col_data,
2617 const Dims<4>& im2col_dims) {
2618 Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
2619 bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
2620 output_dims, im2col_data, im2col_dims);
2621 }
2622
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemmlowp_context)2623 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2624 const uint8* input_data, const RuntimeShape& filter_shape,
2625 const uint8* filter_data, const RuntimeShape& bias_shape,
2626 const int32* bias_data, const RuntimeShape& output_shape,
2627 uint8* output_data, const RuntimeShape& im2col_shape,
2628 uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) {
2629 ruy::profiler::ScopeLabel label("Conv/8bit");
2630 const int stride_width = params.stride_width;
2631 const int stride_height = params.stride_height;
2632 const int dilation_width_factor = params.dilation_width_factor;
2633 const int dilation_height_factor = params.dilation_height_factor;
2634 const int32 input_offset = params.input_offset;
2635 const int32 filter_offset = params.weights_offset;
2636 const int32 output_offset = params.output_offset;
2637 const int32 output_multiplier = params.output_multiplier;
2638 const int output_shift = params.output_shift;
2639 const int32 output_activation_min = params.quantized_activation_min;
2640 const int32 output_activation_max = params.quantized_activation_max;
2641 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2642 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2643 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2644
2645 const uint8* gemm_input_data = nullptr;
2646 const RuntimeShape* gemm_input_shape = nullptr;
2647 const int filter_width = filter_shape.Dims(2);
2648 const int filter_height = filter_shape.Dims(1);
2649 const bool need_dilated_im2col =
2650 dilation_width_factor != 1 || dilation_height_factor != 1;
2651 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2652 filter_width != 1 || filter_height != 1;
2653 if (need_dilated_im2col) {
2654 TFLITE_DCHECK(im2col_data);
2655 const int input_zero_point = -input_offset;
2656 TFLITE_DCHECK_GE(input_zero_point, 0);
2657 TFLITE_DCHECK_LE(input_zero_point, 255);
2658 DilatedIm2col(params, input_zero_point, input_shape, input_data,
2659 filter_shape, output_shape, im2col_data);
2660 gemm_input_data = im2col_data;
2661 gemm_input_shape = &im2col_shape;
2662 } else if (need_im2col) {
2663 TFLITE_DCHECK(im2col_data);
2664 const int input_zero_point = -input_offset;
2665 TFLITE_DCHECK_GE(input_zero_point, 0);
2666 TFLITE_DCHECK_LE(input_zero_point, 255);
2667 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2668 input_data, im2col_shape, im2col_data);
2669 gemm_input_data = im2col_data;
2670 gemm_input_shape = &im2col_shape;
2671 } else {
2672 TFLITE_DCHECK(!im2col_data);
2673 gemm_input_data = input_data;
2674 gemm_input_shape = &input_shape;
2675 }
2676
2677 const int gemm_input_rows = gemm_input_shape->Dims(3);
2678 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2679 // The root cause has not yet been identified though. Same applies below for
2680 // the other calls commented out. This is a partial rollback of cl/196819423.
2681 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2682 const int gemm_input_cols = gemm_input_shape->Dims(0) *
2683 gemm_input_shape->Dims(1) *
2684 gemm_input_shape->Dims(2);
2685 const int filter_rows = filter_shape.Dims(0);
2686 // See b/79927784.
2687 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2688 const int filter_cols =
2689 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2690 const int output_rows = output_shape.Dims(3);
2691 // See b/79927784.
2692 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2693 const int output_cols =
2694 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2695 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2696 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2697 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2698 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2699
2700 #ifdef USE_NEON
2701 if (gemm_input_cols == 1 && output_rows >= 4) {
2702 RuntimeShape fc_filter_shape{
2703 filter_shape.Dims(0),
2704 filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2705
2706 return FullyConnectedAsGEMV(
2707 *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2708 filter_data, filter_offset, bias_shape, bias_data, output_offset,
2709 output_multiplier, output_shift, output_activation_min,
2710 output_activation_max, output_shape, output_data, gemmlowp_context);
2711 }
2712 #endif
2713
2714 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2715 filter_data, filter_rows, filter_cols);
2716 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2717 gemm_input_data, gemm_input_rows, gemm_input_cols);
2718 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2719 output_data, output_rows, output_cols);
2720 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2721 bias_data, output_rows, output_offset, output_multiplier, output_shift,
2722 output_activation_min, output_activation_max);
2723 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2724 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2725 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2726 filter_offset, input_offset, output_pipeline);
2727 }
2728
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2729 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2730 int32 input_offset, const uint8* filter_data,
2731 const Dims<4>& filter_dims, int32 filter_offset,
2732 const int32* bias_data, const Dims<4>& bias_dims,
2733 int stride_width, int stride_height, int dilation_width_factor,
2734 int dilation_height_factor, int pad_width, int pad_height,
2735 int32 output_offset, int32 output_multiplier, int output_shift,
2736 int32 output_activation_min, int32 output_activation_max,
2737 uint8* output_data, const Dims<4>& output_dims,
2738 uint8* im2col_data, const Dims<4>& im2col_dims,
2739 gemmlowp::GemmContext* gemmlowp_context) {
2740 tflite::ConvParams op_params;
2741 // Padding type is ignored, but still set.
2742 op_params.padding_type = PaddingType::kSame;
2743 op_params.padding_values.width = pad_width;
2744 op_params.padding_values.height = pad_height;
2745 op_params.stride_width = stride_width;
2746 op_params.stride_height = stride_height;
2747 op_params.dilation_width_factor = dilation_width_factor;
2748 op_params.dilation_height_factor = dilation_height_factor;
2749 op_params.input_offset = input_offset;
2750 op_params.weights_offset = filter_offset;
2751 op_params.output_offset = output_offset;
2752 op_params.output_multiplier = output_multiplier;
2753 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
2754 op_params.output_shift = kReverseShift * output_shift;
2755 op_params.quantized_activation_min = output_activation_min;
2756 op_params.quantized_activation_max = output_activation_max;
2757
2758 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
2759 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
2760 output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
2761 }
2762
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2763 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2764 int32 input_offset, const uint8* filter_data,
2765 const Dims<4>& filter_dims, int32 filter_offset,
2766 const int32* bias_data, const Dims<4>& bias_dims,
2767 int stride_width, int stride_height, int pad_width,
2768 int pad_height, int32 output_offset, int32 output_multiplier,
2769 int output_shift, int32 output_activation_min,
2770 int32 output_activation_max, uint8* output_data,
2771 const Dims<4>& output_dims, uint8* im2col_data,
2772 const Dims<4>& im2col_dims,
2773 gemmlowp::GemmContext* gemmlowp_context) {
2774 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2775 filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
2776 pad_width, pad_height, output_offset, output_multiplier, output_shift,
2777 output_activation_min, output_activation_max, output_data, output_dims,
2778 im2col_data, im2col_dims, gemmlowp_context);
2779 }
2780
2781 // legacy, for compatibility with old checked-in code
2782 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2783 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
2784 int32 input_offset, const uint8* filter_data,
2785 const Dims<4>& filter_dims, int32 filter_offset,
2786 const int32* bias_data, const Dims<4>& bias_dims,
2787 int stride_width, int stride_height, int pad_width,
2788 int pad_height, int32 output_offset, int32 output_multiplier,
2789 int output_shift, int32 output_activation_min,
2790 int32 output_activation_max, uint8* output_data,
2791 const Dims<4>& output_dims, uint8* im2col_data,
2792 const Dims<4>& im2col_dims,
2793 gemmlowp::GemmContext* gemmlowp_context) {
2794 static_assert(Ac == FusedActivationFunctionType::kNone ||
2795 Ac == FusedActivationFunctionType::kRelu ||
2796 Ac == FusedActivationFunctionType::kRelu6 ||
2797 Ac == FusedActivationFunctionType::kRelu1,
2798 "");
2799 if (Ac == FusedActivationFunctionType::kNone) {
2800 TFLITE_DCHECK_EQ(output_activation_min, 0);
2801 TFLITE_DCHECK_EQ(output_activation_max, 255);
2802 }
2803 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2804 filter_offset, bias_data, bias_dims, stride_width, stride_height,
2805 pad_width, pad_height, output_offset, output_multiplier, output_shift,
2806 output_activation_min, output_activation_max, output_data, output_dims,
2807 im2col_data, im2col_dims, gemmlowp_context);
2808 }
2809
2810 // legacy, for compatibility with old checked-in code
2811 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)2812 void Conv(const uint8* input_data, const Dims<4>& input_dims,
2813 int32 input_offset, const uint8* filter_data,
2814 const Dims<4>& filter_dims, int32 filter_offset,
2815 const int32* bias_data, const Dims<4>& bias_dims, int stride,
2816 int pad_width, int pad_height, int32 output_offset,
2817 int32 output_multiplier, int output_shift,
2818 int32 output_activation_min, int32 output_activation_max,
2819 uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
2820 const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
2821 static_assert(Ac == FusedActivationFunctionType::kNone ||
2822 Ac == FusedActivationFunctionType::kRelu ||
2823 Ac == FusedActivationFunctionType::kRelu6 ||
2824 Ac == FusedActivationFunctionType::kRelu1,
2825 "");
2826 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
2827 filter_offset, bias_data, bias_dims, stride, stride, pad_width,
2828 pad_height, output_offset, output_multiplier, output_shift,
2829 output_activation_min, output_activation_max, output_data, output_dims,
2830 im2col_data, im2col_dims, gemmlowp_context);
2831 }
2832
2833 // legacy, for compatibility with old checked-in code
2834 template <FusedActivationFunctionType Ac, typename T>
Im2col(const T * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int kheight,int kwidth,uint8 zero_byte,T * output_data,const Dims<4> & output_dims)2835 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
2836 int pad_width, int pad_height, int kheight, int kwidth,
2837 uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
2838 Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
2839 kwidth, zero_byte, output_data, output_dims);
2840 }
2841
2842 // legacy, for compatibility with old checked-in code
2843 template <FusedActivationFunctionType Ac>
ConvAsGemm(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)2844 void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
2845 const float* filter_data, const Dims<4>& filter_dims,
2846 const float* bias_data, const Dims<4>& bias_dims,
2847 float* output_data, const Dims<4>& output_dims) {
2848 ruy::profiler::ScopeLabel label("ConvAsGemm");
2849
2850 const auto input_matrix_map =
2851 MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
2852 const auto filter_matrix_map =
2853 MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
2854 auto output_matrix_map =
2855 MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
2856
2857 Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
2858
2859 AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
2860 output_dims);
2861 }
2862
2863 // legacy, for compatibility with old checked-in code
2864 template <FusedActivationFunctionType Ac>
ConvAsGemm(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)2865 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
2866 int32 input_offset, const uint8* filter_data,
2867 const Dims<4>& filter_dims, int32 filter_offset,
2868 const int32* bias_data, const Dims<4>& bias_dims,
2869 int32 output_offset, int32 output_multiplier, int output_shift,
2870 int32 output_activation_min, int32 output_activation_max,
2871 uint8* output_data, const Dims<4>& output_dims,
2872 gemmlowp::GemmContext* gemmlowp_context) {
2873 ruy::profiler::ScopeLabel label("ConvAsGemm/8bit");
2874 static_assert(Ac == FusedActivationFunctionType::kNone ||
2875 Ac == FusedActivationFunctionType::kRelu ||
2876 Ac == FusedActivationFunctionType::kRelu6 ||
2877 Ac == FusedActivationFunctionType::kRelu1,
2878 "");
2879 const int input_rows = input_dims.sizes[0];
2880 const int input_cols = FlatSizeSkipDim(input_dims, 0);
2881 const int filter_rows = filter_dims.sizes[3];
2882 const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
2883 const int output_rows = output_dims.sizes[0];
2884 const int output_cols = FlatSizeSkipDim(output_dims, 0);
2885 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2886 TFLITE_DCHECK_EQ(output_cols, input_cols);
2887 TFLITE_DCHECK_EQ(filter_cols, input_rows);
2888 TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
2889 TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
2890 TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
2891 TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
2892 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2893 filter_data, output_rows, filter_cols, filter_cols);
2894 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2895 input_data, filter_cols, output_cols, filter_cols);
2896 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2897 output_data, output_rows, output_cols, output_rows);
2898 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2899 bias_data, output_rows, output_offset, output_multiplier, -output_shift,
2900 output_activation_min, output_activation_max);
2901 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2902 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2903 gemmlowp_context, filter_matrix, input_matrix, &output_matrix,
2904 filter_offset, input_offset, output_pipeline);
2905 }
2906
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2907 inline void TransposeConv(
2908 const ConvParams& params, const RuntimeShape& input_shape,
2909 const float* input_data, const RuntimeShape& filter_shape,
2910 const float* filter_data, const RuntimeShape& output_shape,
2911 float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
2912 ruy::profiler::ScopeLabel label("TransposeConv");
2913 // Note we could use transposed weights with forward conv for unstrided
2914 // cases. But we are already getting good performance with this code as-is.
2915 TFLITE_DCHECK(im2col_data);
2916 TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
2917 output_shape, im2col_data);
2918
2919 const auto im2col_matrix_map =
2920 MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
2921 const auto filter_matrix_map =
2922 MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
2923 auto output_matrix_map =
2924 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
2925
2926 Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
2927 }
2928
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)2929 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
2930 const float* filter_data, const Dims<4>& filter_dims,
2931 int stride_width, int stride_height, int pad_width,
2932 int pad_height, float* output_data,
2933 const Dims<4>& output_dims, float* im2col_data,
2934 const Dims<4>& im2col_dims) {
2935 tflite::ConvParams op_params;
2936 // Padding type is ignored, but still set.
2937 op_params.padding_type = PaddingType::kSame;
2938 op_params.padding_values.width = pad_width;
2939 op_params.padding_values.height = pad_height;
2940 op_params.stride_width = stride_width;
2941 op_params.stride_height = stride_height;
2942
2943 TransposeConv(op_params, DimsToShape(input_dims), input_data,
2944 DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
2945 output_data, DimsToShape(im2col_dims), im2col_data);
2946 }
2947
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)2948 inline void TransposeConvV2(
2949 const ConvParams& params, const RuntimeShape& input_shape,
2950 const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
2951 const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
2952 float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
2953 CpuBackendContext* cpu_backend_context) {
2954 TransposeConvV2(params, input_shape, input_data, hwoi_ordered_filter_shape,
2955 hwoi_ordered_filter_data, /*bias_shape*/ RuntimeShape(),
2956 /*bias_data*/ nullptr, output_shape, output_data,
2957 col2im_shape, col2im_data, cpu_backend_context);
2958 }
2959
2960 template <typename T>
TransposeIm2col(const T * input_data,const Dims<4> & input_dims,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,const Dims<4> & output_dims,uint8 zero_byte,T * im2col_data)2961 void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
2962 const Dims<4>& filter_dims, int stride_width,
2963 int stride_height, int pad_width, int pad_height,
2964 const Dims<4>& output_dims, uint8 zero_byte,
2965 T* im2col_data) {
2966 tflite::ConvParams op_params;
2967 // Padding type is ignored, but still set.
2968 op_params.padding_type = PaddingType::kSame;
2969 op_params.padding_values.width = pad_width;
2970 op_params.padding_values.height = pad_height;
2971 op_params.stride_width = stride_width;
2972 op_params.stride_height = stride_height;
2973
2974 TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
2975 DimsToShape(filter_dims), DimsToShape(output_dims),
2976 im2col_data);
2977 }
2978
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)2979 inline void LstmCell(
2980 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2981 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2982 const float* prev_activ_data, const RuntimeShape& weights_shape,
2983 const float* weights_data, const RuntimeShape& unextended_bias_shape,
2984 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2985 const float* prev_state_data,
2986 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2987 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2988 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2989 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
2990 ruy::profiler::ScopeLabel label("LstmCell");
2991 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2992 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2993 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2994 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2995 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2996 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2997 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2998 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2999 const RuntimeShape input_shape =
3000 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3001 const RuntimeShape prev_activ_shape =
3002 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3003 const RuntimeShape bias_shape =
3004 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3005 const RuntimeShape prev_state_shape =
3006 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3007 const RuntimeShape output_state_shape =
3008 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3009 const RuntimeShape output_activ_shape =
3010 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3011 const RuntimeShape concat_temp_shape =
3012 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3013 const RuntimeShape activ_temp_shape =
3014 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3015 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3016
3017 const int weights_dim_count = weights_shape.DimensionsCount();
3018 MatchingDim( // batches
3019 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3020 output_state_shape, 0, output_activ_shape, 0);
3021 MatchingDim( // height
3022 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3023 output_state_shape, 1, output_activ_shape, 1);
3024 MatchingDim( // width
3025 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3026 output_state_shape, 2, output_activ_shape, 2);
3027 const int input_depth = input_shape.Dims(3);
3028 const int prev_activ_depth = prev_activ_shape.Dims(3);
3029 const int total_input_depth = prev_activ_depth + input_depth;
3030 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3031 total_input_depth);
3032 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3033 const int intern_activ_depth =
3034 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3035 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3036 intern_activ_depth * total_input_depth);
3037 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3038 const int output_depth =
3039 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3040 3, output_activ_shape, 3);
3041 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3042
3043 // Concatenate prev_activ and input data together
3044 std::vector<float const*> concat_input_arrays_data;
3045 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3046 concat_input_arrays_data.push_back(input_data);
3047 concat_input_arrays_data.push_back(prev_activ_data);
3048 concat_input_arrays_shapes.push_back(&input_shape);
3049 concat_input_arrays_shapes.push_back(&prev_activ_shape);
3050 tflite::ConcatenationParams concat_params;
3051 concat_params.axis = 3;
3052 concat_params.inputs_count = concat_input_arrays_data.size();
3053 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3054 &(concat_input_arrays_data[0]), concat_temp_shape,
3055 concat_temp_data);
3056
3057 // Fully connected
3058 tflite::FullyConnectedParams fc_params;
3059 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3060 fc_params.float_activation_max = std::numeric_limits<float>::max();
3061 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3062 weights_data, bias_shape, bias_data, activ_temp_shape,
3063 activ_temp_data);
3064
3065 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3066 // operations.
3067 ArrayMap<float> activ_temp_map =
3068 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3069 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3070 activ_temp_map.cols());
3071 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3072 activ_temp_map.cols());
3073 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3074 activ_temp_map.cols());
3075 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3076 activ_temp_map.cols());
3077 ArrayMap<const float> prev_state_map =
3078 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3079 ArrayMap<float> output_state_map =
3080 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3081 ArrayMap<float> output_activ_map =
3082 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3083
3084 // Combined memory state and final output calculation
3085 ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
3086 output_state_map =
3087 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3088 new_input_sm.tanh() +
3089 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3090 prev_state_map;
3091 output_activ_map =
3092 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3093 output_state_map.tanh();
3094 }
3095
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)3096 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
3097 const float* prev_activ_data,
3098 const Dims<4>& prev_activ_dims, const float* weights_data,
3099 const Dims<4>& weights_dims, const float* bias_data,
3100 const Dims<4>& bias_dims, const float* prev_state_data,
3101 const Dims<4>& prev_state_dims, float* output_state_data,
3102 const Dims<4>& output_state_dims, float* output_activ_data,
3103 const Dims<4>& output_activ_dims, float* concat_temp_data,
3104 const Dims<4>& concat_temp_dims, float* activ_temp_data,
3105 const Dims<4>& activ_temp_dims) {
3106 tflite::LstmCellParams op_params;
3107 // Float LSTM cell does not need parameters to be set: leave untouched.
3108
3109 LstmCell(op_params, DimsToShape(input_dims), input_data,
3110 DimsToShape(prev_activ_dims), prev_activ_data,
3111 DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
3112 bias_data, DimsToShape(prev_state_dims), prev_state_data,
3113 DimsToShape(output_state_dims), output_state_data,
3114 DimsToShape(output_activ_dims), output_activ_data,
3115 DimsToShape(concat_temp_dims), concat_temp_data,
3116 DimsToShape(activ_temp_dims), activ_temp_data);
3117 }
3118
3119 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemmlowp_context)3120 inline void LstmCell(
3121 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3122 const uint8* input_data_uint8,
3123 const RuntimeShape& unextended_prev_activ_shape,
3124 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3125 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3126 const int32* bias_data_int32,
3127 const RuntimeShape& unextended_prev_state_shape,
3128 const int16* prev_state_data_int16,
3129 const RuntimeShape& unextended_output_state_shape,
3130 int16* output_state_data_int16,
3131 const RuntimeShape& unextended_output_activ_shape,
3132 uint8* output_activ_data_uint8,
3133 const RuntimeShape& unextended_concat_temp_shape,
3134 uint8* concat_temp_data_uint8,
3135 const RuntimeShape& unextended_activ_temp_shape,
3136 int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) {
3137 ruy::profiler::ScopeLabel label(
3138 "LstmCell/quantized (8bit external, 16bit internal)");
3139 int32 weights_zero_point = params.weights_zero_point;
3140 int32 accum_multiplier = params.accum_multiplier;
3141 int accum_shift = params.accum_shift;
3142 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3143 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3144 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3145 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3146 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3147 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3148 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3149 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3150 const RuntimeShape input_shape =
3151 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3152 const RuntimeShape prev_activ_shape =
3153 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3154 const RuntimeShape bias_shape =
3155 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3156 const RuntimeShape prev_state_shape =
3157 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3158 const RuntimeShape output_state_shape =
3159 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3160 const RuntimeShape output_activ_shape =
3161 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3162 const RuntimeShape concat_temp_shape =
3163 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3164 const RuntimeShape activ_temp_shape =
3165 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3166 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3167
3168 // Gather dimensions information, and perform consistency checks.
3169 const int weights_dim_count = weights_shape.DimensionsCount();
3170 const int outer_size = MatchingFlatSizeSkipDim(
3171 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3172 output_activ_shape);
3173 const int input_depth = input_shape.Dims(3);
3174 const int prev_activ_depth = prev_activ_shape.Dims(3);
3175 const int total_input_depth = prev_activ_depth + input_depth;
3176 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3177 total_input_depth);
3178 const int intern_activ_depth =
3179 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3180 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3181 intern_activ_depth * total_input_depth);
3182 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3183 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3184 const int output_depth =
3185 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3186 3, output_activ_shape, 3);
3187 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3188 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3189 const int fc_output_depth =
3190 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3191 const int fc_accum_depth = total_input_depth;
3192 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3193
3194 // Depth-concatenate prev_activ and input data together.
3195 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3196 prev_activ_data_uint8};
3197 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3198 &prev_activ_shape};
3199 tflite::ConcatenationParams concat_params;
3200 concat_params.axis = 3;
3201 concat_params.inputs_count = 2;
3202 Concatenation(concat_params, concat_input_arrays_shapes,
3203 concat_input_arrays_data, concat_temp_shape,
3204 concat_temp_data_uint8);
3205
3206 // Implementation of the fully connected node inside the LSTM cell.
3207 // The operands are 8-bit integers, the accumulators are internally 32bit
3208 // integers, and the output is 16-bit fixed-point with 3 integer bits so
3209 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3210 // is explained in the function comment above.
3211 bool gemm_already_performed = false;
3212 #ifdef GEMMLOWP_NEON
3213 if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3214 GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3215 weights_data_uint8, weights_zero_point, bias_shape,
3216 bias_data_int32, accum_multiplier, accum_shift,
3217 activ_temp_shape, activ_temp_data_int16);
3218 gemm_already_performed = true;
3219 }
3220 #endif
3221 if (!gemm_already_performed) {
3222 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3223 weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3224 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3225 concat_temp_data_uint8, fc_accum_depth, fc_batches);
3226 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3227 activ_temp_data_int16, fc_output_depth, fc_batches);
3228 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3229 ColVectorMap;
3230 ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3231 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3232 bias_addition_stage.bias_vector = bias_vector;
3233 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3234 scale_stage.result_offset_after_shift = 0;
3235 scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3236 scale_stage.result_exponent = accum_shift;
3237 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3238 auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3239 saturating_cast_int16_stage);
3240 gemmlowp::GemmWithOutputPipeline<
3241 uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3242 gemmlowp_context, weights_matrix, input_matrix, &output_matrix,
3243 -weights_zero_point, -128, output_pipeline);
3244 }
3245
3246 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3247 // and muls, all done in 16-bit fixed-point.
3248 const int16* input_gate_input_ptr = activ_temp_data_int16;
3249 const int16* input_modulation_gate_input_ptr =
3250 activ_temp_data_int16 + output_depth;
3251 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3252 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3253 const int16* prev_state_ptr = prev_state_data_int16;
3254 int16* output_state_data_ptr = output_state_data_int16;
3255 uint8* output_activ_data_ptr = output_activ_data_uint8;
3256
3257 for (int b = 0; b < outer_size; ++b) {
3258 int c = 0;
3259 #ifdef GEMMLOWP_NEON
3260 for (; c <= output_depth - 8; c += 8) {
3261 // Define the fixed-point data types that we will use here. All use
3262 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3263 // They only differ by the number of integral vs. fractional bits,
3264 // determining the range of values that they can represent.
3265 //
3266 // F0 uses 0 integer bits, range [-1, 1].
3267 // This is the return type of math functions such as tanh, logistic,
3268 // whose range is in [-1, 1].
3269 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3270 // F3 uses 3 integer bits, range [-8, 8].
3271 // This is the range of the previous fully-connected node's output,
3272 // which is our input here.
3273 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3274 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3275 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3276 // number of integer bits is currently dictated by the model. See comment
3277 // on the StateIntegerBits template parameter above.
3278 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3279 // Implementation of input gate, using fixed-point logistic function.
3280 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3281 input_gate_input_ptr += 8;
3282 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3283 // Implementation of input modulation gate, using fixed-point tanh
3284 // function.
3285 F3 input_modulation_gate_input =
3286 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3287 input_modulation_gate_input_ptr += 8;
3288 F0 input_modulation_gate_output =
3289 gemmlowp::tanh(input_modulation_gate_input);
3290 // Implementation of forget gate, using fixed-point logistic function.
3291 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3292 forget_gate_input_ptr += 8;
3293 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3294 // Implementation of output gate, using fixed-point logistic function.
3295 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3296 output_gate_input_ptr += 8;
3297 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3298 // Implementation of internal multiplication nodes, still in fixed-point.
3299 F0 input_times_input_modulation =
3300 input_gate_output * input_modulation_gate_output;
3301 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3302 prev_state_ptr += 8;
3303 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3304 // Implementation of internal addition node, saturating.
3305 FS new_state = gemmlowp::SaturatingAdd(
3306 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3307 prev_state_times_forget_state);
3308 // Implementation of last internal Tanh node, still in fixed-point.
3309 // Since a Tanh fixed-point implementation is specialized for a given
3310 // number or integer bits, and each specialization can have a substantial
3311 // code size, and we already used above a Tanh on an input with 3 integer
3312 // bits, and per the table in the above function comment there is no
3313 // significant accuracy to be lost by clamping to [-8, +8] for a
3314 // 3-integer-bits representation, let us just do that. This helps people
3315 // porting this to targets where code footprint must be minimized.
3316 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3317 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3318 // Store the new internal state back to memory, as 16-bit integers.
3319 // Note: here we store the original value with StateIntegerBits, not
3320 // the rescaled 3-integer-bits value fed to tanh.
3321 vst1q_s16(output_state_data_ptr, new_state.raw());
3322 output_state_data_ptr += 8;
3323 // Down-scale the output activations to 8-bit integers, saturating,
3324 // and store back to memory.
3325 int16x8_t rescaled_output_activ =
3326 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3327 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3328 uint8x8_t uint8_output_activ =
3329 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3330 vst1_u8(output_activ_data_ptr, uint8_output_activ);
3331 output_activ_data_ptr += 8;
3332 }
3333 #endif
3334 for (; c < output_depth; ++c) {
3335 // Define the fixed-point data types that we will use here. All use
3336 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3337 // They only differ by the number of integral vs. fractional bits,
3338 // determining the range of values that they can represent.
3339 //
3340 // F0 uses 0 integer bits, range [-1, 1].
3341 // This is the return type of math functions such as tanh, logistic,
3342 // whose range is in [-1, 1].
3343 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3344 // F3 uses 3 integer bits, range [-8, 8].
3345 // This is the range of the previous fully-connected node's output,
3346 // which is our input here.
3347 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3348 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3349 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3350 // number of integer bits is currently dictated by the model. See comment
3351 // on the StateIntegerBits template parameter above.
3352 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3353 // Implementation of input gate, using fixed-point logistic function.
3354 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3355 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3356 // Implementation of input modulation gate, using fixed-point tanh
3357 // function.
3358 F3 input_modulation_gate_input =
3359 F3::FromRaw(*input_modulation_gate_input_ptr++);
3360 F0 input_modulation_gate_output =
3361 gemmlowp::tanh(input_modulation_gate_input);
3362 // Implementation of forget gate, using fixed-point logistic function.
3363 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3364 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3365 // Implementation of output gate, using fixed-point logistic function.
3366 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3367 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3368 // Implementation of internal multiplication nodes, still in fixed-point.
3369 F0 input_times_input_modulation =
3370 input_gate_output * input_modulation_gate_output;
3371 FS prev_state = FS::FromRaw(*prev_state_ptr++);
3372 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3373 // Implementation of internal addition node, saturating.
3374 FS new_state = gemmlowp::SaturatingAdd(
3375 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3376 prev_state_times_forget_state);
3377 // Implementation of last internal Tanh node, still in fixed-point.
3378 // Since a Tanh fixed-point implementation is specialized for a given
3379 // number or integer bits, and each specialization can have a substantial
3380 // code size, and we already used above a Tanh on an input with 3 integer
3381 // bits, and per the table in the above function comment there is no
3382 // significant accuracy to be lost by clamping to [-8, +8] for a
3383 // 3-integer-bits representation, let us just do that. This helps people
3384 // porting this to targets where code footprint must be minimized.
3385 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3386 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3387 // Store the new internal state back to memory, as 16-bit integers.
3388 // Note: here we store the original value with StateIntegerBits, not
3389 // the rescaled 3-integer-bits value fed to tanh.
3390 *output_state_data_ptr++ = new_state.raw();
3391 // Down-scale the output activations to 8-bit integers, saturating,
3392 // and store back to memory.
3393 int16 rescaled_output_activ =
3394 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3395 int16 clamped_output_activ =
3396 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3397 *output_activ_data_ptr++ = 128 + clamped_output_activ;
3398 }
3399 input_gate_input_ptr += 3 * output_depth;
3400 input_modulation_gate_input_ptr += 3 * output_depth;
3401 forget_gate_input_ptr += 3 * output_depth;
3402 output_gate_input_ptr += 3 * output_depth;
3403 }
3404 }
3405
3406 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)3407 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
3408 const uint8* prev_activ_data_uint8,
3409 const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
3410 const Dims<4>& weights_dims, const int32* bias_data_int32,
3411 const Dims<4>& bias_dims, const int16* prev_state_data_int16,
3412 const Dims<4>& prev_state_dims, int16* output_state_data_int16,
3413 const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
3414 const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
3415 const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
3416 const Dims<4>& activ_temp_dims, int32 weights_zero_point,
3417 int32 accum_multiplier, int accum_shift,
3418 gemmlowp::GemmContext* gemmlowp_context) {
3419 tflite::LstmCellParams op_params;
3420 op_params.weights_zero_point = weights_zero_point;
3421 op_params.accum_multiplier = accum_multiplier;
3422 op_params.accum_shift = accum_shift;
3423
3424 LstmCell<StateIntegerBits>(
3425 op_params, DimsToShape(input_dims), input_data_uint8,
3426 DimsToShape(prev_activ_dims), prev_activ_data_uint8,
3427 DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
3428 bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
3429 DimsToShape(output_state_dims), output_state_data_int16,
3430 DimsToShape(output_activ_dims), output_activ_data_uint8,
3431 DimsToShape(concat_temp_dims), concat_temp_data_uint8,
3432 DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
3433 }
3434
3435 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3436 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
3437 const T* input2_data, const Dims<4>& input2_dims,
3438 T output_activation_min, T output_activation_max,
3439 T* output_data, const Dims<4>& output_dims) {
3440 tflite::ArithmeticParams op_params;
3441 SetActivationParams(output_activation_min, output_activation_max, &op_params);
3442
3443 BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
3444 DimsToShape(input2_dims), input2_data,
3445 DimsToShape(output_dims), output_data);
3446 }
3447
3448 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)3449 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
3450 float* output_data, const RuntimeShape& output_shape) {
3451 static_assert(Ac == FusedActivationFunctionType::kNone, "");
3452 tflite::L2NormalizationParams op_params;
3453 // No params need to be set for float, but reserved in signature for future
3454 // activations.
3455
3456 L2Normalization(op_params, input_shape, input_data, output_shape,
3457 output_data);
3458 }
3459
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)3460 inline void L2Normalization(const uint8* input_data,
3461 const RuntimeShape& input_shape,
3462 int32 input_zero_point, uint8* output_data,
3463 const RuntimeShape& output_shape) {
3464 tflite::L2NormalizationParams op_params;
3465 op_params.input_zero_point = input_zero_point;
3466
3467 L2Normalization(op_params, input_shape, input_data, output_shape,
3468 output_data);
3469 }
3470
3471 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3472 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
3473 float* output_data, const Dims<4>& output_dims) {
3474 L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
3475 DimsToShape(output_dims));
3476 }
3477
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)3478 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
3479 int32 input_zero_point, uint8* output_data,
3480 const Dims<4>& output_dims) {
3481 L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
3482 output_data, DimsToShape(output_dims));
3483 }
3484
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)3485 inline void Relu(const float* input_data, const Dims<4>& input_dims,
3486 float* output_data, const Dims<4>& output_dims) {
3487 Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3488 output_data);
3489 }
3490
3491 // legacy, for compatibility with old checked-in code
3492 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3493 void Add(const float* input1_data, const Dims<4>& input1_dims,
3494 const float* input2_data, const Dims<4>& input2_dims,
3495 float* output_data, const Dims<4>& output_dims) {
3496 float output_activation_min, output_activation_max;
3497 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3498
3499 tflite::ArithmeticParams op_params;
3500 op_params.float_activation_min = output_activation_min;
3501 op_params.float_activation_max = output_activation_max;
3502 Add(op_params, DimsToShape(input1_dims), input1_data,
3503 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3504 output_data);
3505 }
3506
3507 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3508 inline void Add(int left_shift, const uint8* input1_data,
3509 const Dims<4>& input1_dims, int32 input1_offset,
3510 int32 input1_multiplier, int input1_shift,
3511 const uint8* input2_data, const Dims<4>& input2_dims,
3512 int32 input2_offset, int32 input2_multiplier, int input2_shift,
3513 int32 output_offset, int32 output_multiplier, int output_shift,
3514 int32 output_activation_min, int32 output_activation_max,
3515 uint8* output_data, const Dims<4>& output_dims) {
3516 constexpr int kReverseShift = -1;
3517 static_assert(Ac == FusedActivationFunctionType::kNone ||
3518 Ac == FusedActivationFunctionType::kRelu ||
3519 Ac == FusedActivationFunctionType::kRelu6 ||
3520 Ac == FusedActivationFunctionType::kRelu1,
3521 "");
3522 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3523 if (Ac == FusedActivationFunctionType::kNone) {
3524 TFLITE_DCHECK_EQ(output_activation_min, 0);
3525 TFLITE_DCHECK_EQ(output_activation_max, 255);
3526 }
3527
3528 tflite::ArithmeticParams op_params;
3529 op_params.left_shift = left_shift;
3530 op_params.input1_offset = input1_offset;
3531 op_params.input1_multiplier = input1_multiplier;
3532 op_params.input1_shift = kReverseShift * input1_shift;
3533 op_params.input2_offset = input2_offset;
3534 op_params.input2_multiplier = input2_multiplier;
3535 op_params.input2_shift = kReverseShift * input2_shift;
3536 op_params.output_offset = output_offset;
3537 op_params.output_multiplier = output_multiplier;
3538 op_params.output_shift = kReverseShift * output_shift;
3539 op_params.quantized_activation_min = output_activation_min;
3540 op_params.quantized_activation_max = output_activation_max;
3541 Add(op_params, DimsToShape(input1_dims), input1_data,
3542 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3543 output_data);
3544 }
3545
3546 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)3547 void Add(const int32* input1_data, const Dims<4>& input1_dims,
3548 const int32* input2_data, const Dims<4>& input2_dims,
3549 int32* output_data, const Dims<4>& output_dims) {
3550 ruy::profiler::ScopeLabel label("Add/int32");
3551 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
3552
3553 tflite::ArithmeticParams op_params;
3554 op_params.quantized_activation_min = std::numeric_limits<int32>::min();
3555 op_params.quantized_activation_max = std::numeric_limits<int32>::max();
3556 Add(op_params, DimsToShape(input1_dims), input1_data,
3557 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3558 output_data);
3559 }
3560
3561 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)3562 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3563 const T* input2_data, const Dims<4>& input2_dims,
3564 T output_activation_min, T output_activation_max,
3565 T* output_data, const Dims<4>& output_dims) {
3566 tflite::ArithmeticParams op_params;
3567 op_params.float_activation_min = output_activation_min;
3568 op_params.float_activation_max = output_activation_max;
3569 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3570 DimsToShape(input2_dims), input2_data,
3571 DimsToShape(output_dims), output_data);
3572 }
3573
3574 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3575 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
3576 const Dims<4>& input1_dims, int32 input1_offset,
3577 int32 input1_multiplier, int input1_shift,
3578 const uint8* input2_data, const Dims<4>& input2_dims,
3579 int32 input2_offset, int32 input2_multiplier,
3580 int input2_shift, int32 output_offset,
3581 int32 output_multiplier, int output_shift,
3582 int32 output_activation_min,
3583 int32 output_activation_max, uint8* output_data,
3584 const Dims<4>& output_dims) {
3585 constexpr int kReverseShift = -1;
3586 static_assert(Ac == FusedActivationFunctionType::kNone ||
3587 Ac == FusedActivationFunctionType::kRelu ||
3588 Ac == FusedActivationFunctionType::kRelu6 ||
3589 Ac == FusedActivationFunctionType::kRelu1,
3590 "");
3591 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3592 if (Ac == FusedActivationFunctionType::kNone) {
3593 TFLITE_DCHECK_EQ(output_activation_min, 0);
3594 TFLITE_DCHECK_EQ(output_activation_max, 255);
3595 }
3596
3597 tflite::ArithmeticParams op_params;
3598 op_params.left_shift = left_shift;
3599 op_params.input1_offset = input1_offset;
3600 op_params.input1_multiplier = input1_multiplier;
3601 op_params.input1_shift = kReverseShift * input1_shift;
3602 op_params.input2_offset = input2_offset;
3603 op_params.input2_multiplier = input2_multiplier;
3604 op_params.input2_shift = kReverseShift * input2_shift;
3605 op_params.output_offset = output_offset;
3606 op_params.output_multiplier = output_multiplier;
3607 op_params.output_shift = kReverseShift * output_shift;
3608 op_params.quantized_activation_min = output_activation_min;
3609 op_params.quantized_activation_max = output_activation_max;
3610 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3611 DimsToShape(input2_dims), input2_data,
3612 DimsToShape(output_dims), output_data);
3613 }
3614
3615 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3616 inline void BroadcastAddFivefold(
3617 int y0, int y1, int y2, int y3, int y4, int left_shift,
3618 const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
3619 int32 input1_multiplier, int input1_shift, const uint8* input2_data,
3620 const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
3621 int input2_shift, int32 output_offset, int32 output_multiplier,
3622 int output_shift, int32 output_activation_min, int32 output_activation_max,
3623 uint8* output_data, const Dims<4>& output_dims) {
3624 constexpr int kReverseShift = -1;
3625 static_assert(Ac == FusedActivationFunctionType::kNone ||
3626 Ac == FusedActivationFunctionType::kRelu ||
3627 Ac == FusedActivationFunctionType::kRelu6 ||
3628 Ac == FusedActivationFunctionType::kRelu1,
3629 "");
3630 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3631 if (Ac == FusedActivationFunctionType::kNone) {
3632 TFLITE_DCHECK_EQ(output_activation_min, 0);
3633 TFLITE_DCHECK_EQ(output_activation_max, 255);
3634 }
3635 tflite::ArithmeticParams op_params;
3636 op_params.broadcast_category =
3637 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3638 op_params.left_shift = left_shift;
3639 op_params.input1_offset = input1_offset;
3640 op_params.input1_multiplier = input1_multiplier;
3641 op_params.input1_shift = kReverseShift * input1_shift;
3642 op_params.input2_offset = input2_offset;
3643 op_params.input2_multiplier = input2_multiplier;
3644 op_params.input2_shift = kReverseShift * input2_shift;
3645 op_params.output_offset = output_offset;
3646 op_params.output_multiplier = output_multiplier;
3647 op_params.output_shift = kReverseShift * output_shift;
3648 op_params.quantized_activation_min = output_activation_min;
3649 op_params.quantized_activation_max = output_activation_max;
3650 op_params.broadcast_shape[4] = y0;
3651 op_params.broadcast_shape[3] = y1;
3652 op_params.broadcast_shape[2] = y2;
3653 op_params.broadcast_shape[1] = y3;
3654 op_params.broadcast_shape[0] = y4;
3655 BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
3656 DimsToShape(input2_dims), input2_data,
3657 DimsToShape(output_dims), output_data);
3658 }
3659
3660 // legacy, for compatibility with old checked-in code
3661 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3662 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
3663 const T* input2_data, const Dims<4>& input2_dims,
3664 T* output_data, const Dims<4>& output_dims) {
3665 T output_activation_min, output_activation_max;
3666 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3667
3668 BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
3669 output_activation_min, output_activation_max, output_data,
3670 output_dims);
3671 }
3672
3673 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)3674 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
3675 int input1_shift, const int16* input2_data,
3676 const Dims<4>& input2_dims, int input2_shift,
3677 int16 output_activation_min, int16 output_activation_max,
3678 int16* output_data, const Dims<4>& output_dims) {
3679 constexpr int kReverseShift = -1;
3680 static_assert(Ac == FusedActivationFunctionType::kNone ||
3681 Ac == FusedActivationFunctionType::kRelu ||
3682 Ac == FusedActivationFunctionType::kRelu6 ||
3683 Ac == FusedActivationFunctionType::kRelu1,
3684 "");
3685 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3686 if (Ac == FusedActivationFunctionType::kNone) {
3687 TFLITE_DCHECK_EQ(output_activation_min, -32768);
3688 TFLITE_DCHECK_EQ(output_activation_max, 32767);
3689 }
3690
3691 tflite::ArithmeticParams op_params;
3692 op_params.input1_shift = kReverseShift * input1_shift;
3693 op_params.input2_shift = kReverseShift * input2_shift;
3694 op_params.quantized_activation_min = output_activation_min;
3695 op_params.quantized_activation_max = output_activation_max;
3696 Add(op_params, DimsToShape(input1_dims), input1_data,
3697 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3698 output_data);
3699 }
3700
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)3701 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
3702 const float* input2_data, const Dims<4>& input2_dims,
3703 float* output_data, const Dims<4>& output_dims) {
3704 float output_activation_min, output_activation_max;
3705 GetActivationMinMax(FusedActivationFunctionType::kNone,
3706 &output_activation_min, &output_activation_max);
3707 tflite::ArithmeticParams op_params;
3708 op_params.float_activation_min = output_activation_min;
3709 op_params.float_activation_max = output_activation_max;
3710 Sub(op_params, DimsToShape(input1_dims), input1_data,
3711 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3712 output_data);
3713 }
3714
3715 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)3716 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
3717 const Dims<4>& input2_dims, T* output_data,
3718 const Dims<4>& output_dims) {
3719 T output_activation_min, output_activation_max;
3720 GetActivationMinMax(FusedActivationFunctionType::kNone,
3721 &output_activation_min, &output_activation_max);
3722 tflite::ArithmeticParams op_params;
3723 op_params.quantized_activation_min = output_activation_min;
3724 op_params.quantized_activation_max = output_activation_max;
3725 Sub(op_params, DimsToShape(input1_dims), input1_data,
3726 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
3727 output_data);
3728 }
3729
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3730 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3731 int32 input1_offset, const uint8* input2_data,
3732 const Dims<4>& input2_dims, int32 input2_offset,
3733 int32 output_offset, int32 output_multiplier,
3734 int output_shift, int32 output_activation_min,
3735 int32 output_activation_max, uint8* output_data,
3736 const Dims<4>& output_dims) {
3737 tflite::ArithmeticParams op_params;
3738 SetActivationParams(output_activation_min, output_activation_max, &op_params);
3739 op_params.input1_offset = input1_offset;
3740 op_params.input2_offset = input2_offset;
3741 op_params.output_offset = output_offset;
3742 op_params.output_multiplier = output_multiplier;
3743 op_params.output_shift = kReverseShift * output_shift;
3744
3745 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
3746 DimsToShape(input2_dims), input2_data,
3747 DimsToShape(output_dims), output_data);
3748 }
3749
3750 // legacy, for compatibility with old checked-in code
3751 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3752 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
3753 int32 input1_offset, const uint8* input2_data,
3754 const Dims<4>& input2_dims, int32 input2_offset,
3755 int32 output_offset, int32 output_multiplier,
3756 int output_shift, int32 output_activation_min,
3757 int32 output_activation_max, uint8* output_data,
3758 const Dims<4>& output_dims) {
3759 BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
3760 input2_dims, input2_offset, output_offset, output_multiplier,
3761 output_shift, output_activation_min, output_activation_max,
3762 output_data, output_dims);
3763 }
3764
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3765 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
3766 int stride_width, int stride_height, int pad_width,
3767 int pad_height, int kwidth, int kheight,
3768 float output_activation_min,
3769 float output_activation_max, float* output_data,
3770 const Dims<4>& output_dims) {
3771 tflite::PoolParams params;
3772 params.stride_height = stride_height;
3773 params.stride_width = stride_width;
3774 params.filter_height = kheight;
3775 params.filter_width = kwidth;
3776 params.padding_values.height = pad_height;
3777 params.padding_values.width = pad_width;
3778 params.float_activation_min = output_activation_min;
3779 params.float_activation_max = output_activation_max;
3780 AveragePool(params, DimsToShape(input_dims), input_data,
3781 DimsToShape(output_dims), output_data);
3782 }
3783
3784 // legacy, for compatibility with old checked-in code
3785 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3786 void AveragePool(const float* input_data, const Dims<4>& input_dims,
3787 int stride_width, int stride_height, int pad_width,
3788 int pad_height, int kwidth, int kheight, float* output_data,
3789 const Dims<4>& output_dims) {
3790 float output_activation_min, output_activation_max;
3791 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3792
3793 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
3794 pad_height, kwidth, kheight, output_activation_min,
3795 output_activation_max, output_data, output_dims);
3796 }
3797
3798 // legacy, for compatibility with old checked-in code
3799 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3800 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
3801 int pad_width, int pad_height, int filter_width,
3802 int filter_height, float* output_data,
3803 const Dims<4>& output_dims) {
3804 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3805 filter_width, filter_height, output_data, output_dims);
3806 }
3807
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3808 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3809 int stride_width, int stride_height, int pad_width,
3810 int pad_height, int filter_width, int filter_height,
3811 int32 output_activation_min,
3812 int32 output_activation_max, uint8* output_data,
3813 const Dims<4>& output_dims) {
3814 tflite::PoolParams params;
3815 params.stride_height = stride_height;
3816 params.stride_width = stride_width;
3817 params.filter_height = filter_height;
3818 params.filter_width = filter_width;
3819 params.padding_values.height = pad_height;
3820 params.padding_values.width = pad_width;
3821 params.quantized_activation_min = output_activation_min;
3822 params.quantized_activation_max = output_activation_max;
3823 AveragePool(params, DimsToShape(input_dims), input_data,
3824 DimsToShape(output_dims), output_data);
3825 }
3826
3827 // legacy, for compatibility with old checked-in code
3828 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3829 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
3830 int stride_width, int stride_height, int pad_width,
3831 int pad_height, int filter_width, int filter_height,
3832 int32 output_activation_min, int32 output_activation_max,
3833 uint8* output_data, const Dims<4>& output_dims) {
3834 static_assert(Ac == FusedActivationFunctionType::kNone ||
3835 Ac == FusedActivationFunctionType::kRelu ||
3836 Ac == FusedActivationFunctionType::kRelu6 ||
3837 Ac == FusedActivationFunctionType::kRelu1,
3838 "");
3839 if (Ac == FusedActivationFunctionType::kNone) {
3840 TFLITE_DCHECK_EQ(output_activation_min, 0);
3841 TFLITE_DCHECK_EQ(output_activation_max, 255);
3842 }
3843 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
3844 pad_height, filter_width, filter_height, output_activation_min,
3845 output_activation_max, output_data, output_dims);
3846 }
3847
3848 // legacy, for compatibility with old checked-in code
3849 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3850 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3851 int pad_width, int pad_height, int filter_width,
3852 int filter_height, int32 output_activation_min,
3853 int32 output_activation_max, uint8* output_data,
3854 const Dims<4>& output_dims) {
3855 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3856 filter_width, filter_height, output_activation_min,
3857 output_activation_max, output_data, output_dims);
3858 }
3859
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3860 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
3861 int stride_width, int stride_height, int pad_width,
3862 int pad_height, int kwidth, int kheight,
3863 float output_activation_min, float output_activation_max,
3864 float* output_data, const Dims<4>& output_dims) {
3865 tflite::PoolParams params;
3866 params.stride_height = stride_height;
3867 params.stride_width = stride_width;
3868 params.filter_height = kheight;
3869 params.filter_width = kwidth;
3870 params.padding_values.height = pad_height;
3871 params.padding_values.width = pad_width;
3872 params.float_activation_min = output_activation_min;
3873 params.float_activation_max = output_activation_max;
3874 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3875 output_data);
3876 }
3877
3878 // legacy, for compatibility with old checked-in code
3879 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)3880 void MaxPool(const float* input_data, const Dims<4>& input_dims,
3881 int stride_width, int stride_height, int pad_width, int pad_height,
3882 int kwidth, int kheight, float* output_data,
3883 const Dims<4>& output_dims) {
3884 float output_activation_min, output_activation_max;
3885 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3886 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3887 pad_height, kwidth, kheight, output_activation_min,
3888 output_activation_max, output_data, output_dims);
3889 }
3890
3891 // legacy, for compatibility with old checked-in code
3892 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3893 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
3894 int pad_width, int pad_height, int filter_width, int filter_height,
3895 float* output_data, const Dims<4>& output_dims) {
3896 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3897 filter_width, filter_height, output_data, output_dims);
3898 }
3899
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3900 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3901 int stride_width, int stride_height, int pad_width,
3902 int pad_height, int filter_width, int filter_height,
3903 int32 output_activation_min, int32 output_activation_max,
3904 uint8* output_data, const Dims<4>& output_dims) {
3905 PoolParams params;
3906 params.stride_height = stride_height;
3907 params.stride_width = stride_width;
3908 params.filter_height = filter_height;
3909 params.filter_width = filter_width;
3910 params.padding_values.height = pad_height;
3911 params.padding_values.width = pad_width;
3912 params.quantized_activation_min = output_activation_min;
3913 params.quantized_activation_max = output_activation_max;
3914 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3915 output_data);
3916 }
3917
3918 // legacy, for compatibility with old checked-in code
3919 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3920 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
3921 int stride_width, int stride_height, int pad_width, int pad_height,
3922 int filter_width, int filter_height, int32 output_activation_min,
3923 int32 output_activation_max, uint8* output_data,
3924 const Dims<4>& output_dims) {
3925 static_assert(Ac == FusedActivationFunctionType::kNone ||
3926 Ac == FusedActivationFunctionType::kRelu ||
3927 Ac == FusedActivationFunctionType::kRelu6 ||
3928 Ac == FusedActivationFunctionType::kRelu1,
3929 "");
3930 if (Ac == FusedActivationFunctionType::kNone) {
3931 TFLITE_DCHECK_EQ(output_activation_min, 0);
3932 TFLITE_DCHECK_EQ(output_activation_max, 255);
3933 }
3934 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
3935 pad_height, filter_width, filter_height, output_activation_min,
3936 output_activation_max, output_data, output_dims);
3937 }
3938
3939 // legacy, for compatibility with old checked-in code
3940 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)3941 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
3942 int pad_width, int pad_height, int filter_width, int filter_height,
3943 int32 output_activation_min, int32 output_activation_max,
3944 uint8* output_data, const Dims<4>& output_dims) {
3945 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3946 filter_width, filter_height, output_activation_min,
3947 output_activation_max, output_data, output_dims);
3948 }
3949
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)3950 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
3951 int stride_width, int stride_height, int pad_width,
3952 int pad_height, int filter_width, int filter_height,
3953 float output_activation_min, float output_activation_max,
3954 float* output_data, const Dims<4>& output_dims) {
3955 PoolParams params;
3956 params.stride_height = stride_height;
3957 params.stride_width = stride_width;
3958 params.filter_height = filter_height;
3959 params.filter_width = filter_width;
3960 params.padding_values.height = pad_height;
3961 params.padding_values.width = pad_width;
3962 params.float_activation_min = output_activation_min;
3963 params.float_activation_max = output_activation_max;
3964 L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
3965 output_data);
3966 }
3967
3968 // legacy, for compatibility with old checked-in code
3969 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3970 void L2Pool(const float* input_data, const Dims<4>& input_dims,
3971 int stride_width, int stride_height, int pad_width, int pad_height,
3972 int filter_width, int filter_height, float* output_data,
3973 const Dims<4>& output_dims) {
3974 float output_activation_min, output_activation_max;
3975 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
3976 L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
3977 pad_height, filter_width, filter_height, output_activation_min,
3978 output_activation_max, output_data, output_dims);
3979 }
3980
3981 // legacy, for compatibility with old checked-in code
3982 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)3983 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
3984 int pad_width, int pad_height, int filter_width, int filter_height,
3985 float* output_data, const Dims<4>& output_dims) {
3986 L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
3987 filter_width, filter_height, output_data, output_dims);
3988 }
3989
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3990 inline void Softmax(const SoftmaxParams& params,
3991 const RuntimeShape& input_shape, const uint8* input_data,
3992 const RuntimeShape& output_shape, uint8* output_data) {
3993 const int32 input_beta_multiplier = params.input_multiplier;
3994 const int32 input_beta_left_shift = params.input_left_shift;
3995 const int diff_min = params.diff_min;
3996 // The representation chosen for the input to the exp() function is Q5.26.
3997 // We need to leave extra space since values that we skip might be as large as
3998 // -32 before multiplying by input_beta_multiplier, and therefore as large as
3999 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
4000 // accumulation, but exp(-16) definitely is.
4001 static const int kScaledDiffIntegerBits = 5;
4002 static const int kAccumulationIntegerBits = 12;
4003 using FixedPointScaledDiff =
4004 gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4005 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4006 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4007
4008 ruy::profiler::ScopeLabel label("Softmax/8bit");
4009 const int trailing_dim = input_shape.DimensionsCount() - 1;
4010 const int outer_size =
4011 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4012 const int depth =
4013 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4014
4015 for (int b = 0; b < outer_size; ++b) {
4016 const uint8* input_data_ptr = input_data + b * depth;
4017 uint8* output_data_ptr = output_data + b * depth;
4018
4019 // Determine the largest entry in the current row
4020 uint8 max_in_row = 0;
4021 {
4022 int c = 0;
4023 #ifdef USE_NEON
4024 uint8x16_t max16_0 = vdupq_n_u8(0);
4025 uint8x16_t max16_1 = vdupq_n_u8(0);
4026 for (; c <= depth - 32; c += 32) {
4027 max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4028 max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4029 }
4030 uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4031 if (c <= depth - 16) {
4032 max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4033 c += 16;
4034 }
4035 uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4036 if (c <= depth - 8) {
4037 max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4038 c += 8;
4039 }
4040 uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4041 uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4042 uint8x8_t max1 = vpmax_u8(max2, max2);
4043 max_in_row = vget_lane_u8(max1, 0);
4044 #endif
4045 for (; c < depth; ++c) {
4046 max_in_row = std::max(max_in_row, input_data_ptr[c]);
4047 }
4048 }
4049
4050 #ifdef USE_NEON
4051 using FixedPointAccumInt32x4 =
4052 gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4053 using FixedPointScaledDiffInt32x4 =
4054 gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4055 using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4056 FixedPoint0Int32x4 input_beta_multiplier_f0 =
4057 FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4058 int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4059 #endif
4060
4061 // Compute the sum of exponentials of the differences of entries in the
4062 // current row from the largest entry in the current row.
4063 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4064 {
4065 int c = 0;
4066 #ifdef USE_NEON
4067 int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4068 FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4069 FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4070 FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4071 for (; c <= depth - 8; c += 8) {
4072 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4073 int16x8_t input_diff_s16 =
4074 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4075 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4076 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4077 int32x4_t mask_0 =
4078 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4079 int32x4_t mask_1 =
4080 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4081 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4082 input_beta_multiplier_f0 *
4083 FixedPointScaledDiffInt32x4::FromRaw(
4084 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4085 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4086 input_beta_multiplier_f0 *
4087 FixedPointScaledDiffInt32x4::FromRaw(
4088 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4089 FixedPointAccumInt32x4 exps_0 =
4090 gemmlowp::Rescale<kAccumulationIntegerBits>(
4091 exp_on_negative_values(scaled_diff_0));
4092 FixedPointAccumInt32x4 exps_1 =
4093 gemmlowp::Rescale<kAccumulationIntegerBits>(
4094 exp_on_negative_values(scaled_diff_1));
4095 FixedPointAccumInt32x4 masked_exps_0 =
4096 SelectUsingMask(mask_0, exps_0, zeros);
4097 FixedPointAccumInt32x4 masked_exps_1 =
4098 SelectUsingMask(mask_1, exps_1, zeros);
4099 sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4100 sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4101 }
4102 int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4103 int32x2_t sum_of_exps_reduced_2 =
4104 vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4105 vget_high_s32(sum_of_exps_reduced_4));
4106 int32x2_t sum_of_exps_reduced_1 =
4107 vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4108 sum_of_exps =
4109 FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4110 #endif
4111 for (; c < depth; ++c) {
4112 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4113 if (input_diff >= diff_min) {
4114 const int32 input_diff_rescaled =
4115 MultiplyByQuantizedMultiplierGreaterThanOne(
4116 input_diff, input_beta_multiplier, input_beta_left_shift);
4117 const FixedPointScaledDiff scaled_diff_f8 =
4118 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4119 sum_of_exps =
4120 sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4121 exp_on_negative_values(scaled_diff_f8));
4122 }
4123 }
4124 }
4125
4126 // Compute the fixed-point multiplier and shift that we need to apply to
4127 // perform a division by the above-computed sum-of-exponentials.
4128 int num_bits_over_unit = 0;
4129 FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
4130 sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
4131
4132 // Compute the quotients of exponentials of differences of entries in the
4133 // current row from the largest entry, over the previously-computed sum of
4134 // exponentials.
4135 {
4136 int c = 0;
4137 #ifdef USE_NEON
4138 int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4139 for (; c <= depth - 8; c += 8) {
4140 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4141 int16x8_t input_diff_s16 =
4142 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4143 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4144 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4145 uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4146 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4147 input_beta_multiplier_f0 *
4148 FixedPointScaledDiffInt32x4::FromRaw(
4149 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4150 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4151 input_beta_multiplier_f0 *
4152 FixedPointScaledDiffInt32x4::FromRaw(
4153 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4154 FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4155 FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4156 int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4157 vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4158 num_bits_over_unit + 31 - 8);
4159 int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4160 vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4161 num_bits_over_unit + 31 - 8);
4162 int16x8_t output_s16 =
4163 vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4164 uint8x8_t output_u8 = vqmovun_s16(output_s16);
4165 uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4166 vst1_u8(output_data_ptr + c, masked_output);
4167 }
4168 #endif
4169 for (; c < depth; ++c) {
4170 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4171 if (input_diff >= diff_min) {
4172 const int32 input_diff_rescaled =
4173 MultiplyByQuantizedMultiplierGreaterThanOne(
4174 input_diff, input_beta_multiplier, input_beta_left_shift);
4175 const FixedPointScaledDiff scaled_diff_f8 =
4176 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4177
4178 FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4179 int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4180 (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4181
4182 output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4183
4184 } else {
4185 output_data_ptr[c] = 0;
4186 }
4187 }
4188 }
4189 }
4190 }
4191
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)4192 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
4193 float beta, float* output_data,
4194 const RuntimeShape& output_shape) {
4195 SoftmaxParams params;
4196 params.beta = beta;
4197 Softmax(params, input_shape, input_data, output_shape, output_data);
4198 }
4199
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)4200 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
4201 float beta, float* output_data,
4202 const Dims<4>& output_dims) {
4203 Softmax(input_data, DimsToShape(input_dims), beta, output_data,
4204 DimsToShape(output_dims));
4205 }
4206
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4207 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
4208 int32 input_beta_multiplier, int32 input_beta_left_shift,
4209 int diff_min, uint8* output_data,
4210 const RuntimeShape& output_shape) {
4211 SoftmaxParams params;
4212 params.input_multiplier = input_beta_multiplier;
4213 params.input_left_shift = input_beta_left_shift;
4214 params.diff_min = diff_min;
4215 Softmax(params, input_shape, input_data, output_shape, output_data);
4216 }
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4217 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
4218 int32 input_beta_multiplier, int32 input_beta_left_shift,
4219 int diff_min, uint8* output_data,
4220 const Dims<4>& output_dims) {
4221 Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
4222 input_beta_left_shift, diff_min, output_data,
4223 DimsToShape(output_dims));
4224 }
4225
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)4226 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
4227 float* output_data, const RuntimeShape& output_shape) {
4228 SoftmaxParams params;
4229 // No params currently used for float LogSoftmax.
4230 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
4231 }
4232
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4233 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
4234 float* output_data, const Dims<4>& output_dims) {
4235 LogSoftmax(input_data, DimsToShape(input_dims), output_data,
4236 DimsToShape(output_dims));
4237 }
4238
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)4239 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
4240 int32 input_multiplier, int32 input_left_shift,
4241 int32 reverse_scaling_divisor,
4242 int32 reverse_scaling_right_shift, int diff_min,
4243 uint8* output_data, const RuntimeShape& output_shape) {
4244 SoftmaxParams params;
4245 params.input_multiplier = input_multiplier;
4246 params.input_left_shift = input_left_shift;
4247 params.reverse_scaling_divisor = reverse_scaling_divisor;
4248 params.reverse_scaling_right_shift = reverse_scaling_right_shift;
4249 params.diff_min = diff_min;
4250 reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4251 output_data);
4252 }
4253
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)4254 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
4255 int32 input_multiplier, int32 input_left_shift,
4256 int32 reverse_scaling_divisor,
4257 int32 reverse_scaling_right_shift, int diff_min,
4258 uint8* output_data, const Dims<4>& output_dims) {
4259 reference_ops::LogSoftmax(
4260 input_data, DimsToShape(input_dims), input_multiplier, input_left_shift,
4261 reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
4262 output_data, DimsToShape(output_dims));
4263 }
4264
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4265 inline void Logistic(const LogisticParams& params,
4266 const RuntimeShape& input_shape, const uint8* input_data,
4267 const RuntimeShape& output_shape, uint8* output_data) {
4268 ruy::profiler::ScopeLabel label("Logistic/Uint8");
4269 const int32 input_zero_point = params.input_zero_point;
4270 const int32 input_range_radius = params.input_range_radius;
4271 const int32 input_multiplier = params.input_multiplier;
4272 const int input_left_shift = params.input_left_shift;
4273 const int size = MatchingFlatSize(input_shape, output_shape);
4274
4275 int c = 0;
4276 #ifdef USE_NEON
4277 // Handle 16 values at a time
4278 for (; c <= size - 16; c += 16) {
4279 // Read input uint8 values, cast to int16 and subtract input_zero_point
4280 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4281 int16x8_t input_val_centered_0 =
4282 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4283 vdupq_n_s16(input_zero_point));
4284 int16x8_t input_val_centered_1 =
4285 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4286 vdupq_n_s16(input_zero_point));
4287
4288 // Prepare the bit masks that we will use at the end to implement the logic
4289 // that was expressed in the scalar code with branching:
4290 // if (input_val_centered < -input_range_radius) {
4291 // output_val = 0;
4292 // } else if (input_val_centered > input_range_radius) {
4293 // output_val = 255;
4294 // } else {
4295 // ...
4296 uint16x8_t mask_rightclamp_0 =
4297 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4298 uint16x8_t mask_rightclamp_1 =
4299 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4300 uint16x8_t mask_leftclamp_0 =
4301 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4302 uint16x8_t mask_leftclamp_1 =
4303 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4304 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4305 vshrn_n_u16(mask_rightclamp_1, 8));
4306 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4307 vshrn_n_u16(mask_leftclamp_1, 8));
4308
4309 // This performs what is expressed in the scalar code as
4310 // const int32 input_val_rescaled =
4311 // MultiplyByQuantizedMultiplierGreaterThanOne(
4312 // input_val_centered, input_multiplier, input_left_shift);
4313 int32x4_t input_val_rescaled_0 =
4314 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4315 vdupq_n_s32(input_left_shift));
4316 int32x4_t input_val_rescaled_1 =
4317 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4318 vdupq_n_s32(input_left_shift));
4319 int32x4_t input_val_rescaled_2 =
4320 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4321 vdupq_n_s32(input_left_shift));
4322 int32x4_t input_val_rescaled_3 =
4323 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4324 vdupq_n_s32(input_left_shift));
4325 input_val_rescaled_0 =
4326 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4327 input_val_rescaled_1 =
4328 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4329 input_val_rescaled_2 =
4330 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4331 input_val_rescaled_3 =
4332 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4333
4334 // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4335 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4336 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4337 const FixedPoint4 input_val_f4_0 =
4338 FixedPoint4::FromRaw(input_val_rescaled_0);
4339 const FixedPoint4 input_val_f4_1 =
4340 FixedPoint4::FromRaw(input_val_rescaled_1);
4341 const FixedPoint4 input_val_f4_2 =
4342 FixedPoint4::FromRaw(input_val_rescaled_2);
4343 const FixedPoint4 input_val_f4_3 =
4344 FixedPoint4::FromRaw(input_val_rescaled_3);
4345 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4346 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4347 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4348 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4349
4350 // Divide by 2^23 as in the scalar code
4351 using gemmlowp::RoundingDivideByPOT;
4352 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4353 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4354 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4355 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4356
4357 // Cast output values to uint8, saturating
4358 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4359 vqmovn_s32(output_val_s32_1));
4360 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4361 vqmovn_s32(output_val_s32_3));
4362 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4363 vqmovun_s16(output_val_s16_1));
4364
4365 // Perform the bit-masking with the bit masks computed at the beginning,
4366 // see the comment there.
4367 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4368 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4369
4370 // Store back to memory
4371 vst1q_u8(output_data + c, output_val_u8);
4372 }
4373 #endif
4374 // Leftover loop: handle one value at a time with scalar code.
4375 for (; c < size; ++c) {
4376 const uint8 input_val_u8 = input_data[c];
4377 const int32 input_val_centered =
4378 static_cast<int32>(input_val_u8) - input_zero_point;
4379 uint8 output_val;
4380 if (input_val_centered < -input_range_radius) {
4381 output_val = 0;
4382 } else if (input_val_centered > input_range_radius) {
4383 output_val = 255;
4384 } else {
4385 const int32 input_val_rescaled =
4386 MultiplyByQuantizedMultiplierGreaterThanOne(
4387 input_val_centered, input_multiplier, input_left_shift);
4388 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4389 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4390 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4391 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4392 using gemmlowp::RoundingDivideByPOT;
4393 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4394 if (output_val_s32 == 256) {
4395 output_val_s32 = 255;
4396 }
4397 TFLITE_DCHECK_GE(output_val_s32, 0);
4398 TFLITE_DCHECK_LE(output_val_s32, 255);
4399 output_val = static_cast<uint8>(output_val_s32);
4400 }
4401 output_data[c] = output_val;
4402 }
4403 }
4404
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4405 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
4406 int32 input_zero_point, int32 input_range_radius,
4407 int32 input_multiplier, int input_left_shift,
4408 uint8* output_data, const RuntimeShape& output_shape) {
4409 LogisticParams params;
4410 params.input_zero_point = input_zero_point;
4411 params.input_range_radius = input_range_radius;
4412 params.input_multiplier = input_multiplier;
4413 params.input_left_shift = input_left_shift;
4414 Logistic(params, input_shape, input_data, output_shape, output_data);
4415 }
4416
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4417 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
4418 float* output_data, const Dims<4>& output_dims) {
4419 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4420 output_data);
4421 }
4422
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4423 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
4424 int32 input_zero_point, int32 input_range_radius,
4425 int32 input_multiplier, int input_left_shift,
4426 uint8* output_data, const Dims<4>& output_dims) {
4427 Logistic(input_data, DimsToShape(input_dims), input_zero_point,
4428 input_range_radius, input_multiplier, input_left_shift, output_data,
4429 DimsToShape(output_dims));
4430 }
4431
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4432 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
4433 const RuntimeShape& output_shape, int16* output_data) {
4434 LogisticParams params;
4435 // No params currently needed by int16 Logistic.
4436 Logistic(params, input_shape, input_data, output_shape, output_data);
4437 }
4438
Logistic(const int16 * input_data,const RuntimeShape & input_shape,int16 * output_data,const RuntimeShape & output_shape)4439 inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
4440 int16* output_data, const RuntimeShape& output_shape) {
4441 LogisticParams params;
4442 // No params currently needed by int16 Logistic.
4443 Logistic(params, input_shape, input_data, output_shape, output_data);
4444 }
4445
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)4446 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
4447 int16* output_data, const Dims<4>& output_dims) {
4448 Logistic(input_data, DimsToShape(input_dims), output_data,
4449 DimsToShape(output_dims));
4450 }
4451
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4452 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
4453 float* output_data, const Dims<4>& output_dims) {
4454 Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4455 output_data);
4456 }
4457
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4458 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4459 const uint8* input_data, const RuntimeShape& output_shape,
4460 uint8* output_data) {
4461 // Note that this is almost the exact same code as in Logistic().
4462 ruy::profiler::ScopeLabel label("Tanh");
4463 const int32 input_zero_point = params.input_zero_point;
4464 const int32 input_range_radius = params.input_range_radius;
4465 const int32 input_multiplier = params.input_multiplier;
4466 const int input_left_shift = params.input_left_shift;
4467 const int size = MatchingFlatSize(input_shape, output_shape);
4468
4469 int c = 0;
4470 int32_t output_zero_point = 128;
4471 #ifdef USE_NEON
4472 // Handle 16 values at a time
4473 for (; c <= size - 16; c += 16) {
4474 // Read input uint8 values, cast to int16 and subtract input_zero_point
4475 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4476 int16x8_t input_val_centered_0 =
4477 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4478 vdupq_n_s16(input_zero_point));
4479 int16x8_t input_val_centered_1 =
4480 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4481 vdupq_n_s16(input_zero_point));
4482
4483 // Prepare the bit masks that we will use at the end to implement the logic
4484 // that was expressed in the scalar code with branching:
4485 // if (input_val_centered < -input_range_radius) {
4486 // output_val = 0;
4487 // } else if (input_val_centered > input_range_radius) {
4488 // output_val = 255;
4489 // } else {
4490 // ...
4491 uint16x8_t mask_rightclamp_0 =
4492 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4493 uint16x8_t mask_rightclamp_1 =
4494 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4495 uint16x8_t mask_leftclamp_0 =
4496 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4497 uint16x8_t mask_leftclamp_1 =
4498 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4499 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4500 vshrn_n_u16(mask_rightclamp_1, 8));
4501 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4502 vshrn_n_u16(mask_leftclamp_1, 8));
4503
4504 // This performs what is expressed in the scalar code as
4505 // const int32 input_val_rescaled =
4506 // MultiplyByQuantizedMultiplierGreaterThanOne(
4507 // input_val_centered, input_multiplier, input_left_shift);
4508 int32x4_t input_val_rescaled_0 =
4509 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4510 vdupq_n_s32(input_left_shift));
4511 int32x4_t input_val_rescaled_1 =
4512 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4513 vdupq_n_s32(input_left_shift));
4514 int32x4_t input_val_rescaled_2 =
4515 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4516 vdupq_n_s32(input_left_shift));
4517 int32x4_t input_val_rescaled_3 =
4518 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4519 vdupq_n_s32(input_left_shift));
4520 input_val_rescaled_0 =
4521 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4522 input_val_rescaled_1 =
4523 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4524 input_val_rescaled_2 =
4525 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4526 input_val_rescaled_3 =
4527 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4528
4529 // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
4530 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4531 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4532 const FixedPoint4 input_val_f4_0 =
4533 FixedPoint4::FromRaw(input_val_rescaled_0);
4534 const FixedPoint4 input_val_f4_1 =
4535 FixedPoint4::FromRaw(input_val_rescaled_1);
4536 const FixedPoint4 input_val_f4_2 =
4537 FixedPoint4::FromRaw(input_val_rescaled_2);
4538 const FixedPoint4 input_val_f4_3 =
4539 FixedPoint4::FromRaw(input_val_rescaled_3);
4540 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
4541 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
4542 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
4543 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
4544
4545 // Divide by 2^24 as in the scalar code
4546 using gemmlowp::RoundingDivideByPOT;
4547 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
4548 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
4549 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
4550 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
4551
4552 // Add the output zero point
4553 int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
4554 output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
4555 output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
4556 output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
4557 output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
4558
4559 // Cast output values to uint8, saturating
4560 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4561 vqmovn_s32(output_val_s32_1));
4562 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4563 vqmovn_s32(output_val_s32_3));
4564 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4565 vqmovun_s16(output_val_s16_1));
4566
4567 // Perform the bit-masking with the bit masks computed at the beginning,
4568 // see the comment there.
4569 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4570 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4571
4572 // Store back to memory
4573 vst1q_u8(output_data + c, output_val_u8);
4574 }
4575 #endif
4576 // Leftover loop: handle one value at a time with scalar code.
4577 for (; c < size; ++c) {
4578 const uint8 input_val_u8 = input_data[c];
4579 const int32 input_val_centered =
4580 static_cast<int32>(input_val_u8) - input_zero_point;
4581 uint8 output_val;
4582 if (input_val_centered < -input_range_radius) {
4583 output_val = 0;
4584 } else if (input_val_centered > input_range_radius) {
4585 output_val = 255;
4586 } else {
4587 const int32 input_val_rescaled =
4588 MultiplyByQuantizedMultiplierGreaterThanOne(
4589 input_val_centered, input_multiplier, input_left_shift);
4590 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4591 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4592 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4593 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
4594 using gemmlowp::RoundingDivideByPOT;
4595 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
4596 output_val_s32 += output_zero_point;
4597 if (output_val_s32 == 256) {
4598 output_val_s32 = 255;
4599 }
4600 TFLITE_DCHECK_GE(output_val_s32, 0);
4601 TFLITE_DCHECK_LE(output_val_s32, 255);
4602 output_val = static_cast<uint8>(output_val_s32);
4603 }
4604 output_data[c] = output_val;
4605 }
4606 }
4607
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)4608 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
4609 int32 input_zero_point, int32 input_range_radius,
4610 int32 input_multiplier, int input_left_shift,
4611 uint8* output_data, const RuntimeShape& output_shape) {
4612 TanhParams params;
4613 params.input_zero_point = input_zero_point;
4614 params.input_range_radius = input_range_radius;
4615 params.input_multiplier = input_multiplier;
4616 params.input_left_shift = input_left_shift;
4617 Tanh(params, input_shape, input_data, output_shape, output_data);
4618 }
4619
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)4620 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
4621 int32 input_zero_point, int32 input_range_radius,
4622 int32 input_multiplier, int input_left_shift,
4623 uint8* output_data, const Dims<4>& output_dims) {
4624 Tanh(input_data, DimsToShape(input_dims), input_zero_point,
4625 input_range_radius, input_multiplier, input_left_shift, output_data,
4626 DimsToShape(output_dims));
4627 }
4628
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)4629 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
4630 int input_left_shift, int16* output_data,
4631 const RuntimeShape& output_shape) {
4632 TanhParams params;
4633 params.input_left_shift = input_left_shift;
4634 Tanh(params, input_shape, input_data, output_shape, output_data);
4635 }
4636
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)4637 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
4638 int input_left_shift, int16* output_data,
4639 const Dims<4>& output_dims) {
4640 Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
4641 DimsToShape(output_dims));
4642 }
4643
4644 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4645 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
4646 int block_size, T* output_data,
4647 const Dims<4>& output_dims) {
4648 tflite::DepthToSpaceParams op_params;
4649 op_params.block_size = block_size;
4650
4651 DepthToSpace(op_params, DimsToShape(input_dims), input_data,
4652 DimsToShape(output_dims), output_data);
4653 }
4654
4655 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)4656 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
4657 int block_size, T* output_data,
4658 const Dims<4>& output_dims) {
4659 tflite::SpaceToDepthParams op_params;
4660 op_params.block_size = block_size;
4661
4662 SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
4663 DimsToShape(output_dims), output_data);
4664 }
4665
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)4666 inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
4667 const float* input2_data, const Dims<4>& input2_dims,
4668 float output_activation_min, float output_activation_max,
4669 float* output_data, const Dims<4>& output_dims) {
4670 tflite::ArithmeticParams op_params;
4671 op_params.float_activation_min = output_activation_min;
4672 op_params.float_activation_max = output_activation_max;
4673
4674 Mul(op_params, DimsToShape(input1_dims), input1_data,
4675 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4676 output_data);
4677 }
4678
4679 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4680 void Mul(const float* input1_data, const Dims<4>& input1_dims,
4681 const float* input2_data, const Dims<4>& input2_dims,
4682 float* output_data, const Dims<4>& output_dims) {
4683 float output_activation_min, output_activation_max;
4684 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
4685
4686 Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
4687 output_activation_max, output_data, output_dims);
4688 }
4689
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 output_activation_min,int32 output_activation_max,int32 * output_data,const Dims<4> & output_dims)4690 inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4691 const int32* input2_data, const Dims<4>& input2_dims,
4692 int32 output_activation_min, int32 output_activation_max,
4693 int32* output_data, const Dims<4>& output_dims) {
4694 tflite::ArithmeticParams op_params;
4695 op_params.quantized_activation_min = output_activation_min;
4696 op_params.quantized_activation_max = output_activation_max;
4697
4698 Mul(op_params, DimsToShape(input1_dims), input1_data,
4699 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4700 output_data);
4701 }
4702
4703 template <FusedActivationFunctionType Ac>
Mul(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)4704 void Mul(const int32* input1_data, const Dims<4>& input1_dims,
4705 const int32* input2_data, const Dims<4>& input2_dims,
4706 int32* output_data, const Dims<4>& output_dims) {
4707 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
4708 tflite::ArithmeticParams op_params;
4709 // No parameters needed.
4710
4711 MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
4712 DimsToShape(input2_dims), input2_data,
4713 DimsToShape(output_dims), output_data);
4714 }
4715
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)4716 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4717 const int16* input2_data, const Dims<4>& input2_dims,
4718 int16* output_data, const Dims<4>& output_dims) {
4719 tflite::ArithmeticParams op_params;
4720 // No parameters needed.
4721
4722 Mul(op_params, DimsToShape(input1_dims), input1_data,
4723 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4724 output_data);
4725 }
4726
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)4727 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
4728 const int16* input2_data, const Dims<4>& input2_dims,
4729 int32 output_offset, int32 output_activation_min,
4730 int32 output_activation_max, uint8* output_data,
4731 const Dims<4>& output_dims) {
4732 tflite::ArithmeticParams op_params;
4733 op_params.output_offset = output_offset;
4734 op_params.quantized_activation_min = output_activation_min;
4735 op_params.quantized_activation_max = output_activation_max;
4736
4737 Mul(op_params, DimsToShape(input1_dims), input1_data,
4738 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
4739 output_data);
4740 }
4741
4742 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)4743 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
4744 const T* input2_data, const Dims<4>& input2_dims,
4745 T output_activation_min, T output_activation_max,
4746 T* output_data, const Dims<4>& output_dims) {
4747 tflite::ArithmeticParams op_params;
4748 SetActivationParams(output_activation_min, output_activation_max, &op_params);
4749
4750 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4751 DimsToShape(input2_dims), input2_data,
4752 DimsToShape(output_dims), output_data);
4753 }
4754
4755 // For compatibility with old checked-in code
4756 template <FusedActivationFunctionType Ac>
BroadcastMul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)4757 inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
4758 const float* input2_data, const Dims<4>& input2_dims,
4759 float* output_data, const Dims<4>& output_dims) {
4760 tflite::ArithmeticParams op_params;
4761 float float_activation_min;
4762 float float_activation_max;
4763 GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
4764 SetActivationParams(float_activation_min, float_activation_max, &op_params);
4765
4766 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
4767 DimsToShape(input2_dims), input2_data,
4768 DimsToShape(output_dims), output_data);
4769 }
4770
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)4771 inline void LocalResponseNormalization(const float* input_data,
4772 const Dims<4>& input_dims, int range,
4773 float bias, float alpha, float beta,
4774 float* output_data,
4775 const Dims<4>& output_dims) {
4776 tflite::LocalResponseNormalizationParams op_params;
4777 op_params.range = range;
4778 op_params.bias = bias;
4779 op_params.alpha = alpha;
4780 op_params.beta = beta;
4781
4782 LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
4783 DimsToShape(output_dims), output_data);
4784 }
4785
4786 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)4787 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
4788 const Dims<4>& output_dims) {
4789 Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4790 output_data);
4791 }
4792
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)4793 inline void Floor(const float* input_data, const Dims<4>& input_dims,
4794 float* output_data, const Dims<4>& output_dims) {
4795 Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
4796 output_data);
4797 }
4798
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims,bool align_corners)4799 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4800 const int32* output_size_data,
4801 const Dims<4>& output_size_dims, float* output_data,
4802 const Dims<4>& output_dims, bool align_corners) {
4803 tflite::ResizeBilinearParams op_params;
4804 op_params.align_corners = align_corners;
4805 op_params.half_pixel_centers = false;
4806 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4807 DimsToShape(output_size_dims), output_size_data,
4808 DimsToShape(output_dims), output_data);
4809 }
4810
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims,bool align_corners)4811 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4812 const int32* output_size_data,
4813 const Dims<4>& output_size_dims, uint8* output_data,
4814 const Dims<4>& output_dims, bool align_corners) {
4815 tflite::ResizeBilinearParams op_params;
4816 op_params.align_corners = align_corners;
4817 op_params.half_pixel_centers = false;
4818 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
4819 DimsToShape(output_size_dims), output_size_data,
4820 DimsToShape(output_dims), output_data);
4821 }
4822
4823 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)4824 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
4825 const int32* output_size_data,
4826 const Dims<4>& output_size_dims, float* output_data,
4827 const Dims<4>& output_dims) {
4828 ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4829 output_data, output_dims, /*align_corners=*/false);
4830 }
4831
4832 // legacy, for compatibility with old checked-in code
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)4833 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
4834 const int32* output_size_data,
4835 const Dims<4>& output_size_dims, uint8* output_data,
4836 const Dims<4>& output_dims) {
4837 ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
4838 output_data, output_dims, /*align_corners=*/false);
4839 }
4840
4841 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)4842 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
4843 const int32* block_shape_data,
4844 const Dims<4>& block_shape_dims,
4845 const int32* crops_data, const Dims<4>& crops_dims,
4846 T* output_data, const Dims<4>& output_dims) {
4847 BatchToSpaceND(DimsToShape(input_dims), input_data,
4848 DimsToShape(block_shape_dims), block_shape_data,
4849 DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
4850 output_data);
4851 }
4852
4853 // Legacy signature, function covered both Pad and PadV2.
4854 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)4855 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
4856 const std::vector<int>& left_paddings,
4857 const std::vector<int>& right_paddings, T* output_data,
4858 const Dims<4>& output_dims, const T pad_value) {
4859 TFLITE_DCHECK_EQ(left_paddings.size(), 4);
4860 TFLITE_DCHECK_EQ(right_paddings.size(), 4);
4861 tflite::PadParams op_params;
4862 op_params.left_padding_count = 4;
4863 op_params.right_padding_count = 4;
4864 for (int i = 0; i < 4; ++i) {
4865 op_params.left_padding[i] = left_paddings[3 - i];
4866 op_params.right_padding[i] = right_paddings[3 - i];
4867 }
4868 const T pad_value_copy = pad_value;
4869
4870 Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
4871 DimsToShape(output_dims), output_data);
4872 }
4873
4874 // Old Pad that calls legacy PadV2.
4875 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)4876 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4877 const std::vector<int>& left_paddings,
4878 const std::vector<int>& right_paddings, T* output_data,
4879 const Dims<4>& output_dims, const int32_t pad_value) {
4880 const T converted_pad_value = static_cast<T>(pad_value);
4881 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4882 output_dims, converted_pad_value);
4883 }
4884
4885 // Old Pad that only padded with 0.
4886 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)4887 inline void Pad(const T* input_data, const Dims<4>& input_dims,
4888 const std::vector<int>& left_paddings,
4889 const std::vector<int>& right_paddings, T* output_data,
4890 const Dims<4>& output_dims) {
4891 const T pad_value = static_cast<T>(0);
4892 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
4893 output_dims, pad_value);
4894 }
4895
4896 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)4897 inline void Slice(const T* input_data, const Dims<4>& input_dims,
4898 const std::vector<int>& begin, const std::vector<int>& size,
4899 T* output_data, const Dims<4>& output_dims) {
4900 tflite::SliceParams op_params;
4901 op_params.begin_count = 4;
4902 op_params.size_count = 4;
4903 for (int i = 0; i < 4; ++i) {
4904 op_params.begin[i] = begin[3 - i];
4905 op_params.size[i] = size[3 - i];
4906 }
4907
4908 Slice(op_params, DimsToShape(input_dims), input_data,
4909 DimsToShape(output_dims), output_data);
4910 }
4911
4912 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4913 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
4914 const T* input2_data, T* output_data,
4915 const Dims<4>& output_dims) {
4916 Minimum(DimsToShape(input1_dims), input1_data, input2_data,
4917 DimsToShape(output_dims), output_data);
4918 }
4919
4920 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)4921 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
4922 const T* input2_data, T* output_data,
4923 const Dims<4>& output_dims) {
4924 Maximum(DimsToShape(input1_dims), input1_data, input2_data,
4925 DimsToShape(output_dims), output_data);
4926 }
4927
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)4928 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
4929 int32 zero_point, double scale, float* output_data,
4930 const Dims<4>& output_dims) {
4931 tflite::DequantizationParams op_params;
4932 op_params.zero_point = zero_point;
4933 op_params.scale = scale;
4934
4935 Dequantize(op_params, DimsToShape(input_dims), input_data,
4936 DimsToShape(output_dims), output_data);
4937 }
4938
4939 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)4940 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
4941 const Dims<4>& output_dims, const int* permuted_axes) {
4942 TransposeParams params;
4943 params.perm_count = 4;
4944 for (int i = 0; i < 4; ++i) {
4945 params.perm[i] = 3 - permuted_axes[3 - i];
4946 }
4947 Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
4948 output);
4949 }
4950
4951 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)4952 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
4953 int begin_mask, int end_mask, int shrink_axis_mask,
4954 const std::vector<int>& start_indices,
4955 const std::vector<int>& stop_indices,
4956 const std::vector<int>& strides, T* output_data,
4957 const Dims<4>& output_dims) {
4958 TFLITE_DCHECK_EQ(start_indices.size(), 4);
4959 auto op_params = strided_slice::BuildStridedSliceParams(
4960 begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
4961 strides);
4962 reference_ops::StridedSliceReverseIndices(&op_params);
4963
4964 StridedSlice(op_params, DimsToShape(input_dims), input_data,
4965 DimsToShape(output_dims), output_data);
4966 }
4967
4968 } // namespace optimized_ops
4969 } // namespace tflite
4970 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
4971