1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38
39 #include "Eigen/Core"
40 #include "fixedpoint/fixedpoint.h"
41 #include "ruy/profiler/instrumentation.h" // from @ruy
42 #include "tensorflow/lite/c/common.h"
43 #include "tensorflow/lite/kernels/cpu_backend_context.h"
44 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
46 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
47 #include "tensorflow/lite/kernels/internal/cppmath.h"
48 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
49 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
50 #include "tensorflow/lite/kernels/internal/quantization_util.h"
51 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
52 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
53 #include "tensorflow/lite/kernels/internal/tensor.h"
54 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
55 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
56 #include "tensorflow/lite/kernels/internal/types.h"
57 #include "unsupported/Eigen/CXX11/Tensor"
58
59 #if __aarch64__ && __clang__
60 #define TFLITE_SOFTMAX_USE_UINT16_LUT
61 #endif
62
63 namespace tflite {
64 namespace optimized_ops {
65
66 // Unoptimized reference ops:
67 using reference_ops::ArgMax;
68 using reference_ops::ArgMinMax;
69 using reference_ops::Broadcast4DSlowGreater;
70 using reference_ops::Broadcast4DSlowGreaterEqual;
71 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
72 using reference_ops::Broadcast4DSlowGreaterWithScaling;
73 using reference_ops::Broadcast4DSlowLess;
74 using reference_ops::Broadcast4DSlowLessEqual;
75 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
76 using reference_ops::Broadcast4DSlowLessWithScaling;
77 using reference_ops::BroadcastAdd4DSlow;
78 using reference_ops::BroadcastMul4DSlow;
79 using reference_ops::BroadcastSub16POTSlow;
80 using reference_ops::BroadcastSubSlow;
81 using reference_ops::Concatenation;
82 using reference_ops::ConcatenationWithScaling;
83 using reference_ops::DepthConcatenation;
84 using reference_ops::Div;
85 using reference_ops::Elu;
86 using reference_ops::FakeQuant;
87 using reference_ops::Fill;
88 using reference_ops::Gather;
89 using reference_ops::Greater;
90 using reference_ops::GreaterEqual;
91 using reference_ops::GreaterEqualWithScaling;
92 using reference_ops::GreaterWithScaling;
93 using reference_ops::LeakyRelu;
94 using reference_ops::Less;
95 using reference_ops::LessEqual;
96 using reference_ops::LessEqualWithScaling;
97 using reference_ops::LessWithScaling;
98 using reference_ops::Mean;
99 using reference_ops::ProcessBroadcastShapes;
100 using reference_ops::RankOneSelect;
101 using reference_ops::Relu1;
102 using reference_ops::Relu6;
103 using reference_ops::ReluX;
104 using reference_ops::Round;
105 using reference_ops::Select;
106 using reference_ops::SpaceToBatchND;
107 using reference_ops::Split;
108 using reference_ops::Sub16;
109
110 // TODO(b/80247582) Remove this constant.
111 // This will be phased out as the shifts are revised with more thought. Use of a
112 // constant enables us to track progress on this work.
113 //
114 // Used to convert from old-style shifts (right) to new-style (left).
115 static constexpr int kReverseShift = -1;
116
117 // Make a local VectorMap typedef allowing to map a float array
118 // as a Eigen vector expression. The std::conditional here is to
119 // construct the suitable Eigen type for the constness of the
120 // data. Indeed, for const data, we need to produce
121 // Eigen::Map<const Eigen::Matrix<float, ...>>
122 // and not the more straightforward
123 // Eigen::Map<Eigen::Matrix<const float, ...>>
124 template <typename Scalar>
125 using VectorMap = typename std::conditional<
126 std::is_const<Scalar>::value,
127 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
128 Eigen::Dynamic, 1>>,
129 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
130
131 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)132 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
133 const int size = shape.FlatSize();
134 return VectorMap<Scalar>(data, size, 1);
135 }
136
137 // Make a local VectorMap typedef allowing to map a float array
138 // as a Eigen matrix expression. The same explanation as for VectorMap
139 // above also applies here.
140 template <typename Scalar>
141 using MatrixMap = typename std::conditional<
142 std::is_const<Scalar>::value,
143 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
144 Eigen::Dynamic, Eigen::Dynamic>>,
145 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
146
147 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)148 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
149 const RuntimeShape& shape) {
150 const int dims_count = shape.DimensionsCount();
151 const int rows = shape.Dims(dims_count - 1);
152 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
153 return MatrixMap<Scalar>(data, rows, cols);
154 }
155
156 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)157 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
158 const RuntimeShape& shape) {
159 const int cols = shape.Dims(0);
160 const int rows = FlatSizeSkipDim(shape, 0);
161 return MatrixMap<Scalar>(data, rows, cols);
162 }
163
164 template <typename Scalar>
165 using ArrayMap = typename std::conditional<
166 std::is_const<Scalar>::value,
167 Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
168 Eigen::Dynamic, Eigen::Dynamic>>,
169 Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
170
171 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)172 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
173 const RuntimeShape& shape) {
174 const int dims_count = shape.DimensionsCount();
175 const int rows = shape.Dims(dims_count - 1);
176 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
177 return ArrayMap<Scalar>(data, rows, cols);
178 }
179
180 // Copied from tensorflow/core/framework/tensor_types.h
181 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
182 struct TTypes {
183 // Rank-1 tensor (vector) of scalar type T.
184 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
185 Eigen::Aligned>
186 Flat;
187 typedef Eigen::TensorMap<
188 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
189 UnalignedConstMatrix;
190 };
191
192 // TODO(b/62193649): this function is only needed as long
193 // as we have the --variable_batch hack.
194 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)195 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
196 const RuntimeShape& shape,
197 int rows) {
198 const int flatsize = shape.FlatSize();
199 TFLITE_DCHECK_EQ(flatsize % rows, 0);
200 const int cols = flatsize / rows;
201 return MatrixMap<Scalar>(data, rows, cols);
202 }
203
204 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)205 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
206 const RuntimeShape& unswitched_input1_shape,
207 const T* unswitched_input1_data,
208 const RuntimeShape& unswitched_input2_shape,
209 const T* unswitched_input2_data,
210 const RuntimeShape& output_shape,
211 T* output_data, ElementwiseF elementwise_f,
212 ScalarBroadcastF scalar_broadcast_f) {
213 ArithmeticParams switched_params = unswitched_params;
214 switched_params.input1_offset = unswitched_params.input2_offset;
215 switched_params.input1_multiplier = unswitched_params.input2_multiplier;
216 switched_params.input1_shift = unswitched_params.input2_shift;
217 switched_params.input2_offset = unswitched_params.input1_offset;
218 switched_params.input2_multiplier = unswitched_params.input1_multiplier;
219 switched_params.input2_shift = unswitched_params.input1_shift;
220
221 const bool use_unswitched =
222 unswitched_params.broadcast_category ==
223 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
224
225 const ArithmeticParams& params =
226 use_unswitched ? unswitched_params : switched_params;
227 const T* input1_data =
228 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
229 const T* input2_data =
230 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
231
232 // Fivefold nested loops. The second input resets its position for each
233 // iteration of the second loop. The first input resets its position at the
234 // beginning of the fourth loop. The innermost loop is an elementwise add of
235 // sections of the arrays.
236 T* output_data_ptr = output_data;
237 const T* input1_data_ptr = input1_data;
238 const T* input2_data_reset = input2_data;
239 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
240 // between input shapes. y3 for input 1 is always broadcast, and so the
241 // dimension there is 1, whereas optionally y1 might be broadcast for
242 // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
243 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
244 int y0 = params.broadcast_shape[0];
245 int y1 = params.broadcast_shape[1];
246 int y2 = params.broadcast_shape[2];
247 int y3 = params.broadcast_shape[3];
248 int y4 = params.broadcast_shape[4];
249 if (y4 > 1) {
250 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
251 // dimension.
252 for (int i0 = 0; i0 < y0; ++i0) {
253 const T* input2_data_ptr = nullptr;
254 for (int i1 = 0; i1 < y1; ++i1) {
255 input2_data_ptr = input2_data_reset;
256 for (int i2 = 0; i2 < y2; ++i2) {
257 for (int i3 = 0; i3 < y3; ++i3) {
258 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
259 output_data_ptr);
260 input2_data_ptr += y4;
261 output_data_ptr += y4;
262 }
263 // We have broadcast y4 of input1 data y3 times, and now move on.
264 input1_data_ptr += y4;
265 }
266 }
267 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
268 input2_data_reset = input2_data_ptr;
269 }
270 } else {
271 // Special case of y4 == 1, in which the innermost loop is a single
272 // element and can be combined with the next (y3) as an inner broadcast.
273 //
274 // Note that this handles the case of pure scalar broadcast when
275 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
276 // broadcast with batch (as y2 > 1).
277 //
278 // NOTE The process is the same as the above general case except
279 // simplified for y4 == 1 and the loop over y3 is contained within the
280 // AddScalarBroadcast function.
281 for (int i0 = 0; i0 < y0; ++i0) {
282 const T* input2_data_ptr = nullptr;
283 for (int i1 = 0; i1 < y1; ++i1) {
284 input2_data_ptr = input2_data_reset;
285 for (int i2 = 0; i2 < y2; ++i2) {
286 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
287 output_data_ptr);
288 input2_data_ptr += y3;
289 output_data_ptr += y3;
290 input1_data_ptr += 1;
291 }
292 }
293 input2_data_reset = input2_data_ptr;
294 }
295 }
296 }
297
298 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
299
300 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)301 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
302 uint8x16_t indices) {
303 // Look up in 1st quarter of the table: top 2 bits of indices == 00
304 uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
305 // Look up in 2nd quarter of the table: top 2 bits of indices == 01
306 uint8x16_t output2 =
307 vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
308 // Look up in 3rd quarter of the table: top 2 bits of indices == 10
309 uint8x16_t output3 =
310 vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
311 // Look up in 4th quarter of the table: top 2 bits of indices == 11
312 uint8x16_t output4 =
313 vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
314
315 // Combine result of the 4 lookups.
316 return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
317 }
318
319 #endif
320
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)321 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
322 float output_activation_max,
323 const RuntimeShape& bias_shape,
324 const float* bias_data,
325 const RuntimeShape& array_shape,
326 float* array_data) {
327 BiasAndClamp(output_activation_min, output_activation_max,
328 bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
329 array_data);
330 }
331
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)332 inline void FullyConnected(
333 const FullyConnectedParams& params, const RuntimeShape& input_shape,
334 const float* input_data, const RuntimeShape& weights_shape,
335 const float* weights_data, const RuntimeShape& bias_shape,
336 const float* optional_bias_data, const RuntimeShape& output_shape,
337 float* output_data, CpuBackendContext* cpu_backend_context) {
338 ruy::profiler::ScopeLabel label("FullyConnected");
339 const int dims_count = weights_shape.DimensionsCount();
340 const int input_rows = weights_shape.Dims(dims_count - 1);
341 cpu_backend_gemm::MatrixParams<float> rhs_params;
342 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
343 rhs_params.rows = input_rows;
344 rhs_params.cols = input_shape.FlatSize() / input_rows;
345 rhs_params.cache_policy =
346 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
347 TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
348 cpu_backend_gemm::MatrixParams<float> lhs_params;
349 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
350 lhs_params.cols = weights_shape.Dims(dims_count - 1);
351 lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
352 lhs_params.cache_policy =
353 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
354 cpu_backend_gemm::MatrixParams<float> dst_params;
355 dst_params.order = cpu_backend_gemm::Order::kColMajor;
356 dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
357 dst_params.cols =
358 FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
359 cpu_backend_gemm::GemmParams<float, float> gemm_params;
360 gemm_params.bias = optional_bias_data;
361 gemm_params.clamp_min = params.float_activation_min;
362 gemm_params.clamp_max = params.float_activation_max;
363 cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
364 dst_params, output_data, gemm_params,
365 cpu_backend_context);
366 }
367
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)368 inline void FullyConnected(
369 const FullyConnectedParams& params, const RuntimeShape& input_shape,
370 const uint8* input_data, const RuntimeShape& filter_shape,
371 const uint8* filter_data, const RuntimeShape& bias_shape,
372 const int32* bias_data, const RuntimeShape& output_shape,
373 uint8* output_data, CpuBackendContext* cpu_backend_context) {
374 ruy::profiler::ScopeLabel label("FullyConnected/8bit");
375 const int32 input_offset = params.input_offset;
376 const int32 filter_offset = params.weights_offset;
377 const int32 output_offset = params.output_offset;
378 const int32 output_multiplier = params.output_multiplier;
379 const int output_shift = params.output_shift;
380 const int32 output_activation_min = params.quantized_activation_min;
381 const int32 output_activation_max = params.quantized_activation_max;
382 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
383 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
384 // TODO(b/62193649): This really should be:
385 // const int batches = ArraySize(output_dims, 1);
386 // but the current --variable_batch hack consists in overwriting the 3rd
387 // dimension with the runtime batch size, as we don't keep track for each
388 // array of which dimension is the batch dimension in it.
389 const int output_dim_count = output_shape.DimensionsCount();
390 const int filter_dim_count = filter_shape.DimensionsCount();
391 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
392 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
393 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
394 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
395 const int output_rows = output_shape.Dims(output_dim_count - 1);
396 TFLITE_DCHECK_EQ(output_rows, filter_rows);
397 if (bias_data) {
398 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
399 }
400
401 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
402 lhs_params.rows = filter_rows;
403 lhs_params.cols = filter_cols;
404 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
405 lhs_params.zero_point = -filter_offset;
406 lhs_params.cache_policy =
407 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
408 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
409 rhs_params.rows = filter_cols;
410 rhs_params.cols = batches;
411 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
412 rhs_params.zero_point = -input_offset;
413 rhs_params.cache_policy =
414 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
415 cpu_backend_gemm::MatrixParams<uint8> dst_params;
416 dst_params.rows = filter_rows;
417 dst_params.cols = batches;
418 dst_params.order = cpu_backend_gemm::Order::kColMajor;
419 dst_params.zero_point = output_offset;
420 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
421 gemm_params.bias = bias_data;
422 gemm_params.clamp_min = output_activation_min;
423 gemm_params.clamp_max = output_activation_max;
424 gemm_params.multiplier_fixedpoint = output_multiplier;
425 gemm_params.multiplier_exponent = output_shift;
426 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
427 dst_params, output_data, gemm_params,
428 cpu_backend_context);
429 }
430
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)431 inline void FullyConnected(
432 const FullyConnectedParams& params, const RuntimeShape& input_shape,
433 const uint8* input_data, const RuntimeShape& filter_shape,
434 const uint8* filter_data, const RuntimeShape& bias_shape,
435 const int32* bias_data_int32, const RuntimeShape& output_shape,
436 int16* output_data, CpuBackendContext* cpu_backend_context) {
437 ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
438 const int32 input_offset = params.input_offset;
439 const int32 filter_offset = params.weights_offset;
440 const int32 output_offset = params.output_offset;
441 const int32 output_multiplier = params.output_multiplier;
442 const int output_shift = params.output_shift;
443 const int32 output_activation_min = params.quantized_activation_min;
444 const int32 output_activation_max = params.quantized_activation_max;
445 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
446 TFLITE_DCHECK_EQ(output_offset, 0);
447 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
448 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
449
450 // TODO(b/62193649): This really should be:
451 // const int batches = ArraySize(output_dims, 1);
452 // but the current --variable_batch hack consists in overwriting the 3rd
453 // dimension with the runtime batch size, as we don't keep track for each
454 // array of which dimension is the batch dimension in it.
455 const int output_dim_count = output_shape.DimensionsCount();
456 const int filter_dim_count = filter_shape.DimensionsCount();
457 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
458 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
459 output_shape, output_dim_count - 1);
460 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
461
462 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
463 lhs_params.rows = output_depth;
464 lhs_params.cols = accum_depth;
465 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
466 lhs_params.zero_point = -filter_offset;
467 lhs_params.cache_policy =
468 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
469 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
470 rhs_params.rows = accum_depth;
471 rhs_params.cols = batches;
472 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
473 rhs_params.zero_point = -input_offset;
474 rhs_params.cache_policy =
475 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
476 cpu_backend_gemm::MatrixParams<int16> dst_params;
477 dst_params.rows = output_depth;
478 dst_params.cols = batches;
479 dst_params.order = cpu_backend_gemm::Order::kColMajor;
480 dst_params.zero_point = 0;
481 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
482 gemm_params.bias = bias_data_int32;
483 gemm_params.clamp_min = output_activation_min;
484 gemm_params.clamp_max = output_activation_max;
485 gemm_params.multiplier_fixedpoint = output_multiplier;
486 gemm_params.multiplier_exponent = output_shift;
487 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
488 dst_params, output_data, gemm_params,
489 cpu_backend_context);
490 }
491
492 // Internal function doing the actual arithmetic work for
493 // ShuffledFullyConnected.
494 // May be called either directly by it (single-threaded case) or may be used
495 // as the 'task' for worker threads to run (multi-threaded case, see
496 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)497 inline void ShuffledFullyConnectedWorkerImpl(
498 const uint8* shuffled_input_workspace_data,
499 const int8* shuffled_weights_data, int batches, int output_depth,
500 int output_stride, int accum_depth, const int32* bias_data,
501 int32 output_multiplier, int output_shift, int16* output_data) {
502 #if defined USE_NEON
503 const int8* shuffled_weights_ptr = shuffled_weights_data;
504 if (batches == 1) {
505 const int right_shift = output_shift > 0 ? 0 : -output_shift;
506 const int left_shift = output_shift > 0 ? output_shift : 0;
507 for (int c = 0; c < output_depth; c += 4) {
508 // Accumulation loop.
509 int32x4_t row_accum0 = vdupq_n_s32(0);
510 int32x4_t row_accum1 = vdupq_n_s32(0);
511 int32x4_t row_accum2 = vdupq_n_s32(0);
512 int32x4_t row_accum3 = vdupq_n_s32(0);
513 for (int d = 0; d < accum_depth; d += 16) {
514 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
515 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
516 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
517 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
518 shuffled_weights_ptr += 64;
519 int8x16_t input =
520 vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
521 int16x8_t local_accum0 =
522 vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
523 int16x8_t local_accum1 =
524 vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
525 int16x8_t local_accum2 =
526 vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
527 int16x8_t local_accum3 =
528 vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
529 local_accum0 =
530 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
531 local_accum1 =
532 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
533 local_accum2 =
534 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
535 local_accum3 =
536 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
537 row_accum0 = vpadalq_s16(row_accum0, local_accum0);
538 row_accum1 = vpadalq_s16(row_accum1, local_accum1);
539 row_accum2 = vpadalq_s16(row_accum2, local_accum2);
540 row_accum3 = vpadalq_s16(row_accum3, local_accum3);
541 }
542 // Horizontally reduce accumulators
543 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
544 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
545 pairwise_reduced_acc_0 =
546 vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
547 pairwise_reduced_acc_1 =
548 vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
549 pairwise_reduced_acc_2 =
550 vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
551 pairwise_reduced_acc_3 =
552 vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
553 const int32x2_t reduced_lo =
554 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
555 const int32x2_t reduced_hi =
556 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
557 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
558 // Add bias values.
559 int32x4_t bias_vec = vld1q_s32(bias_data + c);
560 reduced = vaddq_s32(reduced, bias_vec);
561 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
562 // Multiply by the fixed-point multiplier.
563 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
564 // Rounding-shift-right.
565 using gemmlowp::RoundingDivideByPOT;
566 reduced = RoundingDivideByPOT(reduced, right_shift);
567 // Narrow values down to 16 bit signed.
568 const int16x4_t res16 = vqmovn_s32(reduced);
569 vst1_s16(output_data + c, res16);
570 }
571 } else if (batches == 4) {
572 const int right_shift = output_shift > 0 ? 0 : -output_shift;
573 const int left_shift = output_shift > 0 ? output_shift : 0;
574 for (int c = 0; c < output_depth; c += 4) {
575 const int8* shuffled_input_ptr =
576 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
577 // Accumulation loop.
578 int32x4_t row_accum00 = vdupq_n_s32(0);
579 int32x4_t row_accum10 = vdupq_n_s32(0);
580 int32x4_t row_accum20 = vdupq_n_s32(0);
581 int32x4_t row_accum30 = vdupq_n_s32(0);
582 int32x4_t row_accum01 = vdupq_n_s32(0);
583 int32x4_t row_accum11 = vdupq_n_s32(0);
584 int32x4_t row_accum21 = vdupq_n_s32(0);
585 int32x4_t row_accum31 = vdupq_n_s32(0);
586 int32x4_t row_accum02 = vdupq_n_s32(0);
587 int32x4_t row_accum12 = vdupq_n_s32(0);
588 int32x4_t row_accum22 = vdupq_n_s32(0);
589 int32x4_t row_accum32 = vdupq_n_s32(0);
590 int32x4_t row_accum03 = vdupq_n_s32(0);
591 int32x4_t row_accum13 = vdupq_n_s32(0);
592 int32x4_t row_accum23 = vdupq_n_s32(0);
593 int32x4_t row_accum33 = vdupq_n_s32(0);
594 for (int d = 0; d < accum_depth; d += 16) {
595 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
596 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
597 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
598 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
599 shuffled_weights_ptr += 64;
600 int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
601 int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
602 int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
603 int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
604 shuffled_input_ptr += 64;
605 int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
606 #define TFLITE_SHUFFLED_FC_ACCUM(B) \
607 local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B)); \
608 local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B)); \
609 local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B)); \
610 local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B)); \
611 local_accum0 = \
612 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
613 local_accum1 = \
614 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
615 local_accum2 = \
616 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
617 local_accum3 = \
618 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
619 row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0); \
620 row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1); \
621 row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2); \
622 row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
623
624 TFLITE_SHUFFLED_FC_ACCUM(0)
625 TFLITE_SHUFFLED_FC_ACCUM(1)
626 TFLITE_SHUFFLED_FC_ACCUM(2)
627 TFLITE_SHUFFLED_FC_ACCUM(3)
628
629 #undef TFLITE_SHUFFLED_FC_ACCUM
630 }
631 // Horizontally reduce accumulators
632
633 #define TFLITE_SHUFFLED_FC_STORE(B) \
634 { \
635 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, \
636 pairwise_reduced_acc_2, pairwise_reduced_acc_3; \
637 pairwise_reduced_acc_0 = \
638 vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
639 pairwise_reduced_acc_1 = \
640 vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
641 pairwise_reduced_acc_2 = \
642 vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
643 pairwise_reduced_acc_3 = \
644 vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
645 const int32x2_t reduced_lo = \
646 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); \
647 const int32x2_t reduced_hi = \
648 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); \
649 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); \
650 int32x4_t bias_vec = vld1q_s32(bias_data + c); \
651 reduced = vaddq_s32(reduced, bias_vec); \
652 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); \
653 reduced = vqrdmulhq_n_s32(reduced, output_multiplier); \
654 using gemmlowp::RoundingDivideByPOT; \
655 reduced = RoundingDivideByPOT(reduced, right_shift); \
656 const int16x4_t res16 = vqmovn_s32(reduced); \
657 vst1_s16(output_data + c + B * output_stride, res16); \
658 }
659
660 TFLITE_SHUFFLED_FC_STORE(0);
661 TFLITE_SHUFFLED_FC_STORE(1);
662 TFLITE_SHUFFLED_FC_STORE(2);
663 TFLITE_SHUFFLED_FC_STORE(3);
664
665 #undef TFLITE_SHUFFLED_FC_STORE
666 }
667 } else {
668 TFLITE_DCHECK(false);
669 return;
670 }
671 #else
672 if (batches == 1) {
673 int16* output_ptr = output_data;
674 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
675 // so that just reinterpreting them as int8 values is equivalent to
676 // subtracting 128 from them, thus implementing for free the subtraction of
677 // the zero_point value 128.
678 const int8* shuffled_weights_ptr =
679 reinterpret_cast<const int8*>(shuffled_weights_data);
680 // Likewise, we preshuffled and pre-xored the input data above.
681 const int8* shuffled_input_data =
682 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
683 for (int c = 0; c < output_depth; c += 4) {
684 // Internal accumulation.
685 // Initialize accumulator with the bias-value.
686 int32 accum[4] = {0};
687 // Accumulation loop.
688 for (int d = 0; d < accum_depth; d += 16) {
689 for (int i = 0; i < 4; i++) {
690 for (int j = 0; j < 16; j++) {
691 int8 input_val = shuffled_input_data[d + j];
692 int8 weights_val = *shuffled_weights_ptr++;
693 accum[i] += weights_val * input_val;
694 }
695 }
696 }
697 for (int i = 0; i < 4; i++) {
698 // Add bias value
699 int acc = accum[i] + bias_data[c + i];
700 // Down-scale the final int32 accumulator to the scale used by our
701 // (16-bit, typically 3 integer bits) fixed-point format. The quantized
702 // multiplier and shift here have been pre-computed offline
703 // (e.g. by toco).
704 acc =
705 MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
706 // Saturate, cast to int16, and store to output array.
707 acc = std::max(acc, -32768);
708 acc = std::min(acc, 32767);
709 output_ptr[c + i] = acc;
710 }
711 }
712 } else if (batches == 4) {
713 int16* output_ptr = output_data;
714 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
715 // so that just reinterpreting them as int8 values is equivalent to
716 // subtracting 128 from them, thus implementing for free the subtraction of
717 // the zero_point value 128.
718 const int8* shuffled_weights_ptr =
719 reinterpret_cast<const int8*>(shuffled_weights_data);
720 // Likewise, we preshuffled and pre-xored the input data above.
721 const int8* shuffled_input_data =
722 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
723 for (int c = 0; c < output_depth; c += 4) {
724 const int8* shuffled_input_ptr = shuffled_input_data;
725 // Accumulation loop.
726 // Internal accumulation.
727 // Initialize accumulator with the bias-value.
728 int32 accum[4][4];
729 for (int i = 0; i < 4; i++) {
730 for (int b = 0; b < 4; b++) {
731 accum[i][b] = 0;
732 }
733 }
734 for (int d = 0; d < accum_depth; d += 16) {
735 for (int i = 0; i < 4; i++) {
736 for (int b = 0; b < 4; b++) {
737 for (int j = 0; j < 16; j++) {
738 int8 input_val = shuffled_input_ptr[16 * b + j];
739 int8 weights_val = shuffled_weights_ptr[16 * i + j];
740 accum[i][b] += weights_val * input_val;
741 }
742 }
743 }
744 shuffled_input_ptr += 64;
745 shuffled_weights_ptr += 64;
746 }
747 for (int i = 0; i < 4; i++) {
748 for (int b = 0; b < 4; b++) {
749 // Add bias value
750 int acc = accum[i][b] + bias_data[c + i];
751 // Down-scale the final int32 accumulator to the scale used by our
752 // (16-bit, typically 3 integer bits) fixed-point format. The
753 // quantized multiplier and shift here have been pre-computed offline
754 // (e.g. by toco).
755 acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
756 output_shift);
757 // Saturate, cast to int16, and store to output array.
758 acc = std::max(acc, -32768);
759 acc = std::min(acc, 32767);
760 output_ptr[b * output_stride + c + i] = acc;
761 }
762 }
763 }
764 } else {
765 TFLITE_DCHECK(false);
766 return;
767 }
768 #endif
769 }
770
771 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
772 // to allow using gemmlowp's threadpool.
773 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask774 ShuffledFullyConnectedWorkerTask(const uint8* input_data,
775 const int8* shuffled_weights_data,
776 int batches, int output_depth,
777 int output_stride, int accum_depth,
778 const int32* bias_data,
779 int32 output_multiplier, int output_shift,
780 int16* output_data)
781 : input_data_(input_data),
782 shuffled_weights_data_(shuffled_weights_data),
783 batches_(batches),
784 output_depth_(output_depth),
785 output_stride_(output_stride),
786 accum_depth_(accum_depth),
787 bias_data_(bias_data),
788 output_multiplier_(output_multiplier),
789 output_shift_(output_shift),
790 output_data_(output_data) {}
791
RunShuffledFullyConnectedWorkerTask792 void Run() override {
793 ShuffledFullyConnectedWorkerImpl(
794 input_data_, shuffled_weights_data_, batches_, output_depth_,
795 output_stride_, accum_depth_, bias_data_, output_multiplier_,
796 output_shift_, output_data_);
797 }
798
799 const uint8* input_data_;
800 const int8* shuffled_weights_data_;
801 int batches_;
802 int output_depth_;
803 int output_stride_;
804 int accum_depth_;
805 const int32* bias_data_;
806 int32 output_multiplier_;
807 int output_shift_;
808 int16* output_data_;
809 };
810
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)811 inline void ShuffledFullyConnected(
812 const FullyConnectedParams& params, const RuntimeShape& input_shape,
813 const uint8* input_data, const RuntimeShape& weights_shape,
814 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
815 const int32* bias_data, const RuntimeShape& output_shape,
816 int16* output_data, uint8* shuffled_input_workspace_data,
817 CpuBackendContext* cpu_backend_context) {
818 ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
819 const int32 output_multiplier = params.output_multiplier;
820 const int output_shift = params.output_shift;
821 const int32 output_activation_min = params.quantized_activation_min;
822 const int32 output_activation_max = params.quantized_activation_max;
823 TFLITE_DCHECK_EQ(output_activation_min, -32768);
824 TFLITE_DCHECK_EQ(output_activation_max, 32767);
825 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
826 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
827 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
828 // TODO(b/62193649): This really should be:
829 // const int batches = ArraySize(output_dims, 1);
830 // but the current --variable_batch hack consists in overwriting the 3rd
831 // dimension with the runtime batch size, as we don't keep track for each
832 // array of which dimension is the batch dimension in it.
833 const int output_dim_count = output_shape.DimensionsCount();
834 const int weights_dim_count = weights_shape.DimensionsCount();
835 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
836 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
837 output_shape, output_dim_count - 1);
838 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
839 TFLITE_DCHECK((accum_depth % 16) == 0);
840 TFLITE_DCHECK((output_depth % 4) == 0);
841 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
842 // so that just reinterpreting them as int8 values is equivalent to
843 // subtracting 128 from them, thus implementing for free the subtraction of
844 // the zero_point value 128.
845 const int8* int8_shuffled_weights_data =
846 reinterpret_cast<const int8*>(shuffled_weights_data);
847
848 // Shuffling and xoring of input activations into the workspace buffer
849 if (batches == 1) {
850 #ifdef USE_NEON
851 const uint8x16_t signbit = vdupq_n_u8(0x80);
852 for (int i = 0; i < accum_depth; i += 16) {
853 uint8x16_t val = vld1q_u8(input_data + i);
854 val = veorq_u8(val, signbit);
855 vst1q_u8(shuffled_input_workspace_data + i, val);
856 }
857 #else
858 for (int i = 0; i < accum_depth; i++) {
859 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
860 }
861 #endif
862 } else if (batches == 4) {
863 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
864 int c = 0;
865 #ifdef USE_NEON
866 const uint8x16_t signbit = vdupq_n_u8(0x80);
867 for (c = 0; c < accum_depth; c += 16) {
868 const uint8* src_data_ptr = input_data + c;
869 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
870 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
871 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
872 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
873 val0 = veorq_u8(val0, signbit);
874 val1 = veorq_u8(val1, signbit);
875 val2 = veorq_u8(val2, signbit);
876 val3 = veorq_u8(val3, signbit);
877 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
878 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
879 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
880 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
881 shuffled_input_workspace_ptr += 64;
882 }
883 #else
884 for (c = 0; c < accum_depth; c += 16) {
885 for (int b = 0; b < 4; b++) {
886 const uint8* src_data_ptr = input_data + b * accum_depth + c;
887 for (int j = 0; j < 16; j++) {
888 uint8 src_val = *src_data_ptr++;
889 // Flip the sign bit, so that the kernel will only need to
890 // reinterpret these uint8 values as int8, getting for free the
891 // subtraction of the zero_point value 128.
892 uint8 dst_val = src_val ^ 0x80;
893 *shuffled_input_workspace_ptr++ = dst_val;
894 }
895 }
896 }
897 #endif
898 } else {
899 TFLITE_DCHECK(false);
900 return;
901 }
902
903 static constexpr int kKernelRows = 4;
904 const int thread_count =
905 LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
906 output_depth, batches, accum_depth);
907 if (thread_count == 1) {
908 // Single-thread case: do the computation on the current thread, don't
909 // use a threadpool
910 ShuffledFullyConnectedWorkerImpl(
911 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
912 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
913 output_shift, output_data);
914 return;
915 }
916
917 // Multi-threaded case: use the gemmlowp context's threadpool.
918 TFLITE_DCHECK_GT(thread_count, 1);
919 std::vector<ShuffledFullyConnectedWorkerTask> tasks;
920 // TODO(b/131746020) don't create new heap allocations every time.
921 // At least we make it a single heap allocation by using reserve().
922 tasks.reserve(thread_count);
923 const int kRowsPerWorker =
924 RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
925 int row_start = 0;
926 for (int i = 0; i < thread_count; i++) {
927 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
928 tasks.emplace_back(shuffled_input_workspace_data,
929 int8_shuffled_weights_data + row_start * accum_depth,
930 batches, row_end - row_start, output_depth, accum_depth,
931 bias_data + row_start, output_multiplier, output_shift,
932 output_data + row_start);
933 row_start = row_end;
934 }
935 TFLITE_DCHECK_EQ(row_start, output_depth);
936 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
937 cpu_backend_context);
938 }
939
940 #ifdef USE_NEON
941
RoundToNearest(const float32x4_t input)942 inline int32x4_t RoundToNearest(const float32x4_t input) {
943 #if defined(__aarch64__) || defined(__SSSE3__)
944 // Note: vcvtnq_s32_f32 is not available in ARMv7
945 return vcvtnq_s32_f32(input);
946 #else
947 static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
948 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
949 static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
950
951 const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
952 const float32x4_t round =
953 vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
954 return vcvtq_s32_f32(vaddq_f32(input, round));
955 #endif // defined(__aarch64__) || defined(__SSSE3__)
956 }
957
RoundToNearestUnsigned(const float32x4_t input)958 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
959 #if defined(__aarch64__)
960 // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
961 return vcvtnq_u32_f32(input);
962 #else
963 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
964
965 return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
966 #endif // defined(__aarch64__)
967 }
968
969 #endif // USE_NEON
970
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 multiplier,int32 shift,int32 bias,const RuntimeShape & output_shape,uint8_t * output_data,int start_depth,int end_depth)971 inline void MeanImpl(const tflite::MeanParams& op_params,
972 const RuntimeShape& input_shape, const uint8_t* input_data,
973 int32 multiplier, int32 shift, int32 bias,
974 const RuntimeShape& output_shape, uint8_t* output_data,
975 int start_depth, int end_depth) {
976 ruy::profiler::ScopeLabel label("Mean4D/Uint8/MeanImpl");
977
978 // Current implementation only supports dimension equals 4 and simultaneous
979 // reduction over width and height.
980 const int output_batch = output_shape.Dims(0);
981 const int output_height = output_shape.Dims(2);
982 const int output_width = output_shape.Dims(2);
983 const int input_height = input_shape.Dims(1);
984 const int input_width = input_shape.Dims(2);
985
986 TFLITE_CHECK_EQ(op_params.axis_count, 2);
987 TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
988 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
989 TFLITE_CHECK_EQ(output_height, 1);
990 TFLITE_CHECK_EQ(output_width, 1);
991
992 constexpr int32_t kMinValue = std::numeric_limits<uint8_t>::min();
993 constexpr int32_t kMaxValue = std::numeric_limits<uint8_t>::max();
994
995 #ifdef USE_NEON
996 const int32x4_t bias_dup = vdupq_n_s32(bias);
997 const int32x4_t min_dup = vdupq_n_s32(kMinValue);
998 const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
999 #endif // USE_NEON
1000
1001 for (int out_b = 0; out_b < output_batch; ++out_b) {
1002 int out_d = start_depth;
1003 #ifdef USE_NEON
1004
1005 for (; out_d <= end_depth - 16; out_d += 16) {
1006 int32x4x4_t temp_sum;
1007 temp_sum.val[0] = vdupq_n_s32(0);
1008 temp_sum.val[1] = vdupq_n_s32(0);
1009 temp_sum.val[2] = vdupq_n_s32(0);
1010 temp_sum.val[3] = vdupq_n_s32(0);
1011 for (int in_h = 0; in_h < input_height; ++in_h) {
1012 for (int in_w = 0; in_w < input_width; ++in_w) {
1013 const uint8_t* input_data_ptr =
1014 input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1015 uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
1016
1017 int16x8_t input_data_low_shift =
1018 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
1019 int16x8_t input_data_high_shift =
1020 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
1021
1022 int32x4_t input_low_low =
1023 vmovl_s16(vget_low_s16(input_data_low_shift));
1024 int32x4_t input_high_low =
1025 vmovl_s16(vget_high_s16(input_data_low_shift));
1026 int32x4_t input_low_high =
1027 vmovl_s16(vget_low_s16(input_data_high_shift));
1028 int32x4_t input_high_high =
1029 vmovl_s16(vget_high_s16(input_data_high_shift));
1030
1031 temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
1032 temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
1033 temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
1034 temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
1035 }
1036 }
1037
1038 temp_sum =
1039 MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
1040
1041 temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
1042 temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
1043 temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
1044 temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
1045
1046 temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
1047 temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
1048 temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
1049 temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
1050
1051 uint16x4_t narrowed_low_low =
1052 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
1053 uint16x4_t narrowed_high_low =
1054 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
1055 uint16x4_t narrowed_low_high =
1056 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
1057 uint16x4_t narrowed_high_high =
1058 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
1059
1060 uint16x8_t combined_low =
1061 vcombine_u16(narrowed_low_low, narrowed_high_low);
1062 uint16x8_t combined_high =
1063 vcombine_u16(narrowed_low_high, narrowed_high_high);
1064
1065 uint8x8_t narrowed_low = vmovn_u16(combined_low);
1066 uint8x8_t narrowed_high = vmovn_u16(combined_high);
1067
1068 uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
1069
1070 uint8_t* output_data_ptr =
1071 output_data + Offset(output_shape, out_b, 0, 0, out_d);
1072 vst1q_u8(output_data_ptr, combined_output);
1073 }
1074 #endif // USE_NEON
1075
1076 for (; out_d < end_depth; ++out_d) {
1077 int acc = 0;
1078 for (int in_h = 0; in_h < input_height; ++in_h) {
1079 for (int in_w = 0; in_w < input_width; ++in_w) {
1080 acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1081 }
1082 }
1083
1084 acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
1085 acc += bias;
1086 acc = std::min(std::max(acc, kMinValue), kMaxValue);
1087 output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1088 static_cast<uint8_t>(acc);
1089 }
1090 }
1091 }
1092
1093 struct MeanWorkerTask : cpu_backend_threadpool::Task {
MeanWorkerTaskMeanWorkerTask1094 MeanWorkerTask(const tflite::MeanParams& op_params,
1095 const RuntimeShape& input_shape, const uint8_t* input_data,
1096 int32 multiplier, int32 shift, int32 bias,
1097 const RuntimeShape& output_shape, uint8_t* output_data,
1098 int start_height, int end_height)
1099 : op_params(op_params),
1100 input_shape(input_shape),
1101 input_data(input_data),
1102 multiplier(multiplier),
1103 shift(shift),
1104 bias(bias),
1105 output_shape(output_shape),
1106 output_data(output_data),
1107 start_height(start_height),
1108 end_height(end_height) {}
1109
RunMeanWorkerTask1110 void Run() override {
1111 MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1112 output_shape, output_data, start_height, end_height);
1113 }
1114
1115 private:
1116 const tflite::MeanParams& op_params;
1117 const RuntimeShape& input_shape;
1118 const uint8_t* input_data;
1119 int32 multiplier;
1120 int32 shift;
1121 int32 bias;
1122 const RuntimeShape& output_shape;
1123 uint8_t* output_data;
1124 int start_height;
1125 int end_height;
1126 };
1127
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,CpuBackendContext * cpu_backend_context)1128 inline void Mean(const tflite::MeanParams& op_params,
1129 const RuntimeShape& unextended_input_shape,
1130 const uint8_t* input_data, int32 input_zero_point,
1131 float input_scale, const RuntimeShape& unextended_output_shape,
1132 uint8_t* output_data, int32 output_zero_point,
1133 float output_scale, CpuBackendContext* cpu_backend_context) {
1134 ruy::profiler::ScopeLabel label("Mean4D/Uint8");
1135 // Current implementation only supports dimension equals 4 and simultaneous
1136 // reduction over width and height.
1137 TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1138 TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1139 const RuntimeShape input_shape =
1140 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1141 const RuntimeShape output_shape =
1142 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1143 const int output_height = output_shape.Dims(1);
1144 const int output_width = output_shape.Dims(2);
1145 const int output_depth = output_shape.Dims(3);
1146
1147 TFLITE_CHECK_EQ(op_params.axis_count, 2);
1148 TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1149 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1150 TFLITE_CHECK_EQ(output_height, 1);
1151 TFLITE_CHECK_EQ(output_width, 1);
1152
1153 const int input_height = input_shape.Dims(1);
1154 const int input_width = input_shape.Dims(2);
1155 const float num_elements_in_axis = input_width * input_height;
1156
1157 int32 bias =
1158 output_zero_point -
1159 static_cast<int32>(input_zero_point * input_scale / output_scale);
1160 float real_scale = input_scale / (num_elements_in_axis * output_scale);
1161
1162 int32 multiplier, shift;
1163 QuantizeMultiplier(real_scale, &multiplier, &shift);
1164
1165 constexpr int kMinDepthPerThread = 8;
1166 int thread_count = output_depth / kMinDepthPerThread;
1167 thread_count = thread_count > 0 ? thread_count : 1;
1168 const int capped_thread_count =
1169 std::min(thread_count, cpu_backend_context->max_num_threads());
1170
1171 if (capped_thread_count == 1) {
1172 MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1173 output_shape, output_data, 0, output_depth);
1174 } else {
1175 // Instead parallel for batch, we loop for the output_depth since batch
1176 // is typical 1.
1177 std::vector<MeanWorkerTask> tasks;
1178 // TODO(b/131746020) don't create new heap allocations every time.
1179 // At least we make it a single heap allocation by using reserve().
1180 tasks.reserve(capped_thread_count);
1181 int depth_start = 0;
1182 for (int i = 0; i < capped_thread_count; ++i) {
1183 // Try to distribute the tasks as even as possible.
1184 int depth_end = depth_start +
1185 (output_depth - depth_start) / (capped_thread_count - i);
1186 tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
1187 bias, output_shape, output_data, depth_start,
1188 depth_end);
1189 depth_start = depth_end;
1190 }
1191 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
1192 cpu_backend_context);
1193 }
1194 }
1195
1196 template <typename T, typename U>
MeanGeneral(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * temp_index,int * resolved_axis,U * temp_sum)1197 inline bool MeanGeneral(const T* input_data, const int* input_dims,
1198 const int input_num_dims, T* output_data,
1199 const int* output_dims, const int output_num_dims,
1200 const int* axis, const int num_axis_dimensions,
1201 bool keep_dims, int* temp_index, int* resolved_axis,
1202 U* temp_sum) {
1203 return reference_ops::Mean(input_data, input_dims, input_num_dims,
1204 output_data, output_dims, output_num_dims, axis,
1205 num_axis_dimensions, keep_dims, temp_index,
1206 resolved_axis, temp_sum);
1207 }
1208
1209 template <>
1210 inline bool MeanGeneral<float, float>(
1211 const float* input_data, const int* input_dims, const int input_num_dims,
1212 float* output_data, const int* output_dims, const int output_num_dims,
1213 const int* axis, const int num_axis_dimensions, bool keep_dims,
1214 int* temp_index, int* resolved_axis, float* temp_sum) {
1215 // Handle reduce_mean for the last dimensions.
1216 if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
1217 ruy::profiler::ScopeLabel label("MeanLastDim/Float");
1218 int output_size = 1;
1219 for (int i = 0; i < input_num_dims - 1; ++i) {
1220 output_size *= input_dims[i];
1221 }
1222 const int last_input_dim = input_dims[axis[0]];
1223
1224 // TODO(b/152563685): Consider use eigen to cover more general cases.
1225 const MatrixMap<const float> in_mat(input_data, last_input_dim,
1226 output_size);
1227 VectorMap<float> out(output_data, output_size, 1);
1228 out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
1229 return true;
1230 }
1231
1232 return reference_ops::Mean(input_data, input_dims, input_num_dims,
1233 output_data, output_dims, output_num_dims, axis,
1234 num_axis_dimensions, keep_dims, temp_index,
1235 resolved_axis, temp_sum);
1236 }
1237
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)1238 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1239 const float* input_data, const RuntimeShape& filter_shape,
1240 const float* filter_data, const RuntimeShape& bias_shape,
1241 const float* bias_data, const RuntimeShape& output_shape,
1242 float* output_data, const RuntimeShape& im2col_shape,
1243 float* im2col_data, CpuBackendContext* cpu_backend_context) {
1244 const int stride_width = params.stride_width;
1245 const int stride_height = params.stride_height;
1246 const int dilation_width_factor = params.dilation_width_factor;
1247 const int dilation_height_factor = params.dilation_height_factor;
1248 const float output_activation_min = params.float_activation_min;
1249 const float output_activation_max = params.float_activation_max;
1250 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1251 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1252 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1253
1254 ruy::profiler::ScopeLabel label("Conv");
1255
1256 // NB: the float 0.0f value is represented by all zero bytes.
1257 const uint8 float_zero_byte = 0x00;
1258 const float* gemm_input_data = nullptr;
1259 const RuntimeShape* gemm_input_shape = nullptr;
1260 const int filter_width = filter_shape.Dims(2);
1261 const int filter_height = filter_shape.Dims(1);
1262 const bool need_dilated_im2col =
1263 dilation_width_factor != 1 || dilation_height_factor != 1;
1264 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1265 filter_width != 1 || filter_height != 1;
1266 if (need_dilated_im2col) {
1267 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
1268 filter_shape, output_shape, im2col_data);
1269 gemm_input_data = im2col_data;
1270 gemm_input_shape = &im2col_shape;
1271 } else if (need_im2col) {
1272 TFLITE_DCHECK(im2col_data);
1273 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
1274 input_data, im2col_shape, im2col_data);
1275 gemm_input_data = im2col_data;
1276 gemm_input_shape = &im2col_shape;
1277 } else {
1278 TFLITE_DCHECK(!im2col_data);
1279 gemm_input_data = input_data;
1280 gemm_input_shape = &input_shape;
1281 }
1282
1283 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
1284 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
1285 int n = output_shape.Dims(3);
1286 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
1287
1288 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1289 // The following code computes matrix multiplication c = a * transponse(b)
1290 // with CBLAS, where:
1291 // * `a` is a matrix with dimensions (m, k).
1292 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
1293 // * `c` is a matrix with dimensions (m, n).
1294 // The naming of variables are aligned with CBLAS specification here.
1295 const float* a = gemm_input_data;
1296 const float* b = filter_data;
1297 float* c = output_data;
1298 // The stride of matrix a, b and c respectively.
1299 int stride_a = k;
1300 int stride_b = k;
1301 int stride_c = n;
1302
1303 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
1304 stride_a, b, stride_b, 0.0f, c, stride_c);
1305 optimized_ops::AddBiasAndEvalActivationFunction(
1306 output_activation_min, output_activation_max, bias_shape, bias_data,
1307 output_shape, output_data);
1308 #else
1309 // When an optimized CBLAS implementation is not available, fall back
1310 // to using cpu_backend_gemm.
1311 cpu_backend_gemm::MatrixParams<float> lhs_params;
1312 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1313 lhs_params.rows = n;
1314 lhs_params.cols = k;
1315 cpu_backend_gemm::MatrixParams<float> rhs_params;
1316 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1317 rhs_params.rows = k;
1318 rhs_params.cols = m;
1319 cpu_backend_gemm::MatrixParams<float> dst_params;
1320 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1321 dst_params.rows = n;
1322 dst_params.cols = m;
1323 cpu_backend_gemm::GemmParams<float, float> gemm_params;
1324 gemm_params.bias = bias_data;
1325 gemm_params.clamp_min = output_activation_min;
1326 gemm_params.clamp_max = output_activation_max;
1327 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1328 dst_params, output_data, gemm_params,
1329 cpu_backend_context);
1330 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1331 }
1332
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1333 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1334 const RuntimeShape& input_shape,
1335 const int8_t* input_data,
1336 const RuntimeShape& filter_shape,
1337 const int8_t* filter_data,
1338 const RuntimeShape& bias_shape, const float* bias_data,
1339 const RuntimeShape& accum_scratch_shape,
1340 int32_t* accum_scratch, const RuntimeShape& output_shape,
1341 float* output_data, const RuntimeShape& im2col_shape,
1342 int8_t* im2col_data, CpuBackendContext* context) {
1343 const int stride_width = params.stride_width;
1344 const int stride_height = params.stride_height;
1345 const int dilation_width_factor = params.dilation_width_factor;
1346 const int dilation_height_factor = params.dilation_height_factor;
1347 const float output_activation_min = params.float_activation_min;
1348 const float output_activation_max = params.float_activation_max;
1349 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1350 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1351 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1352
1353 const int batch_size = input_shape.Dims(0);
1354 const int filter_width = filter_shape.Dims(2);
1355 const int filter_height = filter_shape.Dims(1);
1356
1357 const int input_zero_point = 0;
1358 const int8_t* gemm_input_data = nullptr;
1359 int num_input;
1360 const bool need_dilated_im2col =
1361 dilation_width_factor != 1 || dilation_height_factor != 1;
1362 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1363 filter_width != 1 || filter_height != 1;
1364
1365 if (need_dilated_im2col) {
1366 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1367 filter_shape, output_shape, im2col_data);
1368 gemm_input_data = im2col_data;
1369 num_input = im2col_shape.FlatSize();
1370 } else if (need_im2col) {
1371 TFLITE_DCHECK(im2col_data);
1372 // symmetric quantization assumes zero point of 0.
1373
1374 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1375 input_data, im2col_shape, im2col_data);
1376 gemm_input_data = im2col_data;
1377 num_input = im2col_shape.FlatSize();
1378 } else {
1379 TFLITE_DCHECK(!im2col_data);
1380 gemm_input_data = input_data;
1381 num_input = input_shape.FlatSize();
1382 }
1383
1384 // Flatten 4D matrices into 2D matrices for matrix multiplication.
1385
1386 // Flatten so that each filter has its own row.
1387 const int filter_rows = filter_shape.Dims(0);
1388 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1389
1390 // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1391 // dot product of one row of the first matrix with one row of the second
1392 // matrix. Therefore, the number of cols in each matrix are equivalent.
1393 //
1394 // After Im2Col, each input patch becomes a row.
1395 const int gemm_input_cols = filter_cols;
1396 const int gemm_input_rows = num_input / gemm_input_cols;
1397
1398 const int output_cols = output_shape.Dims(3);
1399 const int output_rows = FlatSizeSkipDim(output_shape, 3);
1400 TFLITE_DCHECK_EQ(output_cols, filter_rows);
1401 TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1402 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1403
1404 // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1405 // input matrix has its own scale factor. This code duplicates the scale
1406 // factors for each row in the same batch.
1407 const int rows_per_batch = gemm_input_rows / batch_size;
1408 for (int i = gemm_input_rows - 1; i >= 0; --i) {
1409 scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1410 }
1411
1412 std::fill_n(output_data, output_rows * output_cols, 0.0f);
1413
1414 // The scratch buffer must have the same size as the output.
1415 TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1416 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1417 filter_data, filter_rows, filter_cols, gemm_input_data,
1418 scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1419 output_data, context);
1420 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1421 bias_shape, bias_data, output_shape,
1422 output_data);
1423 }
1424
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1425 inline void HybridConvPerChannel(
1426 const ConvParams& params, float* scaling_factors_ptr,
1427 const RuntimeShape& input_shape, const int8_t* input_data,
1428 const RuntimeShape& filter_shape, const int8_t* filter_data,
1429 const RuntimeShape& bias_shape, const float* bias_data,
1430 const RuntimeShape& output_shape, float* output_data,
1431 const RuntimeShape& im2col_shape, int8_t* im2col_data,
1432 const float* per_channel_scale, int32_t* input_offset,
1433 const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1434 bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1435 ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1436 const int stride_width = params.stride_width;
1437 const int stride_height = params.stride_height;
1438 const int dilation_width_factor = params.dilation_width_factor;
1439 const int dilation_height_factor = params.dilation_height_factor;
1440 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1441 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1442 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1443
1444 const int8* gemm_input_data = nullptr;
1445 const RuntimeShape* gemm_input_shape = nullptr;
1446 const int filter_width = filter_shape.Dims(2);
1447 const int filter_height = filter_shape.Dims(1);
1448 const bool need_dilated_im2col =
1449 dilation_width_factor != 1 || dilation_height_factor != 1;
1450 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1451 filter_width != 1 || filter_height != 1;
1452
1453 const int batch_size = input_shape.Dims(0);
1454
1455 if (need_dilated_im2col) {
1456 TFLITE_DCHECK(im2col_data);
1457 optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1458 output_shape, im2col_data, input_offset,
1459 batch_size);
1460 gemm_input_data = im2col_data;
1461 gemm_input_shape = &im2col_shape;
1462 } else if (need_im2col) {
1463 Im2col(params, filter_height, filter_width, input_offset, batch_size,
1464 input_shape, input_data, im2col_shape, im2col_data);
1465 gemm_input_data = im2col_data;
1466 gemm_input_shape = &im2col_shape;
1467 } else {
1468 TFLITE_DCHECK(!im2col_data);
1469 gemm_input_data = input_data;
1470 gemm_input_shape = &input_shape;
1471 }
1472
1473 const int filter_rows = filter_shape.Dims(0);
1474 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1475
1476 const int gemm_input_rows = gemm_input_shape->Dims(3);
1477 const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1478 const int output_rows = output_shape.Dims(3);
1479 const int output_cols =
1480 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1481
1482 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1483 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1484 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1485 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1486 TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1487 if (!compute_row_sums || *compute_row_sums) {
1488 tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1489 filter_cols);
1490 if (compute_row_sums) {
1491 *compute_row_sums = false;
1492 }
1493 }
1494
1495 cpu_backend_gemm::MatrixParams<int8> lhs_params;
1496 lhs_params.rows = filter_rows;
1497 lhs_params.cols = filter_cols;
1498 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1499
1500 cpu_backend_gemm::MatrixParams<int8> rhs_params;
1501 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1502 rhs_params.rows = gemm_input_rows;
1503 rhs_params.cols = gemm_input_cols;
1504
1505 cpu_backend_gemm::MatrixParams<int32> dst_params;
1506 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1507 dst_params.rows = output_rows;
1508 dst_params.cols = output_cols;
1509
1510 // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1511 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1512 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1513 dst_params, scratch, gemm_params, cpu_backend_context);
1514
1515 MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1516 MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1517 VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1518 VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1519 VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1520 1);
1521 const int cols_per_batch = output_cols / batch_size;
1522 for (int c = 0; c < output_cols; c++) {
1523 const int b = c / cols_per_batch;
1524 const float input_scale = scaling_factors_ptr[b];
1525 const int32_t zero_point = input_offset[b];
1526 out_mat.col(c) =
1527 (((in_mat.col(c) - (row_sums_vec * zero_point))
1528 .cast<float>()
1529 .cwiseProduct((per_channel_scale_vec * input_scale))) +
1530 bias_data_vec)
1531 .cwiseMin(params.float_activation_max)
1532 .cwiseMax(params.float_activation_min);
1533 }
1534 }
1535
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1536 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1537 const uint8* input_data, const RuntimeShape& filter_shape,
1538 const uint8* filter_data, const RuntimeShape& bias_shape,
1539 const int32* bias_data, const RuntimeShape& output_shape,
1540 uint8* output_data, const RuntimeShape& im2col_shape,
1541 uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1542 ruy::profiler::ScopeLabel label("Conv/8bit");
1543
1544 const int stride_width = params.stride_width;
1545 const int stride_height = params.stride_height;
1546 const int dilation_width_factor = params.dilation_width_factor;
1547 const int dilation_height_factor = params.dilation_height_factor;
1548 const int32 input_offset = params.input_offset;
1549 const int32 filter_offset = params.weights_offset;
1550 const int32 output_offset = params.output_offset;
1551 const int32 output_multiplier = params.output_multiplier;
1552 const int output_shift = params.output_shift;
1553 const int32 output_activation_min = params.quantized_activation_min;
1554 const int32 output_activation_max = params.quantized_activation_max;
1555 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1556 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1557 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1558
1559 const uint8* gemm_input_data = nullptr;
1560 const RuntimeShape* gemm_input_shape = nullptr;
1561 const int filter_width = filter_shape.Dims(2);
1562 const int filter_height = filter_shape.Dims(1);
1563 const bool need_dilated_im2col =
1564 dilation_width_factor != 1 || dilation_height_factor != 1;
1565 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1566 filter_width != 1 || filter_height != 1;
1567 if (need_dilated_im2col) {
1568 TFLITE_DCHECK(im2col_data);
1569 const int input_zero_point = -input_offset;
1570 TFLITE_DCHECK_GE(input_zero_point, 0);
1571 TFLITE_DCHECK_LE(input_zero_point, 255);
1572 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1573 filter_shape, output_shape, im2col_data);
1574 gemm_input_data = im2col_data;
1575 gemm_input_shape = &im2col_shape;
1576 } else if (need_im2col) {
1577 TFLITE_DCHECK(im2col_data);
1578 const int input_zero_point = -input_offset;
1579 TFLITE_DCHECK_GE(input_zero_point, 0);
1580 TFLITE_DCHECK_LE(input_zero_point, 255);
1581 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1582 input_data, im2col_shape, im2col_data);
1583 gemm_input_data = im2col_data;
1584 gemm_input_shape = &im2col_shape;
1585 } else {
1586 TFLITE_DCHECK(!im2col_data);
1587 gemm_input_data = input_data;
1588 gemm_input_shape = &input_shape;
1589 }
1590
1591 const int gemm_input_rows = gemm_input_shape->Dims(3);
1592 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1593 // The root cause has not yet been identified though. Same applies below for
1594 // the other calls commented out. This is a partial rollback of cl/196819423.
1595 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1596 const int gemm_input_cols = gemm_input_shape->Dims(0) *
1597 gemm_input_shape->Dims(1) *
1598 gemm_input_shape->Dims(2);
1599 const int filter_rows = filter_shape.Dims(0);
1600 // See b/79927784.
1601 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1602 const int filter_cols =
1603 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1604 const int output_rows = output_shape.Dims(3);
1605 // See b/79927784.
1606 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1607 const int output_cols =
1608 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1609 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1610 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1611 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1612 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1613
1614 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1615 lhs_params.rows = filter_rows;
1616 lhs_params.cols = filter_cols;
1617 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1618 lhs_params.zero_point = -filter_offset;
1619 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1620 rhs_params.rows = gemm_input_rows;
1621 rhs_params.cols = gemm_input_cols;
1622 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1623 rhs_params.zero_point = -input_offset;
1624 cpu_backend_gemm::MatrixParams<uint8> dst_params;
1625 dst_params.rows = output_rows;
1626 dst_params.cols = output_cols;
1627 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1628 dst_params.zero_point = output_offset;
1629 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1630 gemm_params.bias = bias_data;
1631 gemm_params.clamp_min = output_activation_min;
1632 gemm_params.clamp_max = output_activation_max;
1633 gemm_params.multiplier_fixedpoint = output_multiplier;
1634 gemm_params.multiplier_exponent = output_shift;
1635 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1636 dst_params, output_data, gemm_params,
1637 cpu_backend_context);
1638 }
1639
1640 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1641 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1642 const RuntimeShape& unextended_input_shape,
1643 const T* input_data,
1644 const RuntimeShape& unextended_output_shape,
1645 T* output_data) {
1646 ruy::profiler::ScopeLabel label("DepthToSpace");
1647
1648 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1649 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1650 const RuntimeShape input_shape =
1651 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1652 const RuntimeShape output_shape =
1653 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1654
1655 const int input_depth = input_shape.Dims(3);
1656 const int input_width = input_shape.Dims(2);
1657 const int input_height = input_shape.Dims(1);
1658
1659 const int output_depth = output_shape.Dims(3);
1660 const int batch_size = output_shape.Dims(0);
1661
1662 // Number of continuous values that we can copy in one interation.
1663 const int stride = op_params.block_size * output_depth;
1664
1665 for (int batch = 0; batch < batch_size; ++batch) {
1666 for (int in_h = 0; in_h < input_height; ++in_h) {
1667 const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1668 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1669 const T* src = input_ptr;
1670 for (int in_w = 0; in_w < input_width; ++in_w) {
1671 memcpy(output_data, src, stride * sizeof(T));
1672 output_data += stride;
1673 src += input_depth;
1674 }
1675 input_ptr += stride;
1676 }
1677 }
1678 }
1679 }
1680
1681 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1682 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1683 const RuntimeShape& unextended_input_shape,
1684 const T* input_data,
1685 const RuntimeShape& unextended_output_shape,
1686 T* output_data) {
1687 ruy::profiler::ScopeLabel label("SpaceToDepth");
1688
1689 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1690 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1691 const RuntimeShape input_shape =
1692 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1693 const RuntimeShape output_shape =
1694 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1695
1696 const int output_depth = output_shape.Dims(3);
1697 const int output_width = output_shape.Dims(2);
1698 const int output_height = output_shape.Dims(1);
1699
1700 const int input_depth = input_shape.Dims(3);
1701 const int batch_size = input_shape.Dims(0);
1702
1703 // Number of continuous values that we can copy in one interation.
1704 const int stride = op_params.block_size * input_depth;
1705
1706 for (int batch = 0; batch < batch_size; ++batch) {
1707 for (int out_h = 0; out_h < output_height; ++out_h) {
1708 T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1709 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1710 T* dst = output_ptr;
1711 for (int out_w = 0; out_w < output_width; ++out_w) {
1712 memcpy(dst, input_data, stride * sizeof(T));
1713 input_data += stride;
1714 dst += output_depth;
1715 }
1716 output_ptr += stride;
1717 }
1718 }
1719 }
1720 }
1721
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1722 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1723 const RuntimeShape& output_shape, float* output_data) {
1724 ruy::profiler::ScopeLabel label("Relu (not fused)");
1725
1726 const auto input = MapAsVector(input_data, input_shape);
1727 auto output = MapAsVector(output_data, output_shape);
1728 output = input.cwiseMax(0.0f);
1729 }
1730
1731 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1732 const RuntimeShape& input_shape,
1733 const float* input_data,
1734 const RuntimeShape& output_shape,
1735 float* output_data, float epsilon = 1e-6) {
1736 ruy::profiler::ScopeLabel label("L2Normalization");
1737 const int trailing_dim = input_shape.DimensionsCount() - 1;
1738 const int outer_size =
1739 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1740 const int depth =
1741 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1742 for (int i = 0; i < outer_size; ++i) {
1743 float squared_l2_norm = 0;
1744 for (int c = 0; c < depth; ++c) {
1745 const float val = input_data[c];
1746 squared_l2_norm += val * val;
1747 }
1748 float l2_norm = std::sqrt(squared_l2_norm);
1749 l2_norm = std::max(l2_norm, epsilon);
1750 for (int c = 0; c < depth; ++c) {
1751 *output_data = *input_data / l2_norm;
1752 ++output_data;
1753 ++input_data;
1754 }
1755 }
1756 }
1757
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1758 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1759 const RuntimeShape& input_shape,
1760 const uint8* input_data,
1761 const RuntimeShape& output_shape,
1762 uint8* output_data) {
1763 ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1764 const int trailing_dim = input_shape.DimensionsCount() - 1;
1765 const int depth =
1766 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1767 const int outer_size =
1768 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1769 const int32 input_zero_point = op_params.input_zero_point;
1770 for (int i = 0; i < outer_size; ++i) {
1771 int32 square_l2_norm = 0;
1772 for (int c = 0; c < depth; c++) {
1773 // Note that input_data advances by depth in the second pass below.
1774 int32 diff = input_data[c] - input_zero_point;
1775 square_l2_norm += diff * diff;
1776 }
1777 // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1778 // for all zero tensors in the input_data
1779 int32 inv_l2norm_multiplier;
1780 int inv_l2norm_shift;
1781 GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1782 &inv_l2norm_multiplier, &inv_l2norm_shift);
1783
1784 for (int c = 0; c < depth; c++) {
1785 int32 diff = *input_data - input_zero_point;
1786 int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1787 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1788 int32 unclamped_output_val = 128 + rescaled_diff;
1789 int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1790 *output_data = static_cast<uint8>(output_val);
1791 ++input_data;
1792 ++output_data;
1793 }
1794 }
1795 }
1796
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1797 inline void AddElementwise(int size, const ArithmeticParams& params,
1798 const float* input1_data, const float* input2_data,
1799 float* output_data) {
1800 int i = 0;
1801
1802 #ifdef USE_NEON
1803 const auto activation_min = vdupq_n_f32(params.float_activation_min);
1804 const auto activation_max = vdupq_n_f32(params.float_activation_max);
1805 for (; i <= size - 16; i += 16) {
1806 auto a10 = vld1q_f32(input1_data + i);
1807 auto a11 = vld1q_f32(input1_data + i + 4);
1808 auto a12 = vld1q_f32(input1_data + i + 8);
1809 auto a13 = vld1q_f32(input1_data + i + 12);
1810 auto a20 = vld1q_f32(input2_data + i);
1811 auto a21 = vld1q_f32(input2_data + i + 4);
1812 auto a22 = vld1q_f32(input2_data + i + 8);
1813 auto a23 = vld1q_f32(input2_data + i + 12);
1814 auto x0 = vaddq_f32(a10, a20);
1815 auto x1 = vaddq_f32(a11, a21);
1816 auto x2 = vaddq_f32(a12, a22);
1817 auto x3 = vaddq_f32(a13, a23);
1818 x0 = vmaxq_f32(activation_min, x0);
1819 x1 = vmaxq_f32(activation_min, x1);
1820 x2 = vmaxq_f32(activation_min, x2);
1821 x3 = vmaxq_f32(activation_min, x3);
1822 x0 = vminq_f32(activation_max, x0);
1823 x1 = vminq_f32(activation_max, x1);
1824 x2 = vminq_f32(activation_max, x2);
1825 x3 = vminq_f32(activation_max, x3);
1826 vst1q_f32(output_data + i, x0);
1827 vst1q_f32(output_data + i + 4, x1);
1828 vst1q_f32(output_data + i + 8, x2);
1829 vst1q_f32(output_data + i + 12, x3);
1830 }
1831 for (; i <= size - 4; i += 4) {
1832 auto a1 = vld1q_f32(input1_data + i);
1833 auto a2 = vld1q_f32(input2_data + i);
1834 auto x = vaddq_f32(a1, a2);
1835 x = vmaxq_f32(activation_min, x);
1836 x = vminq_f32(activation_max, x);
1837 vst1q_f32(output_data + i, x);
1838 }
1839 #endif // NEON
1840
1841 for (; i < size; i++) {
1842 auto x = input1_data[i] + input2_data[i];
1843 output_data[i] = ActivationFunctionWithMinMax(
1844 x, params.float_activation_min, params.float_activation_max);
1845 }
1846 }
1847
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1848 inline void Add(const ArithmeticParams& params,
1849 const RuntimeShape& input1_shape, const float* input1_data,
1850 const RuntimeShape& input2_shape, const float* input2_data,
1851 const RuntimeShape& output_shape, float* output_data) {
1852 ruy::profiler::ScopeLabel label("Add");
1853 const int flat_size =
1854 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1855 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1856 }
1857
1858 // Element-wise add that can often be used for inner loop of broadcast add as
1859 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1860 inline void AddElementwise(int size, const ArithmeticParams& params,
1861 const uint8* input1_data, const uint8* input2_data,
1862 uint8* output_data) {
1863 ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1864 int i = 0;
1865 TFLITE_DCHECK_GT(params.input1_offset, -256);
1866 TFLITE_DCHECK_GT(params.input2_offset, -256);
1867 TFLITE_DCHECK_LT(params.input1_offset, 256);
1868 TFLITE_DCHECK_LT(params.input2_offset, 256);
1869 #ifdef USE_NEON
1870 const uint8x8_t output_activation_min_vector =
1871 vdup_n_u8(params.quantized_activation_min);
1872 const uint8x8_t output_activation_max_vector =
1873 vdup_n_u8(params.quantized_activation_max);
1874 for (; i <= size - 8; i += 8) {
1875 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1876 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1877 const int16x8_t input1_val_s16 =
1878 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1879 const int16x8_t input2_val_s16 =
1880 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1881 const int16x8_t input1_val =
1882 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1883 const int16x8_t input2_val =
1884 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1885 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1886 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1887 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1888 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1889 int32x4_t x11 = vmovl_s16(input1_val_low);
1890 int32x4_t x12 = vmovl_s16(input1_val_high);
1891 int32x4_t x21 = vmovl_s16(input2_val_low);
1892 int32x4_t x22 = vmovl_s16(input2_val_high);
1893 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1894 x11 = vshlq_s32(x11, left_shift_dup);
1895 x12 = vshlq_s32(x12, left_shift_dup);
1896 x21 = vshlq_s32(x21, left_shift_dup);
1897 x22 = vshlq_s32(x22, left_shift_dup);
1898 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1899 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1900 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1901 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1902 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1903 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1904 x11 = vshlq_s32(x11, input1_shift_dup);
1905 x12 = vshlq_s32(x12, input1_shift_dup);
1906 x21 = vshlq_s32(x21, input2_shift_dup);
1907 x22 = vshlq_s32(x22, input2_shift_dup);
1908 int32x4_t s1 = vaddq_s32(x11, x21);
1909 int32x4_t s2 = vaddq_s32(x12, x22);
1910 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1911 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1912 using gemmlowp::RoundingDivideByPOT;
1913 s1 = RoundingDivideByPOT(s1, -params.output_shift);
1914 s2 = RoundingDivideByPOT(s2, -params.output_shift);
1915 const int16x4_t s1_narrowed = vmovn_s32(s1);
1916 const int16x4_t s2_narrowed = vmovn_s32(s2);
1917 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1918 vdupq_n_s16(params.output_offset));
1919 const uint8x8_t clamped =
1920 vmax_u8(output_activation_min_vector,
1921 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1922 vst1_u8(output_data + i, clamped);
1923 }
1924 #endif // NEON
1925
1926 for (; i < size; ++i) {
1927 const int32 input1_val = params.input1_offset + input1_data[i];
1928 const int32 input2_val = params.input2_offset + input2_data[i];
1929 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1930 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1931 const int32 scaled_input1_val =
1932 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1933 shifted_input1_val, params.input1_multiplier, params.input1_shift);
1934 const int32 scaled_input2_val =
1935 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1936 shifted_input2_val, params.input2_multiplier, params.input2_shift);
1937 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1938 const int32 raw_output =
1939 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1940 raw_sum, params.output_multiplier, params.output_shift) +
1941 params.output_offset;
1942 const int32 clamped_output =
1943 std::min(params.quantized_activation_max,
1944 std::max(params.quantized_activation_min, raw_output));
1945 output_data[i] = static_cast<uint8>(clamped_output);
1946 }
1947 }
1948
1949 // Scalar-broadcast add that can be used for inner loop of more general
1950 // broadcast add, so that, for example, scalar-broadcast with batch will still
1951 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1952 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1953 uint8 input1_data, const uint8* input2_data,
1954 uint8* output_data) {
1955 using gemmlowp::RoundingDivideByPOT;
1956
1957 ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1958 TFLITE_DCHECK_GT(params.input1_offset, -256);
1959 TFLITE_DCHECK_GT(params.input2_offset, -256);
1960 TFLITE_DCHECK_LT(params.input1_offset, 256);
1961 TFLITE_DCHECK_LT(params.input2_offset, 256);
1962
1963 int i = 0;
1964
1965 #ifdef USE_NEON
1966 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1967 const uint8x8_t output_activation_min_vector =
1968 vdup_n_u8(params.quantized_activation_min);
1969 const uint8x8_t output_activation_max_vector =
1970 vdup_n_u8(params.quantized_activation_max);
1971
1972 // Process broadcast scalar.
1973 const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1974 const int16x8_t input1_val_s16 =
1975 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1976 const int16x8_t input1_val =
1977 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1978 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1979 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1980 int32x4_t x11 = vmovl_s16(input1_val_low);
1981 int32x4_t x12 = vmovl_s16(input1_val_high);
1982 x11 = vshlq_s32(x11, left_shift_dup);
1983 x12 = vshlq_s32(x12, left_shift_dup);
1984 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1985 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1986 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1987 x11 = vshlq_s32(x11, input1_shift_dup);
1988 x12 = vshlq_s32(x12, input1_shift_dup);
1989
1990 for (; i <= size - 8; i += 8) {
1991 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1992 const int16x8_t input2_val_s16 =
1993 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1994 const int16x8_t input2_val =
1995 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1996 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1997 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1998 int32x4_t x21 = vmovl_s16(input2_val_low);
1999 int32x4_t x22 = vmovl_s16(input2_val_high);
2000 x21 = vshlq_s32(x21, left_shift_dup);
2001 x22 = vshlq_s32(x22, left_shift_dup);
2002 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2003 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2004 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2005 x21 = vshlq_s32(x21, input2_shift_dup);
2006 x22 = vshlq_s32(x22, input2_shift_dup);
2007 int32x4_t s1 = vaddq_s32(x11, x21);
2008 int32x4_t s2 = vaddq_s32(x12, x22);
2009 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2010 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2011 s1 = RoundingDivideByPOT(s1, -params.output_shift);
2012 s2 = RoundingDivideByPOT(s2, -params.output_shift);
2013 const int16x4_t s1_narrowed = vmovn_s32(s1);
2014 const int16x4_t s2_narrowed = vmovn_s32(s2);
2015 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2016 vdupq_n_s16(params.output_offset));
2017 const uint8x8_t clamped =
2018 vmax_u8(output_activation_min_vector,
2019 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2020 vst1_u8(output_data + i, clamped);
2021 }
2022 #endif // NEON
2023
2024 if (i < size) {
2025 // Process broadcast scalar.
2026 const int32 input1_val = params.input1_offset + input1_data;
2027 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2028 const int32 scaled_input1_val =
2029 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2030 shifted_input1_val, params.input1_multiplier, params.input1_shift);
2031
2032 for (; i < size; ++i) {
2033 const int32 input2_val = params.input2_offset + input2_data[i];
2034 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2035 const int32 scaled_input2_val =
2036 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2037 shifted_input2_val, params.input2_multiplier,
2038 params.input2_shift);
2039 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2040 const int32 raw_output =
2041 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2042 raw_sum, params.output_multiplier, params.output_shift) +
2043 params.output_offset;
2044 const int32 clamped_output =
2045 std::min(params.quantized_activation_max,
2046 std::max(params.quantized_activation_min, raw_output));
2047 output_data[i] = static_cast<uint8>(clamped_output);
2048 }
2049 }
2050 }
2051
2052 // Scalar-broadcast add that can be used for inner loop of more general
2053 // broadcast add, so that, for example, scalar-broadcast with batch will still
2054 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)2055 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2056 float broadcast_value, const float* input2_data,
2057 float* output_data) {
2058 int i = 0;
2059 #ifdef USE_NEON
2060 const float32x4_t output_activation_min_vector =
2061 vdupq_n_f32(params.float_activation_min);
2062 const float32x4_t output_activation_max_vector =
2063 vdupq_n_f32(params.float_activation_max);
2064 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2065 for (; i <= size - 4; i += 4) {
2066 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2067
2068 const float32x4_t output =
2069 vaddq_f32(input2_val_original, broadcast_value_dup);
2070
2071 const float32x4_t clamped =
2072 vmaxq_f32(output_activation_min_vector,
2073 vminq_f32(output_activation_max_vector, output));
2074 vst1q_f32(output_data + i, clamped);
2075 }
2076 #endif // NEON
2077
2078 for (; i < size; ++i) {
2079 auto x = broadcast_value + input2_data[i];
2080 output_data[i] = ActivationFunctionWithMinMax(
2081 x, params.float_activation_min, params.float_activation_max);
2082 }
2083 }
2084
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2085 inline void Add(const ArithmeticParams& params,
2086 const RuntimeShape& input1_shape, const uint8* input1_data,
2087 const RuntimeShape& input2_shape, const uint8* input2_data,
2088 const RuntimeShape& output_shape, uint8* output_data) {
2089 TFLITE_DCHECK_LE(params.quantized_activation_min,
2090 params.quantized_activation_max);
2091 ruy::profiler::ScopeLabel label("Add/8bit");
2092 const int flat_size =
2093 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2094
2095 TFLITE_DCHECK_GT(params.input1_offset, -256);
2096 TFLITE_DCHECK_GT(params.input2_offset, -256);
2097 TFLITE_DCHECK_LT(params.input1_offset, 256);
2098 TFLITE_DCHECK_LT(params.input2_offset, 256);
2099 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2100 }
2101
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2102 inline void Add(const ArithmeticParams& params,
2103 const RuntimeShape& input1_shape, const int16* input1_data,
2104 const RuntimeShape& input2_shape, const int16* input2_data,
2105 const RuntimeShape& output_shape, int16* output_data) {
2106 ruy::profiler::ScopeLabel label("Add/Int16");
2107 TFLITE_DCHECK_LE(params.quantized_activation_min,
2108 params.quantized_activation_max);
2109
2110 const int input1_shift = params.input1_shift;
2111 const int flat_size =
2112 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2113 const int16 output_activation_min = params.quantized_activation_min;
2114 const int16 output_activation_max = params.quantized_activation_max;
2115
2116 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2117 TFLITE_DCHECK_LE(input1_shift, 0);
2118 TFLITE_DCHECK_LE(params.input2_shift, 0);
2119 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2120 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2121 const int input_right_shift =
2122 input1_shift == 0 ? -params.input2_shift : -input1_shift;
2123
2124 for (int i = 0; i < flat_size; i++) {
2125 // F0 uses 0 integer bits, range [-1, 1].
2126 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2127
2128 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2129 F0 scaled_input = F0::FromRaw(
2130 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2131 F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2132 const int16 raw_output = result.raw();
2133 const int16 clamped_output = std::min(
2134 output_activation_max, std::max(output_activation_min, raw_output));
2135 output_data[i] = clamped_output;
2136 }
2137 }
2138
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2139 inline void Add(const ArithmeticParams& params,
2140 const RuntimeShape& input1_shape, const int32* input1_data,
2141 const RuntimeShape& input2_shape, const int32* input2_data,
2142 const RuntimeShape& output_shape, int32* output_data) {
2143 ruy::profiler::ScopeLabel label("Add/int32");
2144
2145 auto input1_map = MapAsVector(input1_data, input1_shape);
2146 auto input2_map = MapAsVector(input2_data, input2_shape);
2147 auto output_map = MapAsVector(output_data, output_shape);
2148 if (input1_shape == input2_shape) {
2149 output_map.array() = (input1_map.array() + input2_map.array())
2150 .cwiseMax(params.quantized_activation_min)
2151 .cwiseMin(params.quantized_activation_max);
2152 } else if (input2_shape.FlatSize() == 1) {
2153 auto scalar = input2_data[0];
2154 output_map.array() = (input1_map.array() + scalar)
2155 .cwiseMax(params.quantized_activation_min)
2156 .cwiseMin(params.quantized_activation_max);
2157 } else if (input1_shape.FlatSize() == 1) {
2158 auto scalar = input1_data[0];
2159 output_map.array() = (scalar + input2_map.array())
2160 .cwiseMax(params.quantized_activation_min)
2161 .cwiseMin(params.quantized_activation_max);
2162 } else {
2163 reference_ops::BroadcastAdd4DSlow(params, input1_shape, input1_data,
2164 input2_shape, input2_data, output_shape,
2165 output_data);
2166 }
2167 }
2168
2169 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2170 inline void BroadcastAddDispatch(
2171 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2172 const T* input1_data, const RuntimeShape& input2_shape,
2173 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2174 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2175 return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
2176 input2_data, output_shape, output_data);
2177 }
2178
2179 BinaryBroadcastFiveFold(
2180 params, input1_shape, input1_data, input2_shape, input2_data,
2181 output_shape, output_data,
2182 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2183 T*)>(AddElementwise),
2184 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2185 AddScalarBroadcast));
2186 }
2187
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2188 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2189 const RuntimeShape& unswitched_input1_shape,
2190 const uint8* unswitched_input1_data,
2191 const RuntimeShape& unswitched_input2_shape,
2192 const uint8* unswitched_input2_data,
2193 const RuntimeShape& output_shape,
2194 uint8* output_data) {
2195 BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
2196 unswitched_input1_data, unswitched_input2_shape,
2197 unswitched_input2_data, output_shape, output_data);
2198 }
2199
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2200 inline void BroadcastAddFivefold(const ArithmeticParams& params,
2201 const RuntimeShape& unswitched_input1_shape,
2202 const float* unswitched_input1_data,
2203 const RuntimeShape& unswitched_input2_shape,
2204 const float* unswitched_input2_data,
2205 const RuntimeShape& output_shape,
2206 float* output_data) {
2207 BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2208 unswitched_input2_shape, unswitched_input2_data,
2209 output_shape, output_data);
2210 }
2211
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)2212 inline void MulElementwise(int size, const ArithmeticParams& params,
2213 const float* input1_data, const float* input2_data,
2214 float* output_data) {
2215 const float output_activation_min = params.float_activation_min;
2216 const float output_activation_max = params.float_activation_max;
2217
2218 int i = 0;
2219 #ifdef USE_NEON
2220 const auto activation_min = vdupq_n_f32(output_activation_min);
2221 const auto activation_max = vdupq_n_f32(output_activation_max);
2222 for (; i <= size - 16; i += 16) {
2223 auto a10 = vld1q_f32(input1_data + i);
2224 auto a11 = vld1q_f32(input1_data + i + 4);
2225 auto a12 = vld1q_f32(input1_data + i + 8);
2226 auto a13 = vld1q_f32(input1_data + i + 12);
2227 auto a20 = vld1q_f32(input2_data + i);
2228 auto a21 = vld1q_f32(input2_data + i + 4);
2229 auto a22 = vld1q_f32(input2_data + i + 8);
2230 auto a23 = vld1q_f32(input2_data + i + 12);
2231 auto x0 = vmulq_f32(a10, a20);
2232 auto x1 = vmulq_f32(a11, a21);
2233 auto x2 = vmulq_f32(a12, a22);
2234 auto x3 = vmulq_f32(a13, a23);
2235
2236 x0 = vmaxq_f32(activation_min, x0);
2237 x1 = vmaxq_f32(activation_min, x1);
2238 x2 = vmaxq_f32(activation_min, x2);
2239 x3 = vmaxq_f32(activation_min, x3);
2240 x0 = vminq_f32(activation_max, x0);
2241 x1 = vminq_f32(activation_max, x1);
2242 x2 = vminq_f32(activation_max, x2);
2243 x3 = vminq_f32(activation_max, x3);
2244
2245 vst1q_f32(output_data + i, x0);
2246 vst1q_f32(output_data + i + 4, x1);
2247 vst1q_f32(output_data + i + 8, x2);
2248 vst1q_f32(output_data + i + 12, x3);
2249 }
2250 for (; i <= size - 4; i += 4) {
2251 auto a1 = vld1q_f32(input1_data + i);
2252 auto a2 = vld1q_f32(input2_data + i);
2253 auto x = vmulq_f32(a1, a2);
2254
2255 x = vmaxq_f32(activation_min, x);
2256 x = vminq_f32(activation_max, x);
2257
2258 vst1q_f32(output_data + i, x);
2259 }
2260 #endif // NEON
2261
2262 for (; i < size; i++) {
2263 auto x = input1_data[i] * input2_data[i];
2264 output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
2265 output_activation_max);
2266 }
2267 }
2268
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2269 inline void Mul(const ArithmeticParams& params,
2270 const RuntimeShape& input1_shape, const float* input1_data,
2271 const RuntimeShape& input2_shape, const float* input2_data,
2272 const RuntimeShape& output_shape, float* output_data) {
2273 ruy::profiler::ScopeLabel label("Mul");
2274
2275 const int flat_size =
2276 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2277 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2278 }
2279
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2280 inline void Mul(const ArithmeticParams& params,
2281 const RuntimeShape& input1_shape, const int32* input1_data,
2282 const RuntimeShape& input2_shape, const int32* input2_data,
2283 const RuntimeShape& output_shape, int32* output_data) {
2284 ruy::profiler::ScopeLabel label("Mul/int32/activation");
2285
2286 const int flat_size =
2287 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2288 const int32 output_activation_min = params.quantized_activation_min;
2289 const int32 output_activation_max = params.quantized_activation_max;
2290 for (int i = 0; i < flat_size; ++i) {
2291 output_data[i] = ActivationFunctionWithMinMax(
2292 input1_data[i] * input2_data[i], output_activation_min,
2293 output_activation_max);
2294 }
2295 }
2296
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2297 inline void MulNoActivation(const ArithmeticParams& params,
2298 const RuntimeShape& input1_shape,
2299 const int32* input1_data,
2300 const RuntimeShape& input2_shape,
2301 const int32* input2_data,
2302 const RuntimeShape& output_shape,
2303 int32* output_data) {
2304 ruy::profiler::ScopeLabel label("Mul/int32");
2305
2306 auto input1_map = MapAsVector(input1_data, input1_shape);
2307 auto input2_map = MapAsVector(input2_data, input2_shape);
2308 auto output_map = MapAsVector(output_data, output_shape);
2309 if (input1_shape == input2_shape) {
2310 output_map.array() = input1_map.array() * input2_map.array();
2311 } else if (input2_shape.FlatSize() == 1) {
2312 auto scalar = input2_data[0];
2313 output_map.array() = input1_map.array() * scalar;
2314 } else if (input1_shape.FlatSize() == 1) {
2315 auto scalar = input1_data[0];
2316 output_map.array() = scalar * input2_map.array();
2317 } else {
2318 reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
2319 input2_shape, input2_data, output_shape,
2320 output_data);
2321 }
2322 }
2323
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2324 inline void Mul(const ArithmeticParams& params,
2325 const RuntimeShape& input1_shape, const int16* input1_data,
2326 const RuntimeShape& input2_shape, const int16* input2_data,
2327 const RuntimeShape& output_shape, int16* output_data) {
2328 ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2329 // This is a copy of the reference implementation. We do not currently have a
2330 // properly optimized version.
2331
2332 const int flat_size =
2333 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2334
2335 for (int i = 0; i < flat_size; i++) {
2336 // F0 uses 0 integer bits, range [-1, 1].
2337 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2338
2339 F0 unclamped_result =
2340 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2341 output_data[i] = unclamped_result.raw();
2342 }
2343 }
2344
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2345 inline void Mul(const ArithmeticParams& params,
2346 const RuntimeShape& input1_shape, const int16* input1_data,
2347 const RuntimeShape& input2_shape, const int16* input2_data,
2348 const RuntimeShape& output_shape, uint8* output_data) {
2349 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2350 // This is a copy of the reference implementation. We do not currently have a
2351 // properly optimized version.
2352 const int32 output_activation_min = params.quantized_activation_min;
2353 const int32 output_activation_max = params.quantized_activation_max;
2354 const int32 output_offset = params.output_offset;
2355 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2356
2357 const int flat_size =
2358 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2359
2360 for (int i = 0; i < flat_size; i++) {
2361 // F0 uses 0 integer bits, range [-1, 1].
2362 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2363
2364 F0 unclamped_result =
2365 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2366 int16 rescaled_result =
2367 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2368 int16 clamped_result =
2369 std::min<int16>(output_activation_max - output_offset, rescaled_result);
2370 clamped_result =
2371 std::max<int16>(output_activation_min - output_offset, clamped_result);
2372 output_data[i] = output_offset + clamped_result;
2373 }
2374 }
2375
2376 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2377 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2378 inline void MulElementwise(int size, const ArithmeticParams& params,
2379 const uint8* input1_data, const uint8* input2_data,
2380 uint8* output_data) {
2381 int i = 0;
2382 TFLITE_DCHECK_GT(params.input1_offset, -256);
2383 TFLITE_DCHECK_LT(params.input1_offset, 256);
2384 TFLITE_DCHECK_GT(params.input2_offset, -256);
2385 TFLITE_DCHECK_LT(params.input2_offset, 256);
2386 TFLITE_DCHECK_GT(params.output_offset, -256);
2387 TFLITE_DCHECK_LT(params.output_offset, 256);
2388 #ifdef USE_NEON
2389 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2390 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2391 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2392 const auto output_activation_min_vector =
2393 vdup_n_u8(params.quantized_activation_min);
2394 const auto output_activation_max_vector =
2395 vdup_n_u8(params.quantized_activation_max);
2396 const int left_shift = std::max(0, params.output_shift);
2397 const int right_shift = std::max(0, -params.output_shift);
2398 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2399 for (; i <= size - 8; i += 8) {
2400 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2401 const auto input1_val_original = vld1_u8(input1_data + i);
2402 const auto input2_val_original = vld1_u8(input2_data + i);
2403 const auto input1_val_s16 =
2404 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2405 const auto input2_val_s16 =
2406 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2407 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2408 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2409
2410 const auto input1_val_low = vget_low_s16(input1_val);
2411 const auto input1_val_high = vget_high_s16(input1_val);
2412 const auto input2_val_low = vget_low_s16(input2_val);
2413 const auto input2_val_high = vget_high_s16(input2_val);
2414
2415 auto p1 = vmull_s16(input2_val_low, input1_val_low);
2416 auto p2 = vmull_s16(input2_val_high, input1_val_high);
2417
2418 p1 = vshlq_s32(p1, left_shift_vec);
2419 p2 = vshlq_s32(p2, left_shift_vec);
2420 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2421 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2422 using gemmlowp::RoundingDivideByPOT;
2423 p1 = RoundingDivideByPOT(p1, right_shift);
2424 p2 = RoundingDivideByPOT(p2, right_shift);
2425
2426 const auto p1_narrowed = vqmovn_s32(p1);
2427 const auto p2_narrowed = vqmovn_s32(p2);
2428 const auto p =
2429 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2430 const auto clamped =
2431 vmax_u8(output_activation_min_vector,
2432 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2433 vst1_u8(output_data + i, clamped);
2434 }
2435 #endif // NEON
2436
2437 for (; i < size; ++i) {
2438 const int32 input1_val = params.input1_offset + input1_data[i];
2439 const int32 input2_val = params.input2_offset + input2_data[i];
2440 const int32 unclamped_result =
2441 params.output_offset +
2442 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2443 params.output_multiplier,
2444 params.output_shift);
2445 const int32 clamped_output =
2446 std::min(params.quantized_activation_max,
2447 std::max(params.quantized_activation_min, unclamped_result));
2448 output_data[i] = static_cast<uint8>(clamped_output);
2449 }
2450 }
2451
2452 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2453 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2454 const uint8 broadcast_value,
2455 const uint8* input2_data, uint8* output_data) {
2456 const int16 input1_val = params.input1_offset + broadcast_value;
2457
2458 int i = 0;
2459 TFLITE_DCHECK_GT(params.input1_offset, -256);
2460 TFLITE_DCHECK_LT(params.input1_offset, 256);
2461 TFLITE_DCHECK_GT(params.input2_offset, -256);
2462 TFLITE_DCHECK_LT(params.input2_offset, 256);
2463 TFLITE_DCHECK_GT(params.output_offset, -256);
2464 TFLITE_DCHECK_LT(params.output_offset, 256);
2465 #ifdef USE_NEON
2466 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2467 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2468 const auto output_activation_min_vector =
2469 vdup_n_u8(params.quantized_activation_min);
2470 const auto output_activation_max_vector =
2471 vdup_n_u8(params.quantized_activation_max);
2472 const int left_shift = std::max(0, params.output_shift);
2473 const int right_shift = std::max(0, -params.output_shift);
2474 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2475 for (; i <= size - 8; i += 8) {
2476 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2477 const auto input2_val_original = vld1_u8(input2_data + i);
2478 const auto input2_val_s16 =
2479 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2480 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2481
2482 const auto input2_val_low = vget_low_s16(input2_val);
2483 const auto input2_val_high = vget_high_s16(input2_val);
2484
2485 auto p1 = vmull_n_s16(input2_val_low, input1_val);
2486 auto p2 = vmull_n_s16(input2_val_high, input1_val);
2487
2488 p1 = vshlq_s32(p1, left_shift_vec);
2489 p2 = vshlq_s32(p2, left_shift_vec);
2490 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2491 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2492 using gemmlowp::RoundingDivideByPOT;
2493 p1 = RoundingDivideByPOT(p1, right_shift);
2494 p2 = RoundingDivideByPOT(p2, right_shift);
2495
2496 const auto p1_narrowed = vmovn_s32(p1);
2497 const auto p2_narrowed = vmovn_s32(p2);
2498 const auto p =
2499 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2500 const auto clamped =
2501 vmax_u8(output_activation_min_vector,
2502 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2503 vst1_u8(output_data + i, clamped);
2504 }
2505 #endif // NEON
2506
2507 for (; i < size; ++i) {
2508 const int32 input2_val = params.input2_offset + input2_data[i];
2509 const int32 unclamped_result =
2510 params.output_offset +
2511 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2512 params.output_multiplier,
2513 params.output_shift);
2514 const int32 clamped_output =
2515 std::min(params.quantized_activation_max,
2516 std::max(params.quantized_activation_min, unclamped_result));
2517 output_data[i] = static_cast<uint8>(clamped_output);
2518 }
2519 }
2520
2521 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2522 // This function will handle scalar_value (LHS) * vector_values (RHS).
2523 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2524 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2525 const float broadcast_value,
2526 const float* input2_data, float* output_data) {
2527 int i = 0;
2528 #ifdef USE_NEON
2529 const float32x4_t output_activation_min_vector =
2530 vdupq_n_f32(params.float_activation_min);
2531 const float32x4_t output_activation_max_vector =
2532 vdupq_n_f32(params.float_activation_max);
2533 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2534 for (; i <= size - 4; i += 4) {
2535 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2536
2537 const float32x4_t output =
2538 vmulq_f32(input2_val_original, broadcast_value_dup);
2539
2540 const float32x4_t clamped =
2541 vmaxq_f32(output_activation_min_vector,
2542 vminq_f32(output_activation_max_vector, output));
2543 vst1q_f32(output_data + i, clamped);
2544 }
2545 #endif // NEON
2546
2547 for (; i < size; ++i) {
2548 float x = broadcast_value * input2_data[i];
2549 output_data[i] = ActivationFunctionWithMinMax(
2550 x, params.float_activation_min, params.float_activation_max);
2551 }
2552 }
2553
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2554 inline void Mul(const ArithmeticParams& params,
2555 const RuntimeShape& input1_shape, const uint8* input1_data,
2556 const RuntimeShape& input2_shape, const uint8* input2_data,
2557 const RuntimeShape& output_shape, uint8* output_data) {
2558 TFLITE_DCHECK_LE(params.quantized_activation_min,
2559 params.quantized_activation_max);
2560 ruy::profiler::ScopeLabel label("Mul/8bit");
2561 const int flat_size =
2562 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2563
2564 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2565 }
2566
2567 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2568 inline void BroadcastMulDispatch(
2569 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2570 const T* input1_data, const RuntimeShape& input2_shape,
2571 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2572 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2573 return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2574 input2_data, output_shape, output_data);
2575 }
2576
2577 BinaryBroadcastFiveFold(
2578 params, input1_shape, input1_data, input2_shape, input2_data,
2579 output_shape, output_data,
2580 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2581 T*)>(MulElementwise),
2582 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2583 MulSimpleBroadcast));
2584 }
2585
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2586 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2587 const RuntimeShape& unswitched_input1_shape,
2588 const uint8* unswitched_input1_data,
2589 const RuntimeShape& unswitched_input2_shape,
2590 const uint8* unswitched_input2_data,
2591 const RuntimeShape& output_shape,
2592 uint8* output_data) {
2593 BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2594 unswitched_input1_data, unswitched_input2_shape,
2595 unswitched_input2_data, output_shape, output_data);
2596 }
2597
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2598 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2599 const RuntimeShape& unswitched_input1_shape,
2600 const float* unswitched_input1_data,
2601 const RuntimeShape& unswitched_input2_shape,
2602 const float* unswitched_input2_data,
2603 const RuntimeShape& output_shape,
2604 float* output_data) {
2605 BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2606 unswitched_input2_shape, unswitched_input2_data,
2607 output_shape, output_data);
2608 }
2609
2610 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2611 // dimensionality if the runtime code does a single loop over one dimension
2612 // that handles broadcasting as the base case. The code generator would then
2613 // generate max(D1, D2) nested for loops.
2614 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2615 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2616 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2617 // reference_ops.h.
2618 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2619 void BroadcastDivSlow(const ArithmeticParams& params,
2620 const RuntimeShape& unextended_input1_shape,
2621 const T* input1_data,
2622 const RuntimeShape& unextended_input2_shape,
2623 const T* input2_data,
2624 const RuntimeShape& unextended_output_shape,
2625 T* output_data) {
2626 ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2627 T output_activation_min;
2628 T output_activation_max;
2629 GetActivationParams(params, &output_activation_min, &output_activation_max);
2630
2631 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2632 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2633 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2634
2635 NdArrayDesc<N> desc1;
2636 NdArrayDesc<N> desc2;
2637 NdArrayDesc<N> output_desc;
2638 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2639 unextended_input2_shape, &desc1, &desc2);
2640 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2641 &output_desc);
2642
2643 // In Tensorflow, the dimensions are canonically named (batch_number, row,
2644 // col, channel), with extents (batches, height, width, depth), with the
2645 // trailing dimension changing most rapidly (channels has the smallest stride,
2646 // typically 1 element).
2647 //
2648 // In generated C code, we store arrays with the dimensions reversed. The
2649 // first dimension has smallest stride.
2650 //
2651 // We name our variables by their Tensorflow convention, but generate C code
2652 // nesting loops such that the innermost loop has the smallest stride for the
2653 // best cache behavior.
2654 auto div_func = [&](int indexes[N]) {
2655 output_data[SubscriptToIndex(output_desc, indexes)] =
2656 ActivationFunctionWithMinMax(
2657 input1_data[SubscriptToIndex(desc1, indexes)] /
2658 input2_data[SubscriptToIndex(desc2, indexes)],
2659 output_activation_min, output_activation_max);
2660 };
2661 NDOpsHelper<N>(output_desc, div_func);
2662 }
2663
2664 // TODO: BroadcastDiv is intentionally duplicated from reference_ops.h.
2665 // For more details see the comment above the generic version of
2666 // BroadcastDivSlow.
2667 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2668 inline void BroadcastDivSlow(const ArithmeticParams& params,
2669 const RuntimeShape& unextended_input1_shape,
2670 const uint8* input1_data,
2671 const RuntimeShape& unextended_input2_shape,
2672 const uint8* input2_data,
2673 const RuntimeShape& unextended_output_shape,
2674 uint8* output_data) {
2675 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2676 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2677 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2678
2679 NdArrayDesc<N> desc1;
2680 NdArrayDesc<N> desc2;
2681 NdArrayDesc<N> output_desc;
2682 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2683 unextended_input2_shape, &desc1, &desc2);
2684 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2685 &output_desc);
2686
2687 TFLITE_DCHECK_GT(params.input1_offset, -256);
2688 TFLITE_DCHECK_LT(params.input1_offset, 256);
2689 TFLITE_DCHECK_GT(params.input2_offset, -256);
2690 TFLITE_DCHECK_LT(params.input2_offset, 256);
2691 TFLITE_DCHECK_GT(params.output_offset, -256);
2692 TFLITE_DCHECK_LT(params.output_offset, 256);
2693
2694 auto div_func = [&](int indexes[N]) {
2695 const int32 input1_val =
2696 params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2697 const int32 input2_val =
2698 params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2699 TFLITE_DCHECK_NE(input2_val, 0);
2700 int recip_shift;
2701 const int32 input2_inv =
2702 (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
2703 : -GetReciprocal(-input2_val, 31, &recip_shift);
2704 const int headroom = CountLeadingSignBits(input1_val);
2705 const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2706 input1_val, input2_inv, headroom);
2707 const int total_shift = params.output_shift - recip_shift - headroom;
2708 const int32 unclamped_result =
2709 params.output_offset +
2710 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2711 unscaled_quotient, params.output_multiplier, total_shift);
2712 const int32 clamped_output =
2713 std::min(params.quantized_activation_max,
2714 std::max(params.quantized_activation_min, unclamped_result));
2715 output_data[SubscriptToIndex(output_desc, indexes)] =
2716 static_cast<uint8>(clamped_output);
2717 };
2718 NDOpsHelper<N>(output_desc, div_func);
2719 }
2720
2721 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2722 inline void SubWithActivation(
2723 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2724 const T* input1_data, const RuntimeShape& input2_shape,
2725 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2726 ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2727 TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2728 auto input1_map = MapAsVector(input1_data, input1_shape);
2729 auto input2_map = MapAsVector(input2_data, input2_shape);
2730 auto output_map = MapAsVector(output_data, output_shape);
2731 T activation_min, activation_max;
2732 GetActivationParams(params, &activation_min, &activation_max);
2733 output_map.array() = (input1_map.array() - input2_map.array())
2734 .cwiseMin(activation_max)
2735 .cwiseMax(activation_min);
2736 }
2737
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2738 inline void SubNonBroadcast(const ArithmeticParams& params,
2739 const RuntimeShape& input1_shape,
2740 const float* input1_data,
2741 const RuntimeShape& input2_shape,
2742 const float* input2_data,
2743 const RuntimeShape& output_shape,
2744 float* output_data) {
2745 ruy::profiler::ScopeLabel label("SubNonBroadcast");
2746 SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2747 input2_data, output_shape, output_data);
2748 }
2749
2750 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2751 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2752 const T* input1_data, const RuntimeShape& input2_shape,
2753 const T* input2_data, const RuntimeShape& output_shape,
2754 T* output_data) {
2755 ruy::profiler::ScopeLabel label("Sub");
2756
2757 auto input1_map = MapAsVector(input1_data, input1_shape);
2758 auto input2_map = MapAsVector(input2_data, input2_shape);
2759 auto output_map = MapAsVector(output_data, output_shape);
2760 if (input1_shape == input2_shape) {
2761 output_map.array() = input1_map.array() - input2_map.array();
2762 } else if (input1_shape.FlatSize() == 1) {
2763 auto scalar = input1_data[0];
2764 output_map.array() = scalar - input2_map.array();
2765 } else if (input2_shape.FlatSize() == 1) {
2766 auto scalar = input2_data[0];
2767 output_map.array() = input1_map.array() - scalar;
2768 } else {
2769 BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2770 input2_data, output_shape, output_data);
2771 }
2772 }
2773
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2774 inline void LstmCell(
2775 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2776 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2777 const float* prev_activ_data, const RuntimeShape& weights_shape,
2778 const float* weights_data, const RuntimeShape& unextended_bias_shape,
2779 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2780 const float* prev_state_data,
2781 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2782 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2783 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2784 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2785 CpuBackendContext* cpu_backend_context) {
2786 ruy::profiler::ScopeLabel label("LstmCell");
2787 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2788 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2789 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2790 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2791 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2792 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2793 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2794 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2795 const RuntimeShape input_shape =
2796 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2797 const RuntimeShape prev_activ_shape =
2798 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2799 const RuntimeShape bias_shape =
2800 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2801 const RuntimeShape prev_state_shape =
2802 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2803 const RuntimeShape output_state_shape =
2804 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2805 const RuntimeShape output_activ_shape =
2806 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2807 const RuntimeShape concat_temp_shape =
2808 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2809 const RuntimeShape activ_temp_shape =
2810 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2811 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2812
2813 const int weights_dim_count = weights_shape.DimensionsCount();
2814 MatchingDim( // batches
2815 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2816 output_state_shape, 0, output_activ_shape, 0);
2817 MatchingDim( // height
2818 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2819 output_state_shape, 1, output_activ_shape, 1);
2820 MatchingDim( // width
2821 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2822 output_state_shape, 2, output_activ_shape, 2);
2823 const int input_depth = input_shape.Dims(3);
2824 const int prev_activ_depth = prev_activ_shape.Dims(3);
2825 const int total_input_depth = prev_activ_depth + input_depth;
2826 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2827 total_input_depth);
2828 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2829 const int intern_activ_depth =
2830 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2831 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2832 intern_activ_depth * total_input_depth);
2833 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2834 const int output_depth =
2835 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2836 3, output_activ_shape, 3);
2837 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2838
2839 // Concatenate prev_activ and input data together
2840 std::vector<float const*> concat_input_arrays_data;
2841 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2842 concat_input_arrays_data.push_back(input_data);
2843 concat_input_arrays_data.push_back(prev_activ_data);
2844 concat_input_arrays_shapes.push_back(&input_shape);
2845 concat_input_arrays_shapes.push_back(&prev_activ_shape);
2846 tflite::ConcatenationParams concat_params;
2847 concat_params.axis = 3;
2848 concat_params.inputs_count = concat_input_arrays_data.size();
2849 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2850 &(concat_input_arrays_data[0]), concat_temp_shape,
2851 concat_temp_data);
2852
2853 // Fully connected
2854 tflite::FullyConnectedParams fc_params;
2855 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2856 fc_params.float_activation_max = std::numeric_limits<float>::max();
2857 fc_params.lhs_cacheable = false;
2858 fc_params.rhs_cacheable = false;
2859 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2860 weights_data, bias_shape, bias_data, activ_temp_shape,
2861 activ_temp_data, cpu_backend_context);
2862
2863 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2864 // operations.
2865 ArrayMap<float> activ_temp_map =
2866 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2867 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2868 activ_temp_map.cols());
2869 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2870 activ_temp_map.cols());
2871 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2872 activ_temp_map.cols());
2873 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2874 activ_temp_map.cols());
2875 ArrayMap<const float> prev_state_map =
2876 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2877 ArrayMap<float> output_state_map =
2878 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2879 ArrayMap<float> output_activ_map =
2880 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2881
2882 // Combined memory state and final output calculation
2883 ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2884 output_state_map =
2885 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2886 new_input_sm.tanh() +
2887 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2888 prev_state_map;
2889 output_activ_map =
2890 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2891 output_state_map.tanh();
2892 }
2893
2894 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2895 inline void LstmCell(
2896 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2897 const uint8* input_data_uint8,
2898 const RuntimeShape& unextended_prev_activ_shape,
2899 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2900 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2901 const int32* bias_data_int32,
2902 const RuntimeShape& unextended_prev_state_shape,
2903 const int16* prev_state_data_int16,
2904 const RuntimeShape& unextended_output_state_shape,
2905 int16* output_state_data_int16,
2906 const RuntimeShape& unextended_output_activ_shape,
2907 uint8* output_activ_data_uint8,
2908 const RuntimeShape& unextended_concat_temp_shape,
2909 uint8* concat_temp_data_uint8,
2910 const RuntimeShape& unextended_activ_temp_shape,
2911 int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2912 ruy::profiler::ScopeLabel label(
2913 "LstmCell/quantized (8bit external, 16bit internal)");
2914 int32 weights_zero_point = params.weights_zero_point;
2915 int32 accum_multiplier = params.accum_multiplier;
2916 int accum_shift = params.accum_shift;
2917 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2918 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2919 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2920 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2921 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2922 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2923 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2924 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2925 const RuntimeShape input_shape =
2926 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2927 const RuntimeShape prev_activ_shape =
2928 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2929 const RuntimeShape bias_shape =
2930 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2931 const RuntimeShape prev_state_shape =
2932 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2933 const RuntimeShape output_state_shape =
2934 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2935 const RuntimeShape output_activ_shape =
2936 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2937 const RuntimeShape concat_temp_shape =
2938 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2939 const RuntimeShape activ_temp_shape =
2940 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2941 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2942
2943 // Gather dimensions information, and perform consistency checks.
2944 const int weights_dim_count = weights_shape.DimensionsCount();
2945 const int outer_size = MatchingFlatSizeSkipDim(
2946 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2947 output_activ_shape);
2948 const int input_depth = input_shape.Dims(3);
2949 const int prev_activ_depth = prev_activ_shape.Dims(3);
2950 const int total_input_depth = prev_activ_depth + input_depth;
2951 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2952 total_input_depth);
2953 const int intern_activ_depth =
2954 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2955 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2956 intern_activ_depth * total_input_depth);
2957 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2958 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2959 const int output_depth =
2960 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2961 3, output_activ_shape, 3);
2962 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2963 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2964 const int fc_output_depth =
2965 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2966 const int fc_accum_depth = total_input_depth;
2967 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2968
2969 // Depth-concatenate prev_activ and input data together.
2970 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2971 prev_activ_data_uint8};
2972 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2973 &prev_activ_shape};
2974 tflite::ConcatenationParams concat_params;
2975 concat_params.axis = 3;
2976 concat_params.inputs_count = 2;
2977 Concatenation(concat_params, concat_input_arrays_shapes,
2978 concat_input_arrays_data, concat_temp_shape,
2979 concat_temp_data_uint8);
2980
2981 // Implementation of the fully connected node inside the LSTM cell.
2982 // The operands are 8-bit integers, the accumulators are internally 32bit
2983 // integers, and the output is 16-bit fixed-point with 3 integer bits so
2984 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2985 // is explained in the function comment above.
2986 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2987 lhs_params.rows = fc_output_depth;
2988 lhs_params.cols = fc_accum_depth;
2989 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2990 lhs_params.zero_point = weights_zero_point;
2991 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2992 rhs_params.rows = fc_accum_depth;
2993 rhs_params.cols = fc_batches;
2994 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2995 rhs_params.zero_point = 128;
2996 cpu_backend_gemm::MatrixParams<int16> dst_params;
2997 dst_params.rows = fc_output_depth;
2998 dst_params.cols = fc_batches;
2999 dst_params.order = cpu_backend_gemm::Order::kColMajor;
3000 dst_params.zero_point = 0;
3001 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
3002 gemm_params.bias = bias_data_int32;
3003 gemm_params.multiplier_fixedpoint = accum_multiplier;
3004 gemm_params.multiplier_exponent = accum_shift;
3005 cpu_backend_gemm::Gemm(
3006 lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
3007 dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
3008
3009 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3010 // and muls, all done in 16-bit fixed-point.
3011 const int16* input_gate_input_ptr = activ_temp_data_int16;
3012 const int16* input_modulation_gate_input_ptr =
3013 activ_temp_data_int16 + output_depth;
3014 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3015 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3016 const int16* prev_state_ptr = prev_state_data_int16;
3017 int16* output_state_data_ptr = output_state_data_int16;
3018 uint8* output_activ_data_ptr = output_activ_data_uint8;
3019
3020 for (int b = 0; b < outer_size; ++b) {
3021 int c = 0;
3022 #ifdef GEMMLOWP_NEON
3023 for (; c <= output_depth - 8; c += 8) {
3024 // Define the fixed-point data types that we will use here. All use
3025 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3026 // They only differ by the number of integral vs. fractional bits,
3027 // determining the range of values that they can represent.
3028 //
3029 // F0 uses 0 integer bits, range [-1, 1].
3030 // This is the return type of math functions such as tanh, logistic,
3031 // whose range is in [-1, 1].
3032 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3033 // F3 uses 3 integer bits, range [-8, 8].
3034 // This is the range of the previous fully-connected node's output,
3035 // which is our input here.
3036 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3037 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3038 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3039 // number of integer bits is currently dictated by the model. See comment
3040 // on the StateIntegerBits template parameter above.
3041 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3042 // Implementation of input gate, using fixed-point logistic function.
3043 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3044 input_gate_input_ptr += 8;
3045 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3046 // Implementation of input modulation gate, using fixed-point tanh
3047 // function.
3048 F3 input_modulation_gate_input =
3049 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3050 input_modulation_gate_input_ptr += 8;
3051 F0 input_modulation_gate_output =
3052 gemmlowp::tanh(input_modulation_gate_input);
3053 // Implementation of forget gate, using fixed-point logistic function.
3054 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3055 forget_gate_input_ptr += 8;
3056 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3057 // Implementation of output gate, using fixed-point logistic function.
3058 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3059 output_gate_input_ptr += 8;
3060 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3061 // Implementation of internal multiplication nodes, still in fixed-point.
3062 F0 input_times_input_modulation =
3063 input_gate_output * input_modulation_gate_output;
3064 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3065 prev_state_ptr += 8;
3066 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3067 // Implementation of internal addition node, saturating.
3068 FS new_state = gemmlowp::SaturatingAdd(
3069 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3070 prev_state_times_forget_state);
3071 // Implementation of last internal Tanh node, still in fixed-point.
3072 // Since a Tanh fixed-point implementation is specialized for a given
3073 // number or integer bits, and each specialization can have a substantial
3074 // code size, and we already used above a Tanh on an input with 3 integer
3075 // bits, and per the table in the above function comment there is no
3076 // significant accuracy to be lost by clamping to [-8, +8] for a
3077 // 3-integer-bits representation, let us just do that. This helps people
3078 // porting this to targets where code footprint must be minimized.
3079 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3080 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3081 // Store the new internal state back to memory, as 16-bit integers.
3082 // Note: here we store the original value with StateIntegerBits, not
3083 // the rescaled 3-integer-bits value fed to tanh.
3084 vst1q_s16(output_state_data_ptr, new_state.raw());
3085 output_state_data_ptr += 8;
3086 // Down-scale the output activations to 8-bit integers, saturating,
3087 // and store back to memory.
3088 int16x8_t rescaled_output_activ =
3089 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3090 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3091 uint8x8_t uint8_output_activ =
3092 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3093 vst1_u8(output_activ_data_ptr, uint8_output_activ);
3094 output_activ_data_ptr += 8;
3095 }
3096 #endif
3097 for (; c < output_depth; ++c) {
3098 // Define the fixed-point data types that we will use here. All use
3099 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3100 // They only differ by the number of integral vs. fractional bits,
3101 // determining the range of values that they can represent.
3102 //
3103 // F0 uses 0 integer bits, range [-1, 1].
3104 // This is the return type of math functions such as tanh, logistic,
3105 // whose range is in [-1, 1].
3106 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3107 // F3 uses 3 integer bits, range [-8, 8].
3108 // This is the range of the previous fully-connected node's output,
3109 // which is our input here.
3110 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3111 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3112 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3113 // number of integer bits is currently dictated by the model. See comment
3114 // on the StateIntegerBits template parameter above.
3115 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3116 // Implementation of input gate, using fixed-point logistic function.
3117 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3118 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3119 // Implementation of input modulation gate, using fixed-point tanh
3120 // function.
3121 F3 input_modulation_gate_input =
3122 F3::FromRaw(*input_modulation_gate_input_ptr++);
3123 F0 input_modulation_gate_output =
3124 gemmlowp::tanh(input_modulation_gate_input);
3125 // Implementation of forget gate, using fixed-point logistic function.
3126 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3127 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3128 // Implementation of output gate, using fixed-point logistic function.
3129 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3130 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3131 // Implementation of internal multiplication nodes, still in fixed-point.
3132 F0 input_times_input_modulation =
3133 input_gate_output * input_modulation_gate_output;
3134 FS prev_state = FS::FromRaw(*prev_state_ptr++);
3135 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3136 // Implementation of internal addition node, saturating.
3137 FS new_state = gemmlowp::SaturatingAdd(
3138 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3139 prev_state_times_forget_state);
3140 // Implementation of last internal Tanh node, still in fixed-point.
3141 // Since a Tanh fixed-point implementation is specialized for a given
3142 // number or integer bits, and each specialization can have a substantial
3143 // code size, and we already used above a Tanh on an input with 3 integer
3144 // bits, and per the table in the above function comment there is no
3145 // significant accuracy to be lost by clamping to [-8, +8] for a
3146 // 3-integer-bits representation, let us just do that. This helps people
3147 // porting this to targets where code footprint must be minimized.
3148 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3149 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3150 // Store the new internal state back to memory, as 16-bit integers.
3151 // Note: here we store the original value with StateIntegerBits, not
3152 // the rescaled 3-integer-bits value fed to tanh.
3153 *output_state_data_ptr++ = new_state.raw();
3154 // Down-scale the output activations to 8-bit integers, saturating,
3155 // and store back to memory.
3156 int16 rescaled_output_activ =
3157 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3158 int16 clamped_output_activ =
3159 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3160 *output_activ_data_ptr++ = 128 + clamped_output_activ;
3161 }
3162 input_gate_input_ptr += 3 * output_depth;
3163 input_modulation_gate_input_ptr += 3 * output_depth;
3164 forget_gate_input_ptr += 3 * output_depth;
3165 output_gate_input_ptr += 3 * output_depth;
3166 }
3167 }
3168
NodeOffset(int b,int h,int w,int height,int width)3169 inline int NodeOffset(int b, int h, int w, int height, int width) {
3170 return (b * height + h) * width + w;
3171 }
3172
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3173 inline void AveragePool(const PoolParams& params,
3174 const RuntimeShape& input_shape,
3175 const float* input_data,
3176 const RuntimeShape& output_shape, float* output_data) {
3177 ruy::profiler::ScopeLabel label("AveragePool");
3178 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3179 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3180 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3181 const int input_height = input_shape.Dims(1);
3182 const int input_width = input_shape.Dims(2);
3183 const int output_height = output_shape.Dims(1);
3184 const int output_width = output_shape.Dims(2);
3185 const int stride_height = params.stride_height;
3186 const int stride_width = params.stride_width;
3187
3188 // TODO(benoitjacob) make this a proper reference impl without Eigen!
3189 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3190 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3191 // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3192 Eigen::VectorXf out_count(out_mat.cols());
3193 out_count.setZero();
3194 // Prefill the output to 0.
3195 out_mat.setZero();
3196 for (int b = 0; b < batches; ++b) {
3197 for (int h = 0; h < input_height; ++h) {
3198 for (int w = 0; w < input_width; ++w) {
3199 // (h_start, h_end) * (w_start, w_end) is the range that the input
3200 // vector projects to.
3201 int hpad = h + params.padding_values.height;
3202 int wpad = w + params.padding_values.width;
3203 int h_start = (hpad < params.filter_height)
3204 ? 0
3205 : (hpad - params.filter_height) / stride_height + 1;
3206 int h_end = std::min(hpad / stride_height + 1, output_height);
3207 int w_start = (wpad < params.filter_width)
3208 ? 0
3209 : (wpad - params.filter_width) / stride_width + 1;
3210 int w_end = std::min(wpad / stride_width + 1, output_width);
3211 // compute elementwise sum
3212 for (int ph = h_start; ph < h_end; ++ph) {
3213 for (int pw = w_start; pw < w_end; ++pw) {
3214 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3215 out_mat.col(out_offset) +=
3216 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
3217 out_count(out_offset)++;
3218 }
3219 }
3220 }
3221 }
3222 }
3223 // Divide the output by the actual number of elements being averaged over
3224 TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
3225 out_mat.array().rowwise() /= out_count.transpose().array();
3226
3227 const int flat_size = output_shape.FlatSize();
3228 for (int i = 0; i < flat_size; ++i) {
3229 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3230 params.float_activation_min,
3231 params.float_activation_max);
3232 }
3233 }
3234
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3235 inline void AveragePool(const PoolParams& params,
3236 const RuntimeShape& input_shape,
3237 const uint8* input_data,
3238 const RuntimeShape& output_shape, uint8* output_data) {
3239 ruy::profiler::ScopeLabel label("AveragePool/8bit");
3240
3241 // Here, and in other pooling ops, in order to maintain locality of reference,
3242 // to minimize some recalculations, and to load into NEON vector registers, we
3243 // use an inner loop down the depth. Since depths can be large and hence we
3244 // would need arbitrarily large temporary storage, we divide the work up into
3245 // depth tranches just within the batch loop.
3246 static constexpr int kPoolingAccTrancheSize = 256;
3247
3248 TFLITE_DCHECK_LE(params.quantized_activation_min,
3249 params.quantized_activation_max);
3250 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3251 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3252 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3253 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3254 const int input_height = input_shape.Dims(1);
3255 const int input_width = input_shape.Dims(2);
3256 const int output_height = output_shape.Dims(1);
3257 const int output_width = output_shape.Dims(2);
3258 const int stride_height = params.stride_height;
3259 const int stride_width = params.stride_width;
3260
3261 uint32 acc[kPoolingAccTrancheSize];
3262 for (int batch = 0; batch < batches; ++batch) {
3263 // We proceed through the depth in tranches (see comment above). The
3264 // depth_base is the depth at the beginning of the tranche. The
3265 // tranche_depth is the depth dimension of the tranche.
3266 for (int depth_base = 0; depth_base < depth;
3267 depth_base += kPoolingAccTrancheSize) {
3268 const int tranche_depth =
3269 std::min(depth - depth_base, kPoolingAccTrancheSize);
3270 for (int out_y = 0; out_y < output_height; ++out_y) {
3271 for (int out_x = 0; out_x < output_width; ++out_x) {
3272 const int in_x_origin =
3273 (out_x * stride_width) - params.padding_values.width;
3274 const int in_y_origin =
3275 (out_y * stride_height) - params.padding_values.height;
3276 const int filter_x_start = std::max(0, -in_x_origin);
3277 const int filter_x_end =
3278 std::min(params.filter_width, input_width - in_x_origin);
3279 const int filter_y_start = std::max(0, -in_y_origin);
3280 const int filter_y_end =
3281 std::min(params.filter_height, input_height - in_y_origin);
3282 const int filter_count =
3283 (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
3284 memset(acc, 0, tranche_depth * sizeof(acc[0]));
3285 const uint8* input_ptr =
3286 input_data + depth_base +
3287 depth * (in_x_origin +
3288 input_width * (in_y_origin + input_height * batch));
3289 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3290 const uint8* input_row_ptr =
3291 input_ptr + depth * (fy * input_width + filter_x_start);
3292 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3293 const uint8* input_channel_ptr = input_row_ptr;
3294 int channel = 0;
3295 #ifdef USE_NEON
3296 for (; channel <= tranche_depth - 16; channel += 16) {
3297 uint16x4_t acc_reg[4];
3298 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3299 input_channel_ptr += 16;
3300 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
3301 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
3302 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
3303 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
3304 for (int i = 0; i < 4; i++) {
3305 vst1q_u32(
3306 acc + channel + 4 * i,
3307 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3308 }
3309 }
3310 for (; channel <= tranche_depth - 8; channel += 8) {
3311 uint16x4_t acc_reg[2];
3312 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
3313 input_channel_ptr += 8;
3314 acc_reg[0] = vget_low_u16(input_reg);
3315 acc_reg[1] = vget_high_u16(input_reg);
3316 for (int i = 0; i < 2; i++) {
3317 vst1q_u32(
3318 acc + channel + 4 * i,
3319 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3320 }
3321 }
3322 #endif
3323 for (; channel < tranche_depth; ++channel) {
3324 acc[channel] += *input_channel_ptr++;
3325 }
3326 input_row_ptr += depth;
3327 }
3328 }
3329 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3330 out_x, depth_base);
3331 int channel = 0;
3332 #ifdef USE_NEON
3333 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
3334 if (filter_count == FILTER_COUNT) { \
3335 for (; channel <= tranche_depth - 8; channel += 8) { \
3336 uint16 buf[8]; \
3337 for (int i = 0; i < 8; i++) { \
3338 buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
3339 } \
3340 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
3341 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3342 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3343 vst1_u8(output_ptr + channel, buf8); \
3344 } \
3345 }
3346 AVGPOOL_DIVIDING_BY(9)
3347 AVGPOOL_DIVIDING_BY(15)
3348 #undef AVGPOOL_DIVIDING_BY
3349 for (; channel <= tranche_depth - 8; channel += 8) {
3350 uint16 buf[8];
3351 for (int i = 0; i < 8; i++) {
3352 buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3353 }
3354 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3355 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3356 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3357 vst1_u8(output_ptr + channel, buf8);
3358 }
3359 #endif
3360 for (; channel < tranche_depth; ++channel) {
3361 uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3362 a = std::max<uint16>(a, params.quantized_activation_min);
3363 a = std::min<uint16>(a, params.quantized_activation_max);
3364 output_ptr[channel] = static_cast<uint8>(a);
3365 }
3366 }
3367 }
3368 }
3369 }
3370 }
3371
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3372 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3373 const float* input_data, const RuntimeShape& output_shape,
3374 float* output_data) {
3375 ruy::profiler::ScopeLabel label("MaxPool");
3376 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3377 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3378 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3379 const int input_height = input_shape.Dims(1);
3380 const int input_width = input_shape.Dims(2);
3381 const int output_height = output_shape.Dims(1);
3382 const int output_width = output_shape.Dims(2);
3383 const int stride_height = params.stride_height;
3384 const int stride_width = params.stride_width;
3385
3386 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3387 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3388 // Prefill the output to minimum representable float value
3389 out_mat.setConstant(std::numeric_limits<float>::lowest());
3390 for (int b = 0; b < batches; ++b) {
3391 for (int h = 0; h < input_height; ++h) {
3392 for (int w = 0; w < input_width; ++w) {
3393 // (h_start, h_end) * (w_start, w_end) is the range that the input
3394 // vector projects to.
3395 int hpad = h + params.padding_values.height;
3396 int wpad = w + params.padding_values.width;
3397 int h_start = (hpad < params.filter_height)
3398 ? 0
3399 : (hpad - params.filter_height) / stride_height + 1;
3400 int h_end = std::min(hpad / stride_height + 1, output_height);
3401 int w_start = (wpad < params.filter_width)
3402 ? 0
3403 : (wpad - params.filter_width) / stride_width + 1;
3404 int w_end = std::min(wpad / stride_width + 1, output_width);
3405 // compute elementwise sum
3406 for (int ph = h_start; ph < h_end; ++ph) {
3407 for (int pw = w_start; pw < w_end; ++pw) {
3408 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3409 out_mat.col(out_offset) =
3410 out_mat.col(out_offset)
3411 .cwiseMax(in_mat.col(
3412 NodeOffset(b, h, w, input_height, input_width)));
3413 }
3414 }
3415 }
3416 }
3417 }
3418 const int flat_size = output_shape.FlatSize();
3419 for (int i = 0; i < flat_size; ++i) {
3420 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3421 params.float_activation_min,
3422 params.float_activation_max);
3423 }
3424 }
3425
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3426 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3427 const uint8* input_data, const RuntimeShape& output_shape,
3428 uint8* output_data) {
3429 ruy::profiler::ScopeLabel label("MaxPool/8bit");
3430
3431 // Here, and in other pooling ops, in order to maintain locality of reference,
3432 // to minimize some recalculations, and to load into NEON vector registers, we
3433 // use an inner loop down the depth. Since depths can be large and hence we
3434 // would need arbitrarily large temporary storage, we divide the work up into
3435 // depth tranches just within the batch loop.
3436 static constexpr int kPoolingAccTrancheSize = 256;
3437
3438 TFLITE_DCHECK_LE(params.quantized_activation_min,
3439 params.quantized_activation_max);
3440 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3441 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3442 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3443 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3444 const int input_height = input_shape.Dims(1);
3445 const int input_width = input_shape.Dims(2);
3446 const int output_height = output_shape.Dims(1);
3447 const int output_width = output_shape.Dims(2);
3448 const int stride_height = params.stride_height;
3449 const int stride_width = params.stride_width;
3450
3451 uint8 acc[kPoolingAccTrancheSize];
3452 for (int batch = 0; batch < batches; ++batch) {
3453 // We proceed through the depth in tranches (see comment above). The
3454 // depth_base is the depth at the beginning of the tranche. The
3455 // tranche_depth is the depth dimension of the tranche.
3456 for (int depth_base = 0; depth_base < depth;
3457 depth_base += kPoolingAccTrancheSize) {
3458 const int tranche_depth =
3459 std::min(depth - depth_base, kPoolingAccTrancheSize);
3460 for (int out_y = 0; out_y < output_height; ++out_y) {
3461 for (int out_x = 0; out_x < output_width; ++out_x) {
3462 const int in_x_origin =
3463 (out_x * stride_width) - params.padding_values.width;
3464 const int in_y_origin =
3465 (out_y * stride_height) - params.padding_values.height;
3466 const int filter_x_start = std::max(0, -in_x_origin);
3467 const int filter_x_end =
3468 std::min(params.filter_width, input_width - in_x_origin);
3469 const int filter_y_start = std::max(0, -in_y_origin);
3470 const int filter_y_end =
3471 std::min(params.filter_height, input_height - in_y_origin);
3472 memset(acc, 0, tranche_depth * sizeof(acc[0]));
3473 const uint8* input_ptr =
3474 input_data + depth_base +
3475 depth * (in_x_origin +
3476 input_width * (in_y_origin + input_height * batch));
3477 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3478 const uint8* input_row_ptr =
3479 input_ptr + depth * (fy * input_width + filter_x_start);
3480 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3481 const uint8* input_channel_ptr = input_row_ptr;
3482 int channel = 0;
3483 #ifdef USE_NEON
3484 for (; channel <= tranche_depth - 16; channel += 16) {
3485 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3486 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3487 input_channel_ptr += 16;
3488 acc_reg = vmaxq_u8(acc_reg, input_reg);
3489 vst1q_u8(acc + channel, acc_reg);
3490 }
3491
3492 for (; channel <= tranche_depth - 8; channel += 8) {
3493 uint8x8_t acc_reg = vld1_u8(acc + channel);
3494 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3495 input_channel_ptr += 8;
3496 acc_reg = vmax_u8(acc_reg, input_reg);
3497 vst1_u8(acc + channel, acc_reg);
3498 }
3499 #endif
3500 for (; channel < tranche_depth; ++channel) {
3501 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3502 }
3503 input_row_ptr += depth;
3504 }
3505 }
3506 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3507 out_x, depth_base);
3508 int channel = 0;
3509 #ifdef USE_NEON
3510 for (; channel <= tranche_depth - 16; channel += 16) {
3511 uint8x16_t a = vld1q_u8(acc + channel);
3512 a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3513 a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3514 vst1q_u8(output_ptr + channel, a);
3515 }
3516 for (; channel <= tranche_depth - 8; channel += 8) {
3517 uint8x8_t a = vld1_u8(acc + channel);
3518 a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3519 a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3520 vst1_u8(output_ptr + channel, a);
3521 }
3522 #endif
3523 for (; channel < tranche_depth; ++channel) {
3524 uint8 a = acc[channel];
3525 a = std::max<uint8>(a, params.quantized_activation_min);
3526 a = std::min<uint8>(a, params.quantized_activation_max);
3527 output_ptr[channel] = static_cast<uint8>(a);
3528 }
3529 }
3530 }
3531 }
3532 }
3533 }
3534
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3535 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3536 const float* input_data, const RuntimeShape& output_shape,
3537 float* output_data) {
3538 ruy::profiler::ScopeLabel label("L2Pool");
3539 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3540 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3541 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3542 const int input_height = input_shape.Dims(1);
3543 const int input_width = input_shape.Dims(2);
3544 const int output_height = output_shape.Dims(1);
3545 const int output_width = output_shape.Dims(2);
3546 const int stride_height = params.stride_height;
3547 const int stride_width = params.stride_width;
3548 // Actually carry out L2 Pool. Code is written in forward mode: we go through
3549 // the input values once, and write to all the pooled regions that it maps to.
3550 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3551 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3552 Eigen::VectorXf in_square(in_mat.rows());
3553 Eigen::VectorXf out_count(out_mat.cols());
3554 out_count.setZero();
3555 // Prefill the output to 0.
3556 out_mat.setZero();
3557 for (int b = 0; b < batches; ++b) {
3558 for (int h = 0; h < input_height; ++h) {
3559 for (int w = 0; w < input_width; ++w) {
3560 // (h_start, h_end) * (w_start, w_end) is the range that the input
3561 // vector projects to.
3562 const int hpad = h + params.padding_values.height;
3563 const int wpad = w + params.padding_values.width;
3564 const int h_start =
3565 (hpad < params.filter_height)
3566 ? 0
3567 : (hpad - params.filter_height) / stride_height + 1;
3568 const int h_end = std::min(hpad / stride_height + 1, output_height);
3569 const int w_start =
3570 (wpad < params.filter_width)
3571 ? 0
3572 : (wpad - params.filter_width) / stride_width + 1;
3573 const int w_end = std::min(wpad / stride_width + 1, output_width);
3574 // pre-compute square
3575 const int in_offset = w + input_width * (h + input_height * b);
3576 in_square =
3577 in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3578 // compute elementwise sum of squares
3579 for (int ph = h_start; ph < h_end; ++ph) {
3580 for (int pw = w_start; pw < w_end; ++pw) {
3581 const int out_offset = pw + output_width * (ph + output_height * b);
3582 out_mat.col(out_offset) += in_square;
3583 out_count(out_offset)++;
3584 }
3585 }
3586 }
3587 }
3588 }
3589
3590 out_count = out_count.array().inverse();
3591 out_mat =
3592 (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3593
3594 const int flat_size = output_shape.FlatSize();
3595 for (int i = 0; i < flat_size; ++i) {
3596 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3597 params.float_activation_min,
3598 params.float_activation_max);
3599 }
3600 }
3601
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3602 inline void LocalResponseNormalization(
3603 const tflite::LocalResponseNormalizationParams& op_params,
3604 const RuntimeShape& input_shape, const float* input_data,
3605 const RuntimeShape& output_shape, float* output_data) {
3606 ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3607 MatchingFlatSize(input_shape, output_shape);
3608
3609 const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3610 auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3611
3612 // Carry out local response normalization, vector by vector.
3613 // Since the data are stored column major, making row-wise operation
3614 // probably not memory efficient anyway, we do an explicit for loop over
3615 // the columns.
3616 const int double_range = op_params.range * 2;
3617 Eigen::VectorXf padded_square(data_in.rows() + double_range);
3618 padded_square.setZero();
3619 const float bias = op_params.bias;
3620 for (int r = 0; r < data_in.cols(); ++r) {
3621 // Do local response normalization for data_in(:, r)
3622 // first, compute the square and store them in buffer for repeated use
3623 padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3624 data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3625 // Then, compute the scale and writes them to data_out
3626 float accumulated_scale = 0;
3627 for (int i = 0; i < double_range; ++i) {
3628 accumulated_scale += padded_square(i);
3629 }
3630 for (int i = 0; i < data_in.rows(); ++i) {
3631 accumulated_scale += padded_square(i + double_range);
3632 data_out(i, r) = bias + accumulated_scale;
3633 accumulated_scale -= padded_square(i);
3634 }
3635 }
3636
3637 // In a few cases, the pow computation could benefit from speedups.
3638 if (op_params.beta == 1) {
3639 data_out.array() = data_in.array() * data_out.array().inverse();
3640 } else if (op_params.beta == 0.5f) {
3641 data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3642 } else {
3643 data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3644 }
3645 }
3646
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3647 inline void SoftmaxImpl(const SoftmaxParams& params,
3648 const RuntimeShape& input_shape,
3649 const float* input_data,
3650 const RuntimeShape& output_shape, float* output_data,
3651 int start_batch, int end_batch) {
3652 ruy::profiler::ScopeLabel label("Softmax/Impl");
3653 MatchingFlatSize(input_shape, output_shape);
3654
3655 const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3656 const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3657 logit_size, end_batch - start_batch);
3658 MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3659 end_batch - start_batch);
3660 // Compute the exponential first, removing the max coefficient for numerical
3661 // stability.
3662 out_mat =
3663 (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3664 // We are separating out the exp function so that exp can be vectorized.
3665 out_mat = out_mat.array().exp();
3666 // Normalize to get the activations.
3667 Eigen::Array<float, 1, Eigen::Dynamic> scale =
3668 out_mat.array().colwise().sum().inverse();
3669 out_mat.array().rowwise() *= scale;
3670 }
3671
3672 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3673 SoftmaxWorkerTask(const SoftmaxParams& params,
3674 const RuntimeShape& input_shape, const float* input_data,
3675 const RuntimeShape& output_shape, float* output_data,
3676 int start_batch, int end_batch)
3677 : params(params),
3678 input_shape(input_shape),
3679 input_data(input_data),
3680 output_shape(output_shape),
3681 output_data(output_data),
3682 start_batch(start_batch),
3683 end_batch(end_batch) {}
RunSoftmaxWorkerTask3684 void Run() override {
3685 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3686 start_batch, end_batch);
3687 }
3688
3689 private:
3690 const tflite::SoftmaxParams& params;
3691 const RuntimeShape& input_shape;
3692 const float* input_data;
3693 const RuntimeShape& output_shape;
3694 float* output_data;
3695 int start_batch;
3696 int end_batch;
3697 };
3698
3699 inline void Softmax(const SoftmaxParams& params,
3700 const RuntimeShape& input_shape, const float* input_data,
3701 const RuntimeShape& output_shape, float* output_data,
3702 CpuBackendContext* cpu_backend_context = nullptr) {
3703 ruy::profiler::ScopeLabel label("Softmax");
3704
3705 // We picture softmax input as a 2-D matrix while the last dim is the logit
3706 // dim, and the rest dims will be the batch dim for the 2-D matrix.
3707 const int batch_size =
3708 FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3709 constexpr int kMinBatchPerThread = 8;
3710 int thread_count = batch_size / kMinBatchPerThread;
3711 thread_count = thread_count > 0 ? thread_count : 1;
3712 const int capped_thread_count =
3713 cpu_backend_context == nullptr
3714 ? 1
3715 : std::min(thread_count, cpu_backend_context->max_num_threads());
3716 if (capped_thread_count == 1) {
3717 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3718 batch_size);
3719 } else {
3720 std::vector<SoftmaxWorkerTask> tasks;
3721 // TODO(b/131746020) don't create new heap allocations every time.
3722 // At least we make it a single heap allocation by using reserve().
3723 tasks.reserve(capped_thread_count);
3724 int batch_start = 0;
3725 for (int i = 0; i < capped_thread_count; ++i) {
3726 // Try to distribute the tasks as even as possible.
3727 int batch_end =
3728 batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3729 tasks.emplace_back(params, input_shape, input_data, output_shape,
3730 output_data, batch_start, batch_end);
3731 batch_start = batch_end;
3732 }
3733 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3734 cpu_backend_context);
3735 }
3736 }
3737
3738 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3739 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3740 const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3741 return prob_rnd + zero_point;
3742 }
3743
3744 #if !__aarch64__
3745 // With ARM64, rounding is faster than add + truncation.
3746 template <>
3747 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3748 int32_t zero_point) {
3749 return static_cast<int32_t>(prob_rescaled + 0.5f);
3750 }
3751 #endif
3752
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3753 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3754 float beta) {
3755 const float scale = -input_scale * beta;
3756 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3757 for (int32_t val = 0; val <= max_uint8; ++val) {
3758 data->table[max_uint8 - val] = expf(scale * val);
3759 }
3760 }
3761
3762 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3763 inline void Softmax(const SoftmaxParams& params,
3764 const RuntimeShape& input_shape, const In* input_data,
3765 const RuntimeShape& output_shape, Out* output_data) {
3766 const int trailing_dim = input_shape.DimensionsCount() - 1;
3767 const int excluding_last_dim =
3768 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3769 const int last_dim =
3770 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3771
3772 const int32_t clamp_max = std::numeric_limits<Out>::max();
3773 const int32_t clamp_min = std::numeric_limits<Out>::min();
3774 for (int i = 0; i < excluding_last_dim; ++i) {
3775 int32_t max_val = std::numeric_limits<In>::min();
3776 // Find max quantized value.
3777 for (int j = 0; j < last_dim; ++j) {
3778 max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3779 }
3780
3781 float sum_exp = 0.0f;
3782 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3783 const float* table_offset = ¶ms.table[max_uint8 - max_val];
3784 // Calculate normalizer sum(exp(x)).
3785 for (int j = 0; j < last_dim; ++j) {
3786 sum_exp += table_offset[input_data[j]];
3787 }
3788
3789 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3790 // Normalize and quantize probabilities.
3791 for (int j = 0; j < last_dim; ++j) {
3792 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3793 const int32_t prob_quantized =
3794 QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3795 output_data[j] = static_cast<Out>(
3796 std::max(std::min(clamp_max, prob_quantized), clamp_min));
3797 }
3798 input_data += last_dim;
3799 output_data += last_dim;
3800 }
3801 }
3802
3803 // Here's the softmax LUT optimization strategy:
3804 // For softmax, we can do some mathmetically equivalent transformation:
3805 //
3806 // softmax(x) = e^x / sum(e^x, 0...n) ===> equals to
3807 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3808 //
3809 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3810 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3811 //
3812 // so if we let
3813 // CONST = (255 - input_zp) * input_s
3814 // then we will have:
3815 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3816 // /
3817 // sum(e^(input_q - 255) * input_s, 0...n) -------- (2)
3818 //
3819 // the good thing about (1) is it's within the range of (0, 1), so we can
3820 // approximate its result with uint16.
3821 // (1) = uint8_out * 1 / 2^16.
3822 //
3823 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3824 // then (2) is essentially the following:
3825 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3826 //
3827 // since (output_q - output_zp) * output_s = softmax(x)
3828 // output_q = lookup_uint8_table(input_zp)
3829 // /
3830 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3831 // +
3832 // output_zp
3833 //
3834 // We can actually further improve the performance by using uint8 instead of
3835 // uint16. But that we may lose some accuracy, so we need to pay attention
3836 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3837 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3838 float input_scale, float beta) {
3839 const float scale = input_scale * beta;
3840 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3841 const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3842
3843 for (int32_t val = 0; val <= max_uint8; ++val) {
3844 float input_to_exp = scale * (val - max_uint8);
3845 int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3846 temp = std::min(max_uint16, temp);
3847 uint8_t part1 = temp >> 8;
3848 uint8_t part2 = temp & 0xff;
3849 data->uint8_table1[val] = static_cast<uint8_t>(part1);
3850 data->uint8_table2[val] = static_cast<uint8_t>(part2);
3851 }
3852 }
3853
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3854 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3855 int32_t max_val = std::numeric_limits<uint8_t>::min();
3856 int j = 0;
3857 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3858 uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3859 uint8x16_t offset_dup = vdupq_n_u8(offset);
3860 for (; j <= size - 16; j += 16) {
3861 uint8x16_t input_value = vld1q_u8(input_data + j);
3862 input_value = veorq_u8(input_value, offset_dup);
3863 max_val_dup = vmaxq_u8(input_value, max_val_dup);
3864 }
3865 max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3866 #endif
3867
3868 for (; j < size; ++j) {
3869 max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3870 }
3871 return max_val;
3872 }
3873
3874 #ifdef USE_NEON
3875 // Value_to_store layout:
3876 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3877 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3878 const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3879 vqmovn_s32(value_to_store.val[0]));
3880 const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3881 vqmovn_s32(value_to_store.val[2]));
3882 const int8x16_t result =
3883 vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3884 vst1q_s8(output, result);
3885 }
3886
3887 // Value_to_store layout:
3888 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3889 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3890 const uint16x8_t result_1 =
3891 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3892 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3893 const uint16x8_t result_2 =
3894 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3895 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3896 const uint8x16_t result =
3897 vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3898 vst1q_u8(output, result);
3899 }
3900
3901 #endif
3902
3903 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3904 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3905 const RuntimeShape& input_shape,
3906 const In* input_data,
3907 const RuntimeShape& output_shape, Out* output_data) {
3908 const int trailing_dim = input_shape.DimensionsCount() - 1;
3909 const int excluding_last_dim =
3910 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3911 const int last_dim =
3912 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3913
3914 const int32_t clamp_max = std::numeric_limits<Out>::max();
3915 const int32_t clamp_min = std::numeric_limits<Out>::min();
3916
3917 // Offset is used to interpret the input data "correctly".
3918 // If the input is uint8, the data will be unchanged.
3919 // If the input is int8, since it will be reinterpret as uint8.
3920 // e.g.,
3921 // int8 127 will be applied "offset" to become 255 in uint8.
3922 uint8_t offset = 0;
3923 if (std::is_same<In, int8>::value) {
3924 offset = 0x80;
3925 }
3926
3927 const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3928
3929 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3930 // This code uses ARM64-only instructions.
3931 // TODO(b/143709993): Port to ARMv7
3932
3933 // Load the tables into registers. (4*4 128-bit registers)
3934 uint8x16x4_t table1[4];
3935 table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3936 table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3937 table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3938 table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3939
3940 uint8x16x4_t table2[4];
3941 table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3942 table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3943 table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3944 table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3945 #endif
3946
3947 for (int i = 0; i < excluding_last_dim; ++i) {
3948 // Find max quantized value.
3949 int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3950
3951 int32 sum_exp = 0;
3952 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3953 const uint8_t table_offset = max_uint8 - max_val;
3954
3955 // Calculate normalizer sum(exp(x)).
3956 int sum_j = 0;
3957 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3958 uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3959 uint8x16_t offset_dup = vdupq_n_u8(offset);
3960 uint32x4_t sum_4 = vdupq_n_u32(0);
3961 const int multiplier_shift = 8;
3962 for (; sum_j <= last_dim - 16; sum_j += 16) {
3963 uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3964 input_value = veorq_u8(input_value, offset_dup);
3965 input_value = vaddq_u8(input_value, table_offset_dup);
3966
3967 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3968 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3969
3970 uint16x8_t exp_value1 =
3971 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3972 uint16x8_t exp_value2 =
3973 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3974
3975 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3976 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3977
3978 sum_4 = vpadalq_u16(sum_4, exp_value1);
3979 sum_4 = vpadalq_u16(sum_4, exp_value2);
3980 }
3981 int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3982 vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3983 sum_exp += temp;
3984
3985 #endif
3986 for (; sum_j < last_dim; ++sum_j) {
3987 const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3988
3989 uint8_t part1 = params.uint8_table1[index];
3990 uint8_t part2 = params.uint8_table2[index];
3991 sum_exp += ((part1 << 8) + part2);
3992 }
3993
3994 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3995
3996 int32 multiplier, shift;
3997 QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
3998
3999 // Normalize and quantize probabilities.
4000 int j = 0;
4001 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
4002 const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
4003 const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
4004 const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
4005
4006 for (; j <= last_dim - 16; j += 16) {
4007 uint8x16_t input_value = vld1q_u8(input_data_uint + j);
4008 input_value = veorq_u8(input_value, offset_dup);
4009 input_value = vaddq_u8(input_value, table_offset_dup);
4010
4011 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
4012 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
4013
4014 uint16x8_t exp_value1 =
4015 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
4016 uint16x8_t exp_value2 =
4017 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
4018
4019 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
4020 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
4021
4022 int32x4x4_t output_value;
4023 output_value.val[0] =
4024 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
4025 output_value.val[1] =
4026 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
4027 output_value.val[2] =
4028 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
4029 output_value.val[3] =
4030 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
4031
4032 int32x4x4_t temp_val =
4033 MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
4034
4035 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
4036 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
4037 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
4038 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
4039
4040 temp_val.val[0] =
4041 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
4042 temp_val.val[1] =
4043 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
4044 temp_val.val[2] =
4045 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
4046 temp_val.val[3] =
4047 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
4048
4049 StoreValue(temp_val, output_data + j);
4050 }
4051 #endif
4052 for (; j < last_dim; ++j) {
4053 const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
4054 const uint8_t part1 = params.uint8_table1[index];
4055 const uint8_t part2 = params.uint8_table2[index];
4056 const int32_t exp_value = (part1 << 8) + part2;
4057 const int32_t output_value =
4058 MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
4059
4060 output_data[j] = static_cast<Out>(std::max(
4061 std::min(clamp_max, output_value + params.zero_point), clamp_min));
4062 }
4063 input_data_uint += last_dim;
4064 output_data += last_dim;
4065 }
4066 }
4067
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4068 inline void LogSoftmax(const SoftmaxParams& params,
4069 const RuntimeShape& input_shape, const float* input_data,
4070 const RuntimeShape& output_shape, float* output_data) {
4071 ruy::profiler::ScopeLabel label("LogSoftmax");
4072 const int trailing_dim = input_shape.DimensionsCount() - 1;
4073 const int outer_size =
4074 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4075 const int depth =
4076 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4077
4078 for (int i = 0; i < outer_size; ++i) {
4079 VectorMap<const float> block_input(input_data + i * depth, depth, 1);
4080 VectorMap<float> block_output(output_data + i * depth, depth, 1);
4081 // Find max element value which we'll use to ensure numerical stability
4082 // taking advantage of the following equality:
4083 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4084 const float max = block_input.maxCoeff();
4085 const float log_sum = std::log((block_input.array() - max).exp().sum());
4086 block_output = block_input.array() - max - log_sum;
4087 }
4088 }
4089
4090 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4091 inline void LogSoftmax(const SoftmaxParams& params,
4092 const RuntimeShape& input_shape, const uint8* input_data,
4093 const RuntimeShape& output_shape, uint8* output_data) {
4094 reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4095 output_data);
4096 }
4097
4098 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
4099 // as done in tf.nn.log_softmax to prevent underflow and overflow.
4100 // This is in contrast to just log(softmax(x))
4101 //
4102 // To handle quantization, first dequantize the inputs (from doing
4103 // e^(input scale * val) where we ignore the zero point since it cancels
4104 // out during subtraction due to the ln) and do a rescale at the end to int8.
4105 //
4106 // Notably this makes use of float and is intended as the optimized
4107 // form for quantized execution on CPU. For a fully integer version,
4108 // see the reference op.
4109 //
4110 // TODO(tflite): notes for optimization:
4111 // 1) See if e^ is also bottleneck in the reference fully-integer
4112 // version and apply lookup there and compare.
4113 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4114 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
4115 const RuntimeShape& input_shape, const T* input_data,
4116 const RuntimeShape& output_shape, T* output_data) {
4117 ruy::profiler::ScopeLabel label("LogSoftmax");
4118 const int trailing_dim = input_shape.DimensionsCount() - 1;
4119 const int excluding_last_dim =
4120 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4121 const int last_dim =
4122 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4123
4124 const int32_t clamp_max = std::numeric_limits<T>::max();
4125 const int32_t clamp_min = std::numeric_limits<T>::min();
4126
4127 int32_t zero_point_offset = 0;
4128 if (std::is_same<T, int8_t>::value) {
4129 zero_point_offset = 128;
4130 }
4131 for (int i = 0; i < excluding_last_dim; ++i) {
4132 T max_val = std::numeric_limits<T>::min();
4133 // Find max quantized value.
4134 for (int j = 0; j < last_dim; ++j) {
4135 max_val = std::max(max_val, input_data[j]);
4136 }
4137
4138 float sum_exp = 0.0f;
4139 const int32_t max_q8 = std::numeric_limits<T>::max();
4140 // Offset into table to compute exp(scale*(x - xmax)) instead of
4141 // exp(scale*(x)) to prevent overflow.
4142 const float* table_offset = ¶ms.table[max_q8 - max_val];
4143 // Calculate sum(exp(scale*(x - x_max))).
4144 for (int j = 0; j < last_dim; ++j) {
4145 sum_exp += table_offset[input_data[j]];
4146 }
4147 const float log_sum_exp = std::log(sum_exp);
4148
4149 // params.scale is the output scale.
4150 const float scale = input_scale / params.scale;
4151 const float precomputed =
4152 (input_scale * (max_val + zero_point_offset) + log_sum_exp) /
4153 params.scale;
4154 for (int j = 0; j < last_dim; ++j) {
4155 // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
4156 // output_scale.
4157 const float log_prob = scale * input_data[j] - precomputed;
4158
4159 // TODO(tflite): look into better solution.
4160 // Use std::rint over std::round (which is used in
4161 // FakeQuant) since it's multiple times faster on tested arm32.
4162 const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
4163 output_data[j] = static_cast<T>(
4164 std::max(std::min(clamp_max, prob_quantized), clamp_min));
4165 }
4166 input_data += last_dim;
4167 output_data += last_dim;
4168 }
4169 }
4170
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4171 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4172 const RuntimeShape& output_shape, float* output_data) {
4173 ruy::profiler::ScopeLabel label("Logistic");
4174 auto input_map = MapAsVector(input_data, input_shape);
4175 auto output_map = MapAsVector(output_data, output_shape);
4176 output_map.array() =
4177 input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4178 }
4179
4180 // Convenience version that allows, for example, generated-code calls to be
4181 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4182 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4183 const float* input_data, const RuntimeShape& output_shape,
4184 float* output_data) {
4185 // Drop params: not needed.
4186 Logistic(input_shape, input_data, output_shape, output_data);
4187 }
4188
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4189 inline void Logistic(const LogisticParams& params,
4190 const RuntimeShape& input_shape, const int16* input_data,
4191 const RuntimeShape& output_shape, int16* output_data) {
4192 ruy::profiler::ScopeLabel label("Logistic/Int16");
4193 const int flat_size = MatchingFlatSize(input_shape, output_shape);
4194
4195 for (int i = 0; i < flat_size; i++) {
4196 }
4197
4198 int c = 0;
4199 const int16* input_data_ptr = input_data;
4200 int16* output_data_ptr = output_data;
4201 #ifdef GEMMLOWP_NEON
4202 {
4203 // F0 uses 0 integer bits, range [-1, 1].
4204 // This is the return type of math functions such as tanh, logistic,
4205 // whose range is in [-1, 1].
4206 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4207 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4208 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4209
4210 for (; c <= flat_size - 16; c += 16) {
4211 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4212 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4213 F0 output0 = gemmlowp::logistic(input0);
4214 F0 output1 = gemmlowp::logistic(input1);
4215 vst1q_s16(output_data_ptr, output0.raw());
4216 vst1q_s16(output_data_ptr + 8, output1.raw());
4217
4218 input_data_ptr += 16;
4219 output_data_ptr += 16;
4220 }
4221 for (; c <= flat_size - 8; c += 8) {
4222 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4223 F0 output = gemmlowp::logistic(input);
4224 vst1q_s16(output_data_ptr, output.raw());
4225
4226 input_data_ptr += 8;
4227 output_data_ptr += 8;
4228 }
4229 }
4230 #endif
4231 #ifdef GEMMLOWP_SSE4
4232 {
4233 // F0 uses 0 integer bits, range [-1, 1].
4234 // This is the return type of math functions such as tanh, logistic,
4235 // whose range is in [-1, 1].
4236 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4237 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4238 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4239
4240 for (; c <= flat_size - 16; c += 16) {
4241 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4242 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4243 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4244 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4245 F0 output0 = gemmlowp::logistic(input0);
4246 F0 output1 = gemmlowp::logistic(input1);
4247 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4248 output0.raw().v);
4249 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4250 output1.raw().v);
4251 input_data_ptr += 16;
4252 output_data_ptr += 16;
4253 }
4254 for (; c <= flat_size - 8; c += 8) {
4255 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4256 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4257 F0 output = gemmlowp::logistic(input);
4258 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4259 output.raw().v);
4260 input_data_ptr += 8;
4261 output_data_ptr += 8;
4262 }
4263 }
4264 #endif
4265
4266 {
4267 // F0 uses 0 integer bits, range [-1, 1].
4268 // This is the return type of math functions such as tanh, logistic,
4269 // whose range is in [-1, 1].
4270 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4271 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4272 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4273
4274 for (; c < flat_size; ++c) {
4275 F3 input = F3::FromRaw(*input_data_ptr);
4276 F0 output = gemmlowp::logistic(input);
4277 *output_data_ptr = output.raw();
4278
4279 ++input_data_ptr;
4280 ++output_data_ptr;
4281 }
4282 }
4283 }
4284
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4285 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
4286 const RuntimeShape& output_shape, float* output_data) {
4287 ruy::profiler::ScopeLabel label("Tanh");
4288 auto input_map = MapAsVector(input_data, input_shape);
4289 auto output_map = MapAsVector(output_data, output_shape);
4290 output_map.array() = input_map.array().tanh();
4291 }
4292
4293 // Convenience version that allows, for example, generated-code calls to be
4294 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4295 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
4296 const float* input_data, const RuntimeShape& output_shape,
4297 float* output_data) {
4298 // Drop params: not needed.
4299 Tanh(input_shape, input_data, output_shape, output_data);
4300 }
4301
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4302 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4303 const int16* input_data, const RuntimeShape& output_shape,
4304 int16* output_data) {
4305 ruy::profiler::ScopeLabel label("Tanh/Int16");
4306 const int input_left_shift = params.input_left_shift;
4307 // Support for shifts is limited until we have a parameterized version of
4308 // SaturatingRoundingMultiplyByPOT().
4309 TFLITE_DCHECK_GE(input_left_shift, 0);
4310 TFLITE_DCHECK_LE(input_left_shift, 1);
4311
4312 const int flat_size = MatchingFlatSize(input_shape, output_shape);
4313
4314 int c = 0;
4315 const int16* input_data_ptr = input_data;
4316 int16* output_data_ptr = output_data;
4317 #ifdef GEMMLOWP_NEON
4318 {
4319 // F0 uses 0 integer bits, range [-1, 1].
4320 // This is the return type of math functions such as tanh, logistic,
4321 // whose range is in [-1, 1].
4322 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4323 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4324 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4325
4326 if (input_left_shift == 0) {
4327 for (; c <= flat_size - 16; c += 16) {
4328 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4329 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4330 F0 output0 = gemmlowp::tanh(input0);
4331 F0 output1 = gemmlowp::tanh(input1);
4332 vst1q_s16(output_data_ptr, output0.raw());
4333 vst1q_s16(output_data_ptr + 8, output1.raw());
4334
4335 input_data_ptr += 16;
4336 output_data_ptr += 16;
4337 }
4338 for (; c <= flat_size - 8; c += 8) {
4339 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4340 F0 output = gemmlowp::tanh(input);
4341 vst1q_s16(output_data_ptr, output.raw());
4342
4343 input_data_ptr += 8;
4344 output_data_ptr += 8;
4345 }
4346 } else {
4347 for (; c <= flat_size - 16; c += 16) {
4348 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4349 vld1q_s16(input_data_ptr)));
4350 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4351 vld1q_s16(input_data_ptr + 8)));
4352 F0 output0 = gemmlowp::tanh(input0);
4353 F0 output1 = gemmlowp::tanh(input1);
4354 vst1q_s16(output_data_ptr, output0.raw());
4355 vst1q_s16(output_data_ptr + 8, output1.raw());
4356
4357 input_data_ptr += 16;
4358 output_data_ptr += 16;
4359 }
4360 for (; c <= flat_size - 8; c += 8) {
4361 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4362 vld1q_s16(input_data_ptr)));
4363 F0 output = gemmlowp::tanh(input);
4364 vst1q_s16(output_data_ptr, output.raw());
4365
4366 input_data_ptr += 8;
4367 output_data_ptr += 8;
4368 }
4369 }
4370 }
4371 #endif
4372 #ifdef GEMMLOWP_SSE4
4373 {
4374 // F0 uses 0 integer bits, range [-1, 1].
4375 // This is the return type of math functions such as tanh, logistic,
4376 // whose range is in [-1, 1].
4377 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4378 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4379 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4380
4381 if (input_left_shift == 0) {
4382 for (; c <= flat_size - 16; c += 16) {
4383 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4384 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4385 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4386 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4387 F0 output0 = gemmlowp::tanh(input0);
4388 F0 output1 = gemmlowp::tanh(input1);
4389 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4390 output0.raw().v);
4391 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4392 output1.raw().v);
4393
4394 input_data_ptr += 16;
4395 output_data_ptr += 16;
4396 }
4397 for (; c <= flat_size - 8; c += 8) {
4398 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4399 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4400 F0 output = gemmlowp::tanh(input);
4401 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4402 output.raw().v);
4403 input_data_ptr += 8;
4404 output_data_ptr += 8;
4405 }
4406 } else {
4407 for (; c <= flat_size - 16; c += 16) {
4408 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4409 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4410 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4411 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4412 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4413 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4414 F0 output0 = gemmlowp::tanh(input0);
4415 F0 output1 = gemmlowp::tanh(input1);
4416 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4417 output0.raw().v);
4418 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4419 output1.raw().v);
4420
4421 input_data_ptr += 16;
4422 output_data_ptr += 16;
4423 }
4424 for (; c <= flat_size - 8; c += 8) {
4425 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4426 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4427 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4428 F0 output = gemmlowp::tanh(input);
4429 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4430 output.raw().v);
4431 input_data_ptr += 8;
4432 output_data_ptr += 8;
4433 }
4434 }
4435 }
4436 #endif
4437
4438 {
4439 // F0 uses 0 integer bits, range [-1, 1].
4440 // This is the return type of math functions such as tanh, logistic,
4441 // whose range is in [-1, 1].
4442 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4443 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4444 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4445
4446 if (input_left_shift == 0) {
4447 for (; c < flat_size; ++c) {
4448 F3 input = F3::FromRaw(*input_data_ptr);
4449 F0 output = gemmlowp::tanh(input);
4450 *output_data_ptr = output.raw();
4451
4452 ++input_data_ptr;
4453 ++output_data_ptr;
4454 }
4455 } else {
4456 for (; c < flat_size; ++c) {
4457 F3 input = F3::FromRaw(
4458 gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4459 F0 output = gemmlowp::tanh(input);
4460 *output_data_ptr = output.raw();
4461
4462 ++input_data_ptr;
4463 ++output_data_ptr;
4464 }
4465 }
4466 }
4467 }
4468
4469 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4470 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4471 const RuntimeShape& output_shape, DstT* output_data) {
4472 ruy::profiler::ScopeLabel label("Cast");
4473 auto input_map = MapAsVector(input_data, input_shape);
4474 auto output_map = MapAsVector(output_data, output_shape);
4475 output_map.array() = input_map.array().template cast<DstT>();
4476 }
4477
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4478 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4479 const RuntimeShape& output_shape, float* output_data) {
4480 ruy::profiler::ScopeLabel label("Floor");
4481 auto input_map = MapAsVector(input_data, input_shape);
4482 auto output_map = MapAsVector(output_data, output_shape);
4483 output_map.array() = Eigen::floor(input_map.array());
4484 }
4485
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4486 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4487 const RuntimeShape& output_shape, float* output_data) {
4488 ruy::profiler::ScopeLabel label("Ceil");
4489 auto input_map = MapAsVector(input_data, input_shape);
4490 auto output_map = MapAsVector(output_data, output_shape);
4491 output_map.array() = Eigen::ceil(input_map.array());
4492 }
4493
4494 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)4495 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
4496 float scale, float* output_ptr) {
4497 int ic = 0;
4498 // Handle 32 input channels at a time.
4499 for (; ic <= depth - 32; ic += 32) {
4500 float32x4x2_t input[4];
4501 for (int i = 0; i < 4; i++) {
4502 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
4503 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
4504 }
4505 float32x4x2_t acc[4];
4506 for (int i = 0; i < 4; i++) {
4507 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
4508 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
4509 }
4510 for (int i = 0; i < 4; i++) {
4511 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
4512 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
4513 }
4514 for (int i = 0; i < 4; i++) {
4515 vst1q_f32(output_ptr, acc[i].val[0]);
4516 vst1q_f32(output_ptr + 4, acc[i].val[1]);
4517 output_ptr += 8;
4518 }
4519 input_ptr += 32;
4520 }
4521 // Handle 16 input channels at a time.
4522 for (; ic <= depth - 16; ic += 16) {
4523 float32x4x2_t input[2];
4524 for (int i = 0; i < 2; i++) {
4525 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
4526 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
4527 }
4528 float32x4x2_t acc[2];
4529 for (int i = 0; i < 2; i++) {
4530 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
4531 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
4532 }
4533 for (int i = 0; i < 2; i++) {
4534 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
4535 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
4536 }
4537 for (int i = 0; i < 2; i++) {
4538 vst1q_f32(output_ptr, acc[i].val[0]);
4539 vst1q_f32(output_ptr + 4, acc[i].val[1]);
4540 output_ptr += 8;
4541 }
4542 input_ptr += 16;
4543 }
4544 // Handle 8 input channels at a time.
4545 for (; ic <= depth - 8; ic += 8) {
4546 float32x4x2_t input;
4547 input.val[0] = vld1q_f32(input_ptr);
4548 input.val[1] = vld1q_f32(input_ptr + 4);
4549
4550 float32x4x2_t acc;
4551 acc.val[0] = vld1q_f32(output_ptr);
4552 acc.val[1] = vld1q_f32(output_ptr + 4);
4553 acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
4554 acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
4555
4556 vst1q_f32(output_ptr, acc.val[0]);
4557 vst1q_f32(output_ptr + 4, acc.val[1]);
4558
4559 input_ptr += 8;
4560 output_ptr += 8;
4561 }
4562 // Handle 4 input channels at a time.
4563 for (; ic <= depth - 4; ic += 4) {
4564 float32x4_t input = vld1q_f32(input_ptr);
4565 float32x4_t acc = vld1q_f32(output_ptr);
4566
4567 acc = vmlaq_n_f32(acc, input, scale);
4568 vst1q_f32(output_ptr, acc);
4569
4570 input_ptr += 4;
4571 output_ptr += 4;
4572 }
4573 // Handle 1 input channel at a time.
4574 for (; ic < depth; ic++) {
4575 *output_ptr += *input_ptr * scale;
4576 output_ptr++;
4577 input_ptr++;
4578 }
4579 }
4580 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)4581 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
4582 float scale, float* output_ptr) {
4583 for (int32 i = 0; i < depth; i++) {
4584 *output_ptr += *input_ptr * scale;
4585 output_ptr++;
4586 input_ptr++;
4587 }
4588 }
4589 #endif
4590
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4591 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
4592 int32 x, int32 y, int32 depth, int32 batch,
4593 const RuntimeShape& input_shape,
4594 const float* input_data,
4595 const RuntimeShape& output_shape,
4596 float* output_data) {
4597 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4598 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4599 const int32 input_width = input_shape.Dims(2);
4600 const int32 output_width = output_shape.Dims(2);
4601
4602 const int32 input_x_offset = (x1 - x0) * depth;
4603 const int32 input_y_offset = (y1 - y0) * depth * input_width;
4604 const int32 output_x_offset = depth;
4605 const int32 output_y_offset = depth * output_width;
4606
4607 #ifdef USE_NEON
4608 TFLITE_DCHECK(x1 >= x0);
4609 TFLITE_DCHECK(y1 >= y0);
4610
4611 int ic = 0;
4612 // Handle 8 input channels at a time.
4613 for (; ic <= depth - 8; ic += 8) {
4614 const float* input_ptr = nullptr;
4615
4616 float32x4x2_t x0y0;
4617 input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
4618 x0y0.val[0] = vld1q_f32(input_ptr);
4619 x0y0.val[1] = vld1q_f32(input_ptr + 4);
4620
4621 float32x4x2_t x1y0;
4622 input_ptr += input_x_offset;
4623 x1y0.val[0] = vld1q_f32(input_ptr);
4624 x1y0.val[1] = vld1q_f32(input_ptr + 4);
4625
4626 float32x4x2_t x0y1;
4627 input_ptr += -input_x_offset + input_y_offset;
4628 x0y1.val[0] = vld1q_f32(input_ptr);
4629 x0y1.val[1] = vld1q_f32(input_ptr + 4);
4630
4631 float32x4x2_t x1y1;
4632 input_ptr += input_x_offset;
4633 x1y1.val[0] = vld1q_f32(input_ptr);
4634 x1y1.val[1] = vld1q_f32(input_ptr + 4);
4635
4636 // Top left corner.
4637 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
4638 vst1q_f32(output_ptr, x0y0.val[0]);
4639 vst1q_f32(output_ptr + 4, x0y0.val[1]);
4640
4641 // Top right corner.
4642 output_ptr += output_x_offset;
4643 float32x4x2_t tr;
4644 tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
4645 tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
4646 tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
4647 tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
4648
4649 vst1q_f32(output_ptr, tr.val[0]);
4650 vst1q_f32(output_ptr + 4, tr.val[1]);
4651
4652 // Bottom left corner.
4653 output_ptr += -output_x_offset + output_y_offset;
4654 float32x4x2_t bl;
4655 bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
4656 bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
4657 bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
4658 bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
4659 vst1q_f32(output_ptr, bl.val[0]);
4660 vst1q_f32(output_ptr + 4, bl.val[1]);
4661
4662 // Bottom right corner.
4663 output_ptr += output_x_offset;
4664 float32x4x2_t br;
4665 br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
4666 br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
4667 br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
4668 br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
4669 br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
4670 br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
4671 vst1q_f32(output_ptr, br.val[0]);
4672 vst1q_f32(output_ptr + 4, br.val[1]);
4673 }
4674 // Handle 4 input channels at a time.
4675 for (; ic <= depth - 4; ic += 4) {
4676 const float* input_ptr =
4677 &input_data[Offset(input_shape, batch, y0, x0, ic)];
4678 float32x4_t x0y0 = vld1q_f32(input_ptr);
4679 float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
4680 float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
4681 float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
4682
4683 // Top left corner.
4684 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
4685 vst1q_f32(output_ptr, x0y0);
4686
4687 // Top right corner.
4688 output_ptr += output_x_offset;
4689 float32x4_t tr = vaddq_f32(x0y0, x1y0);
4690 tr = vmulq_n_f32(tr, 0.5f);
4691 vst1q_f32(output_ptr, tr);
4692
4693 // Bottom left corner.
4694 output_ptr += -output_x_offset + output_y_offset;
4695 float32x4_t bl = vaddq_f32(x0y0, x0y1);
4696 bl = vmulq_n_f32(bl, 0.5f);
4697 vst1q_f32(output_ptr, bl);
4698
4699 // Bottom right corner.
4700 output_ptr += output_x_offset;
4701 float32x4_t br = vaddq_f32(x1y0, x1y1);
4702 br = vmlaq_n_f32(bl, br, 0.5f);
4703 br = vmulq_n_f32(br, 0.5f);
4704 vst1q_f32(output_ptr, br);
4705 }
4706 // Handle one input channel at a time.
4707 for (; ic < depth; ic++) {
4708 const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
4709
4710 float x0y0 = input_data[input_offset];
4711 float x1y0 = input_data[input_offset + input_x_offset];
4712 float x0y1 = input_data[input_offset + input_y_offset];
4713 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
4714
4715 // Top left corner.
4716 const int32 output_offset = Offset(output_shape, batch, y, x, ic);
4717 output_data[output_offset] = x0y0;
4718
4719 // Top right corner.
4720 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
4721
4722 // Bottom left corner.
4723 float output = (x0y0 + x0y1) / 2;
4724 output_data[output_offset + output_y_offset] = output;
4725
4726 // Bottom right corner.
4727 output_data[output_offset + output_x_offset + output_y_offset] =
4728 (output + ((x1y0 + x1y1) / 2)) / 2;
4729 }
4730 #else
4731 for (int ch = 0; ch < depth; ch++) {
4732 const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
4733
4734 float x0y0 = input_data[input_offset];
4735 float x1y0 = input_data[input_offset + input_x_offset];
4736 float x0y1 = input_data[input_offset + input_y_offset];
4737 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
4738
4739 // Top left corner.
4740 const int32 output_offset = Offset(output_shape, batch, y, x, ch);
4741 output_data[output_offset] = x0y0;
4742
4743 // Top right corner.
4744 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
4745
4746 // Bottom left corner.
4747 float output = (x0y0 + x0y1) / 2;
4748 output_data[output_offset + output_y_offset] = output;
4749
4750 // Bottom right corner.
4751 output_data[output_offset + output_x_offset + output_y_offset] =
4752 (output + ((x1y0 + x1y1) / 2)) / 2;
4753 }
4754 #endif
4755 }
4756
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4757 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
4758 int32 input_width, int32 depth,
4759 int32 output_height, int32 output_width,
4760 const RuntimeShape& input_shape,
4761 const float* input_data,
4762 const RuntimeShape& output_shape,
4763 float* output_data) {
4764 for (int b = 0; b < batches; b++) {
4765 for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
4766 for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
4767 int32 x1 = std::min(x0 + 1, input_width - 1);
4768 int32 y1 = std::min(y0 + 1, input_height - 1);
4769 ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
4770 input_data, output_shape, output_data);
4771 }
4772 }
4773 }
4774 }
4775
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,const bool half_pixel_centers)4776 inline void ResizeBilinearGeneric(
4777 int32 batches, int32 input_height, int32 input_width, int32 depth,
4778 int32 output_height, int32 output_width, float height_scale,
4779 float width_scale, const RuntimeShape& input_shape, const float* input_data,
4780 const RuntimeShape& output_shape, float* output_data,
4781 const bool half_pixel_centers) {
4782 memset(output_data, 0,
4783 batches * output_height * output_width * depth * sizeof(float));
4784
4785 int32 output_offset = 0;
4786 for (int b = 0; b < batches; ++b) {
4787 for (int y = 0; y < output_height; ++y) {
4788 float input_y;
4789 int32 y0, y1;
4790 reference_ops::ComputeInterpolationValues(
4791 y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
4792 &y1);
4793 for (int x = 0; x < output_width; ++x) {
4794 float input_x;
4795 int32 x0, x1;
4796 reference_ops::ComputeInterpolationValues(
4797 x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
4798 &x1);
4799 float* output_ptr = &output_data[output_offset];
4800
4801 // Run kernel on the 4 corners of the bilinear resize algorithm.
4802 int32 input_offset = Offset(input_shape, b, y0, x0, 0);
4803 float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
4804 const float* input_ptr = &input_data[input_offset];
4805 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4806
4807 input_offset = Offset(input_shape, b, y0, x1, 0);
4808 scale = (1 - (input_y - y0)) * (input_x - x0);
4809 input_ptr = &input_data[input_offset];
4810 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4811
4812 input_offset = Offset(input_shape, b, y1, x0, 0);
4813 scale = (input_y - y0) * (1 - (input_x - x0));
4814 input_ptr = &input_data[input_offset];
4815 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4816
4817 input_offset = Offset(input_shape, b, y1, x1, 0);
4818 scale = (input_y - y0) * (input_x - x0);
4819 input_ptr = &input_data[input_offset];
4820 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
4821
4822 output_offset += depth;
4823 }
4824 }
4825 }
4826 }
4827
4828 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data,const bool half_pixel_centers)4829 inline void ResizeBilinearGenericSmallChannel(
4830 int32 batches, int32 input_height, int32 input_width, int32 depth,
4831 int32 output_height, int32 output_width, float height_scale,
4832 float width_scale, const RuntimeShape& input_shape, const T* input_data,
4833 const RuntimeShape& output_shape, T* output_data,
4834 const bool half_pixel_centers) {
4835 T* output_ptr = &output_data[0];
4836 for (int b = 0; b < batches; ++b) {
4837 for (int y = 0; y < output_height; ++y) {
4838 float input_y;
4839 int32 y0, y1;
4840 reference_ops::ComputeInterpolationValues(
4841 y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
4842 &y1);
4843 for (int x = 0; x < output_width; ++x) {
4844 float input_x;
4845 int32 x0, x1;
4846 reference_ops::ComputeInterpolationValues(
4847 x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
4848 &x1);
4849
4850 int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
4851 Offset(input_shape, b, y0, x1, 0),
4852 Offset(input_shape, b, y1, x0, 0),
4853 Offset(input_shape, b, y1, x1, 0)};
4854 float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
4855 (1 - (input_y - y0)) * (input_x - x0),
4856 (input_y - y0) * (1 - (input_x - x0)),
4857 (input_y - y0) * (input_x - x0)};
4858
4859 for (int d = 0; d < depth; d++) {
4860 const T* input_ptr = &input_data[d];
4861 *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
4862 input_ptr[input_offset[1]] * scale[1] +
4863 input_ptr[input_offset[2]] * scale[2] +
4864 input_ptr[input_offset[3]] * scale[3]);
4865 }
4866 }
4867 }
4868 }
4869 }
4870
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)4871 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
4872 const RuntimeShape& unextended_input_shape,
4873 const float* input_data,
4874 const RuntimeShape& output_size_shape,
4875 const int32* output_size_data,
4876 const RuntimeShape& unextended_output_shape,
4877 float* output_data) {
4878 ruy::profiler::ScopeLabel label("ResizeBilinear");
4879 // If half_pixel_centers is True, align_corners must be False.
4880 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
4881 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
4882 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
4883 const RuntimeShape input_shape =
4884 RuntimeShape::ExtendedShape(4, unextended_input_shape);
4885 const RuntimeShape output_shape =
4886 RuntimeShape::ExtendedShape(4, unextended_output_shape);
4887
4888 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
4889 int32 input_height = input_shape.Dims(1);
4890 int32 input_width = input_shape.Dims(2);
4891 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
4892
4893 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
4894 int32 output_height = output_size_data[0];
4895 int32 output_width = output_size_data[1];
4896
4897 // Specialize for 2x2 upsample.
4898 if (!op_params.align_corners && !op_params.half_pixel_centers &&
4899 output_height == 2 * input_height && output_width == 2 * input_width) {
4900 ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
4901 output_width, input_shape, input_data, output_shape,
4902 output_data);
4903 } else {
4904 float height_scale = static_cast<float>(input_height) / output_height;
4905 float width_scale = static_cast<float>(input_width) / output_width;
4906 if (op_params.align_corners && output_height > 1) {
4907 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
4908 }
4909 if (op_params.align_corners && output_width > 1) {
4910 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
4911 }
4912
4913 ResizeBilinearGeneric(batches, input_height, input_width, depth,
4914 output_height, output_width, height_scale,
4915 width_scale, input_shape, input_data, output_shape,
4916 output_data, op_params.half_pixel_centers);
4917 }
4918 }
4919
4920 // TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
4921 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)4922 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
4923 const RuntimeShape& unextended_input_shape,
4924 const uint8* input_data,
4925 const RuntimeShape& output_size_shape,
4926 const int32* output_size_data,
4927 const RuntimeShape& unextended_output_shape,
4928 uint8* output_data) {
4929 ruy::profiler::ScopeLabel label("ResizeBilinear");
4930 // If half_pixel_centers is True, align_corners must be False.
4931 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
4932 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
4933 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
4934 const RuntimeShape input_shape =
4935 RuntimeShape::ExtendedShape(4, unextended_input_shape);
4936 const RuntimeShape output_shape =
4937 RuntimeShape::ExtendedShape(4, unextended_output_shape);
4938
4939 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
4940 int32 input_height = input_shape.Dims(1);
4941 int32 input_width = input_shape.Dims(2);
4942 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
4943
4944 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
4945 int32 output_height = output_size_data[0];
4946 int32 output_width = output_size_data[1];
4947
4948 float height_scale =
4949 (op_params.align_corners && output_height > 1)
4950 ? (static_cast<float>(input_height - 1) / (output_height - 1))
4951 : (static_cast<float>(input_height) / output_height);
4952
4953 float width_scale =
4954 (op_params.align_corners && output_width > 1)
4955 ? (static_cast<float>(input_width - 1) / (output_width - 1))
4956 : (static_cast<float>(input_width) / output_width);
4957
4958 ResizeBilinearGenericSmallChannel<uint8>(
4959 batches, input_height, input_width, depth, output_height, output_width,
4960 height_scale, width_scale, input_shape, input_data, output_shape,
4961 output_data, op_params.half_pixel_centers);
4962 }
4963
4964 // Helper methods for BatchToSpaceND.
4965 // `spatial_index_dim` specifies post-crop offset index in this spatial
4966 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4967 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4968 // size in current dimension. `input_dim` and `output_dim` are input and output
4969 // size of BatchToSpaceND operation in current dimension.
4970 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4971 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4972 int input_dim, int output_dim, int* start_index,
4973 int* end_index) {
4974 // (*start_index) * block_shape_dim is effectively rounded up to the next
4975 // multiple of block_shape_dim by the integer division.
4976 *start_index =
4977 std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4978 // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4979 // end_index is exclusive).
4980 *end_index = std::min(
4981 input_dim,
4982 (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4983 }
4984
4985 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4986 inline void BatchToSpaceND(
4987 const RuntimeShape& unextended_input1_shape, const T* input1_data,
4988 const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4989 const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4990 const RuntimeShape& unextended_output_shape, T* output_data) {
4991 ruy::profiler::ScopeLabel label("BatchToSpaceND");
4992
4993 TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4994 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4995 TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4996 unextended_output_shape.DimensionsCount());
4997
4998 // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4999 auto extend_shape = [](const RuntimeShape& shape) {
5000 if (shape.DimensionsCount() == 4) {
5001 return shape;
5002 }
5003 RuntimeShape new_shape(4, 1);
5004 new_shape.SetDim(0, shape.Dims(0));
5005 new_shape.SetDim(1, shape.Dims(1));
5006 new_shape.SetDim(3, shape.Dims(2));
5007 return new_shape;
5008 };
5009 const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
5010 const RuntimeShape output_shape = extend_shape(unextended_output_shape);
5011
5012 const int output_width = output_shape.Dims(2);
5013 const int output_height = output_shape.Dims(1);
5014 const int output_batch_size = output_shape.Dims(0);
5015
5016 const int depth = input1_shape.Dims(3);
5017 const int input_width = input1_shape.Dims(2);
5018 const int input_height = input1_shape.Dims(1);
5019 const int input_batch_size = input1_shape.Dims(0);
5020
5021 const int block_shape_height = block_shape_data[0];
5022 const int block_shape_width =
5023 unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
5024 const int crops_top = crops_data[0];
5025 const int crops_left =
5026 unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
5027
5028 for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
5029 const int out_batch = in_batch % output_batch_size;
5030 const int spatial_offset = in_batch / output_batch_size;
5031
5032 int in_h_start = 0;
5033 int in_h_end = 0;
5034 // GetIndexRange ensures start and end indices are in [0, output_height).
5035 GetIndexRange(spatial_offset / block_shape_width - crops_top,
5036 block_shape_height, input_height, output_height, &in_h_start,
5037 &in_h_end);
5038
5039 for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
5040 const int out_h = in_h * block_shape_height +
5041 spatial_offset / block_shape_width - crops_top;
5042 TFLITE_DCHECK_GE(out_h, 0);
5043 TFLITE_DCHECK_LT(out_h, output_height);
5044
5045 int in_w_start = 0;
5046 int in_w_end = 0;
5047 // GetIndexRange ensures start and end indices are in [0, output_width).
5048 GetIndexRange(spatial_offset % block_shape_width - crops_left,
5049 block_shape_width, input_width, output_width, &in_w_start,
5050 &in_w_end);
5051
5052 for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
5053 const int out_w = in_w * block_shape_width +
5054 spatial_offset % block_shape_width - crops_left;
5055 TFLITE_DCHECK_GE(out_w, 0);
5056 TFLITE_DCHECK_LT(out_w, output_width);
5057 T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
5058 const T* in =
5059 input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
5060 memcpy(out, in, depth * sizeof(T));
5061 }
5062 }
5063 }
5064 }
5065
5066 template <typename T>
TypedMemset(void * ptr,T value,size_t num)5067 void TypedMemset(void* ptr, T value, size_t num) {
5068 // Optimization for common cases where memset() will suffice.
5069 if (value == 0 || std::is_same<T, uint8_t>::value) {
5070 memset(ptr, value, num * sizeof(T));
5071 } else {
5072 // Default implementation for cases where memset() will not preserve the
5073 // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
5074 char* pos = static_cast<char*>(ptr);
5075 for (size_t i = 0; i < num; ++i) {
5076 memcpy(pos, &value, sizeof(T));
5077 pos = pos + sizeof(T);
5078 }
5079 }
5080 }
5081
5082 // This makes heavy use of Offset, along with conditional branches. There may be
5083 // opportunities for improvement.
5084 //
5085 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
5086 // scalar input that provides the padding value. Therefore pad_value_ptr can be
5087 // equivalent to a simple input1_data. For Pad, it should point to a zero
5088 // value.
5089 //
5090 // Note that two typenames are required, so that T=P=int32 is considered a
5091 // specialization distinct from P=int32.
5092 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5093 inline void PadImpl(const tflite::PadParams& op_params,
5094 const RuntimeShape& input_shape, const T* input_data,
5095 const P* pad_value_ptr, const RuntimeShape& output_shape,
5096 T* output_data) {
5097 ruy::profiler::ScopeLabel label("Pad4DSlowImpl");
5098 const RuntimeShape ext_input_shape =
5099 RuntimeShape::ExtendedShape(4, input_shape);
5100 const RuntimeShape ext_output_shape =
5101 RuntimeShape::ExtendedShape(4, output_shape);
5102 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5103 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5104
5105 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5106 // to 4 dims (yes, we are "padding the padding").
5107 std::vector<int> left_padding_copy(4, 0);
5108 const int left_padding_extend = 4 - op_params.left_padding_count;
5109 for (int i = 0; i < op_params.left_padding_count; ++i) {
5110 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5111 }
5112 std::vector<int> right_padding_copy(4, 0);
5113 const int right_padding_extend = 4 - op_params.right_padding_count;
5114 for (int i = 0; i < op_params.right_padding_count; ++i) {
5115 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5116 }
5117
5118 const int output_batch = ext_output_shape.Dims(0);
5119 const int output_height = ext_output_shape.Dims(1);
5120 const int output_width = ext_output_shape.Dims(2);
5121 const int output_depth = ext_output_shape.Dims(3);
5122
5123 const int left_b_padding = left_padding_copy[0];
5124 const int left_h_padding = left_padding_copy[1];
5125 const int left_w_padding = left_padding_copy[2];
5126 const int left_d_padding = left_padding_copy[3];
5127
5128 const int right_b_padding = right_padding_copy[0];
5129 const int right_h_padding = right_padding_copy[1];
5130 const int right_w_padding = right_padding_copy[2];
5131 const int right_d_padding = right_padding_copy[3];
5132
5133 const int input_depth = ext_input_shape.Dims(3);
5134 const T pad_value = *pad_value_ptr;
5135
5136 if (left_b_padding != 0) {
5137 TypedMemset<T>(
5138 output_data, pad_value,
5139 left_b_padding * output_height * output_width * output_depth);
5140 }
5141 for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
5142 ++out_b) {
5143 if (left_h_padding != 0) {
5144 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
5145 pad_value, left_h_padding * output_width * output_depth);
5146 }
5147 for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
5148 ++out_h) {
5149 if (left_w_padding != 0) {
5150 TypedMemset<T>(
5151 output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
5152 pad_value, left_w_padding * output_depth);
5153 }
5154 for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
5155 ++out_w) {
5156 if (left_d_padding != 0) {
5157 TypedMemset<T>(
5158 output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
5159 pad_value, left_d_padding);
5160 }
5161
5162 T* out = output_data +
5163 Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
5164 const T* in = input_data +
5165 Offset(ext_input_shape, out_b - left_b_padding,
5166 out_h - left_h_padding, out_w - left_w_padding, 0);
5167 memcpy(out, in, input_depth * sizeof(T));
5168
5169 if (right_d_padding != 0) {
5170 TypedMemset<T>(
5171 output_data + Offset(ext_output_shape, out_b, out_h, out_w,
5172 output_depth - right_d_padding),
5173 pad_value, right_d_padding);
5174 }
5175 }
5176 if (right_w_padding != 0) {
5177 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
5178 output_width - right_w_padding, 0),
5179 pad_value, right_w_padding * output_depth);
5180 }
5181 }
5182 if (right_h_padding != 0) {
5183 TypedMemset<T>(
5184 output_data + Offset(ext_output_shape, out_b,
5185 output_height - right_h_padding, 0, 0),
5186 pad_value, right_h_padding * output_width * output_depth);
5187 }
5188 }
5189 if (right_b_padding != 0) {
5190 TypedMemset<T>(
5191 output_data +
5192 Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
5193 pad_value,
5194 right_b_padding * output_height * output_width * output_depth);
5195 }
5196 }
5197
5198 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5199 inline void Pad(const tflite::PadParams& op_params,
5200 const RuntimeShape& input_shape, const T* input_data,
5201 const P* pad_value_ptr, const RuntimeShape& output_shape,
5202 T* output_data) {
5203 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5204 output_data);
5205 }
5206
5207 // The second (pad-value) input can be int32 when, say, the first is uint8.
5208 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5209 inline void Pad(const tflite::PadParams& op_params,
5210 const RuntimeShape& input_shape, const T* input_data,
5211 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5212 T* output_data) {
5213 const T converted_pad_value = static_cast<T>(*pad_value_ptr);
5214 PadImpl(op_params, input_shape, input_data, &converted_pad_value,
5215 output_shape, output_data);
5216 }
5217
5218 // This version avoids conflicting template matching.
5219 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)5220 inline void Pad(const tflite::PadParams& op_params,
5221 const RuntimeShape& input_shape, const int32* input_data,
5222 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5223 int32* output_data) {
5224 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5225 output_data);
5226 }
5227
5228 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
5229 //
5230 // This pad requires that (a) left and right paddings are in the 4D patterns
5231 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
5232 // T is uint8.
5233 //
5234 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
5235 // scalar input that provides the padding value. Therefore pad_value_ptr can be
5236 // equivalent to a simple input1_data. For Pad, it should point to a zero
5237 // value.
5238 //
5239 // Note that two typenames are required, so that T=P=int32 is considered a
5240 // specialization distinct from P=int32.
5241 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5242 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
5243 const RuntimeShape& input_shape,
5244 const T* input_data, const P* pad_value_ptr,
5245 const RuntimeShape& output_shape,
5246 T* output_data) {
5247 ruy::profiler::ScopeLabel label("PadImageStyle");
5248 const RuntimeShape ext_input_shape =
5249 RuntimeShape::ExtendedShape(4, input_shape);
5250 const RuntimeShape ext_output_shape =
5251 RuntimeShape::ExtendedShape(4, output_shape);
5252 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5253 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5254
5255 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5256 // to 4 dims (yes, we are "padding the padding").
5257 std::vector<int> left_padding_copy(4, 0);
5258 const int left_padding_extend = 4 - op_params.left_padding_count;
5259 for (int i = 0; i < op_params.left_padding_count; ++i) {
5260 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5261 }
5262 std::vector<int> right_padding_copy(4, 0);
5263 const int right_padding_extend = 4 - op_params.right_padding_count;
5264 for (int i = 0; i < op_params.right_padding_count; ++i) {
5265 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5266 }
5267 // The following padding restrictions are contractual requirements, and
5268 // embody what it means for a padding op to be "image-style".
5269 TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
5270 TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
5271 TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
5272 TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
5273
5274 const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
5275 const int output_height = ext_output_shape.Dims(1);
5276 const int output_width = ext_output_shape.Dims(2);
5277 const int input_height = ext_input_shape.Dims(1);
5278 const int input_width = ext_input_shape.Dims(2);
5279 const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
5280
5281 const int left_h_padding = left_padding_copy[1];
5282 const int left_w_padding = left_padding_copy[2];
5283 const int right_h_padding = right_padding_copy[1];
5284 const int right_w_padding = right_padding_copy[2];
5285
5286 TFLITE_DCHECK_EQ(output_height,
5287 input_height + left_h_padding + right_h_padding);
5288 TFLITE_DCHECK_EQ(output_width,
5289 input_width + left_w_padding + right_w_padding);
5290
5291 const T pad_value = *pad_value_ptr;
5292 const int top_block_size = left_h_padding * output_width * depth;
5293 const size_t num_top_block_bytes = top_block_size * sizeof(T);
5294 const int bottom_block_size = right_h_padding * output_width * depth;
5295 const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
5296 const int left_blocks_size = left_w_padding * depth;
5297 const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
5298 const int right_blocks_size = right_w_padding * depth;
5299 const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
5300 const int inner_line_size = input_width * depth;
5301 const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
5302
5303 if (input_height == 0) {
5304 memset(output_data, pad_value,
5305 num_top_block_bytes + num_bottom_block_bytes);
5306 } else {
5307 for (int i = 0; i < batch; ++i) {
5308 // For each image in the batch, apply the top padding, then iterate
5309 // through rows, then apply the bottom padding.
5310 //
5311 // By unwinding one iteration, we can combine the first left-margin
5312 // padding with the top padding, and the last right-margin padding with
5313 // the bottom padding.
5314 memset(output_data, pad_value,
5315 num_top_block_bytes + num_left_block_bytes);
5316 output_data += top_block_size + left_blocks_size;
5317 memcpy(output_data, input_data, num_inner_line_bytes);
5318 input_data += inner_line_size;
5319 output_data += inner_line_size;
5320 // One iteration unwound.
5321 // Unwinding this loop affords the opportunity to reorder the loop work
5322 // and hence combine memset() calls.
5323 //
5324 // Before unwinding:
5325 // for (int j = 0; j < input_height; ++j) {
5326 // // Pad on left, copy central data, pad on right.
5327 // memset(output_data, pad_value, num_left_block_bytes);
5328 // output_data += left_blocks_size;
5329 // memcpy(output_data, input_data, num_inner_line_bytes);
5330 // input_data += inner_line_size;
5331 // output_data += inner_line_size;
5332 // memset(output_data, pad_value, num_right_block_bytes);
5333 // output_data += right_blocks_size;
5334 // }
5335 for (int j = 1; j < input_height; ++j) {
5336 memset(output_data, pad_value,
5337 num_right_block_bytes + num_left_block_bytes);
5338 output_data += right_blocks_size + left_blocks_size;
5339 memcpy(output_data, input_data, num_inner_line_bytes);
5340 input_data += inner_line_size;
5341 output_data += inner_line_size;
5342 }
5343 memset(output_data, pad_value,
5344 num_right_block_bytes + num_bottom_block_bytes);
5345 output_data += right_blocks_size + bottom_block_size;
5346 }
5347 }
5348 }
5349
5350 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5351 inline void PadImageStyle(const tflite::PadParams& op_params,
5352 const RuntimeShape& input_shape, const T* input_data,
5353 const P* pad_value_ptr,
5354 const RuntimeShape& output_shape, T* output_data) {
5355 reference_ops::PadImageStyle(op_params, input_shape, input_data,
5356 pad_value_ptr, output_shape, output_data);
5357 }
5358
5359 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)5360 inline void PadImageStyle(const tflite::PadParams& op_params,
5361 const RuntimeShape& input_shape,
5362 const uint8* input_data, const P* pad_value_ptr,
5363 const RuntimeShape& output_shape,
5364 uint8* output_data) {
5365 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
5366 output_shape, output_data);
5367 }
5368
5369 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)5370 inline void PadImageStyle(const tflite::PadParams& op_params,
5371 const RuntimeShape& input_shape,
5372 const float* input_data, const P* pad_value_ptr,
5373 const RuntimeShape& output_shape,
5374 float* output_data) {
5375 const float converted_pad_value = static_cast<float>(*pad_value_ptr);
5376 if (converted_pad_value == 0.0f) {
5377 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
5378 output_shape, output_data);
5379 } else {
5380 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5381 output_data);
5382 }
5383 }
5384
5385 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)5386 inline void Slice(const tflite::SliceParams& op_params,
5387 const RuntimeShape& input_shape,
5388 const RuntimeShape& output_shape,
5389 SequentialTensorWriter<T>* writer) {
5390 ruy::profiler::ScopeLabel label("Slice");
5391 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
5392 TFLITE_DCHECK_LE(op_params.begin_count, 5);
5393 TFLITE_DCHECK_LE(op_params.size_count, 5);
5394 const int begin_count = op_params.begin_count;
5395 const int size_count = op_params.size_count;
5396 // We front-pad the begin and size vectors.
5397 std::array<int, 5> start;
5398 std::array<int, 5> stop;
5399 for (int i = 0; i < 5; ++i) {
5400 int padded_i = 5 - i;
5401 start[i] =
5402 begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
5403 stop[i] =
5404 (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
5405 ? ext_shape.Dims(i)
5406 : start[i] + op_params.size[size_count - padded_i];
5407 }
5408
5409 for (int i0 = start[0]; i0 < stop[0]; ++i0) {
5410 for (int i1 = start[1]; i1 < stop[1]; ++i1) {
5411 for (int i2 = start[2]; i2 < stop[2]; ++i2) {
5412 for (int i3 = start[3]; i3 < stop[3]; ++i3) {
5413 const int len = stop[4] - start[4];
5414 if (len > 0)
5415 writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
5416 }
5417 }
5418 }
5419 }
5420 }
5421
5422 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)5423 inline void Slice(const tflite::SliceParams& op_params,
5424 const RuntimeShape& input_shape, const T* input_data,
5425 const RuntimeShape& output_shape, T* output_data) {
5426 SequentialTensorWriter<T> writer(input_data, output_data);
5427 return Slice(op_params, input_shape, output_shape, &writer);
5428 }
5429
5430 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)5431 inline void Slice(const tflite::SliceParams& op_params,
5432 const RuntimeShape& input_shape, const TfLiteTensor* input,
5433 const RuntimeShape& output_shape, TfLiteTensor* output) {
5434 SequentialTensorWriter<T> writer(input, output);
5435 return Slice(op_params, input_shape, output_shape, &writer);
5436 }
5437
5438 // Note: This implementation is only optimized for the case where the inner
5439 // stride == 1.
5440 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)5441 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5442 const RuntimeShape& unextended_input_shape,
5443 const RuntimeShape& unextended_output_shape,
5444 SequentialTensorWriter<T>* writer) {
5445 using strided_slice::LoopCondition;
5446 using strided_slice::StartForAxis;
5447 using strided_slice::StopForAxis;
5448
5449 ruy::profiler::ScopeLabel label("StridedSlice");
5450
5451 // Note that the output_shape is not used herein.
5452 tflite::StridedSliceParams params_copy = op_params;
5453
5454 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
5455 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
5456 const RuntimeShape input_shape =
5457 RuntimeShape::ExtendedShape(5, unextended_input_shape);
5458 const RuntimeShape output_shape =
5459 RuntimeShape::ExtendedShape(5, unextended_output_shape);
5460
5461 // Reverse and pad to 5 dimensions because that is what the runtime code
5462 // requires (ie. all shapes must be 5D and are given backwards).
5463 strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
5464
5465 const int start_0 = StartForAxis(params_copy, input_shape, 0);
5466 const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
5467 const int start_1 = StartForAxis(params_copy, input_shape, 1);
5468 const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
5469 const int start_2 = StartForAxis(params_copy, input_shape, 2);
5470 const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
5471 const int start_3 = StartForAxis(params_copy, input_shape, 3);
5472 const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
5473 const int start_4 = StartForAxis(params_copy, input_shape, 4);
5474 const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
5475 const bool inner_stride_is_1 = params_copy.strides[4] == 1;
5476
5477 for (int offset_0 = start_0 * input_shape.Dims(1),
5478 end_0 = stop_0 * input_shape.Dims(1),
5479 step_0 = params_copy.strides[0] * input_shape.Dims(1);
5480 !LoopCondition(offset_0, end_0, params_copy.strides[0]);
5481 offset_0 += step_0) {
5482 for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
5483 end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
5484 step_1 = params_copy.strides[1] * input_shape.Dims(2);
5485 !LoopCondition(offset_1, end_1, params_copy.strides[1]);
5486 offset_1 += step_1) {
5487 for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
5488 end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
5489 step_2 = params_copy.strides[2] * input_shape.Dims(3);
5490 !LoopCondition(offset_2, end_2, params_copy.strides[2]);
5491 offset_2 += step_2) {
5492 for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
5493 end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
5494 step_3 = params_copy.strides[3] * input_shape.Dims(4);
5495 !LoopCondition(offset_3, end_3, params_copy.strides[3]);
5496 offset_3 += step_3) {
5497 // When the stride is 1, the inner loop is equivalent to the
5498 // optimized slice inner loop. Otherwise, it is identical to the
5499 // strided_slice reference implementation inner loop.
5500 if (inner_stride_is_1) {
5501 const int len = stop_4 - start_4;
5502 if (len > 0) {
5503 writer->WriteN(offset_3 + start_4, len);
5504 }
5505 } else {
5506 for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
5507 !LoopCondition(offset_4, end_4, params_copy.strides[4]);
5508 offset_4 += params_copy.strides[4]) {
5509 writer->Write(offset_4);
5510 }
5511 }
5512 }
5513 }
5514 }
5515 }
5516 }
5517
5518 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)5519 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5520 const RuntimeShape& unextended_input_shape,
5521 const T* input_data,
5522 const RuntimeShape& unextended_output_shape,
5523 T* output_data) {
5524 SequentialTensorWriter<T> writer(input_data, output_data);
5525 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5526 &writer);
5527 }
5528
5529 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)5530 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5531 const RuntimeShape& unextended_input_shape,
5532 const TfLiteTensor* input,
5533 const RuntimeShape& unextended_output_shape,
5534 TfLiteTensor* output) {
5535 SequentialTensorWriter<T> writer(input, output);
5536 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5537 &writer);
5538 }
5539
5540 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5541 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5542 const T* input2_data, const RuntimeShape& output_shape,
5543 T* output_data) {
5544 ruy::profiler::ScopeLabel label("TensorFlowMinimum");
5545 auto input1_map = MapAsVector(input1_data, input1_shape);
5546 auto output_map = MapAsVector(output_data, output_shape);
5547 auto min_value = input2_data[0];
5548 output_map.array() = input1_map.array().min(min_value);
5549 }
5550
5551 // Convenience version that allows, for example, generated-code calls to be
5552 // the same as other binary ops.
5553 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5554 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5555 const RuntimeShape&, const T* input2_data,
5556 const RuntimeShape& output_shape, T* output_data) {
5557 // Drop shape of second input: not needed.
5558 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
5559 }
5560
5561 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5562 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5563 const T* input2_data, const RuntimeShape& output_shape,
5564 T* output_data) {
5565 ruy::profiler::ScopeLabel label("TensorFlowMaximum");
5566 auto input1_map = MapAsVector(input1_data, input1_shape);
5567 auto output_map = MapAsVector(output_data, output_shape);
5568 auto max_value = input2_data[0];
5569 output_map.array() = input1_map.array().max(max_value);
5570 }
5571
5572 // Convenience version that allows, for example, generated-code calls to be
5573 // the same as other binary ops.
5574 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5575 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5576 const RuntimeShape&, const T* input2_data,
5577 const RuntimeShape& output_shape, T* output_data) {
5578 // Drop shape of second input: not needed.
5579 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
5580 }
5581
5582 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)5583 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
5584 const RuntimeShape& input_shape, const T* input_data,
5585 const RuntimeShape& filter_shape,
5586 const RuntimeShape& output_shape, T* im2col_data) {
5587 ruy::profiler::ScopeLabel label("TransposeIm2col");
5588 const int stride_width = params.stride_width;
5589 const int stride_height = params.stride_height;
5590 const int pad_width = params.padding_values.width;
5591 const int pad_height = params.padding_values.height;
5592 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5593 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
5594 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5595 TFLITE_DCHECK(im2col_data);
5596
5597 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
5598 const int input_height = input_shape.Dims(1);
5599 const int input_width = input_shape.Dims(2);
5600 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
5601 const int filter_height = filter_shape.Dims(1);
5602 const int filter_width = filter_shape.Dims(2);
5603 const int output_height = output_shape.Dims(1);
5604 const int output_width = output_shape.Dims(2);
5605 MatchingDim(output_shape, 3, filter_shape, 0); // output_depth
5606
5607 // Construct the MxN sized im2col matrix.
5608 // The rows M, are sub-ordered B x H x W
5609 const RuntimeShape row_shape({1, batches, output_height, output_width});
5610 // The columns, N, are sub-ordered Kh x Kw x Din
5611 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
5612 // Use dimensions M and N to construct dims for indexing directly into im2col
5613 const RuntimeShape im2col_shape(
5614 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
5615
5616 // Build the im2col matrix by looping through all the input pixels,
5617 // computing their influence on the output, rather than looping through all
5618 // the output pixels. We therefore must initialize the im2col array to zero.
5619 // This is potentially inefficient because we subsequently overwrite bytes
5620 // set here. However, in practice memset is very fast and costs negligible.
5621 memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
5622
5623 // Loop through the output batches
5624 for (int batch = 0; batch < batches; ++batch) {
5625 // Loop through input pixels one at a time.
5626 for (int in_y = 0; in_y < input_height; ++in_y) {
5627 for (int in_x = 0; in_x < input_width; ++in_x) {
5628 // Loop through the output pixels it will influence
5629 const int out_x_origin = (in_x * stride_width) - pad_width;
5630 const int out_y_origin = (in_y * stride_height) - pad_height;
5631 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
5632 const int out_y = out_y_origin + filter_y;
5633 // Is output pixel within height bounds?
5634 if ((out_y >= 0) && (out_y < output_height)) {
5635 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
5636 const int out_x = out_x_origin + filter_x;
5637 // Is output pixel within width bounds?
5638 if ((out_x >= 0) && (out_x < output_width)) {
5639 // Copy the input elements of this pixel
5640 T const* src =
5641 input_data + Offset(input_shape, batch, in_y, in_x, 0);
5642 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
5643 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
5644 T* dst = im2col_data +
5645 Offset(im2col_shape, 0, 0, row_offset, col_offset);
5646 memcpy(dst, src, input_depth * sizeof(T));
5647 }
5648 }
5649 }
5650 }
5651 }
5652 }
5653 }
5654 }
5655
5656 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
5657 // order (height, width, depth), constructed from patches in 'col_data', which
5658 // is required to be in storage order (out_height * out_width, filter_height,
5659 // filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
5660 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
5661 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)5662 void Col2im(const T* col_data, const int depth, const int height,
5663 const int width, const int filter_h, const int filter_w,
5664 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
5665 const int stride_h, const int stride_w, T* im_data) {
5666 ruy::profiler::ScopeLabel label("Col2im");
5667 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
5668 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
5669 int h_pad = -pad_t;
5670 for (int h = 0; h < height_col; ++h) {
5671 int w_pad = -pad_l;
5672 for (int w = 0; w < width_col; ++w) {
5673 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
5674 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
5675 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
5676 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
5677 // TODO(andydavis) Vectorize this loop (if compiler does not).
5678 for (int i = 0; i < depth; ++i) {
5679 im_patch_data[i] += col_data[i];
5680 }
5681 }
5682 im_patch_data += depth;
5683 col_data += depth;
5684 }
5685 // Jump over remaining number of depth.
5686 im_patch_data += depth * (width - filter_w);
5687 }
5688 w_pad += stride_w;
5689 }
5690 h_pad += stride_h;
5691 }
5692 }
5693
5694 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)5695 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
5696 const int height, const int width, const int depth) {
5697 if (bias_data) {
5698 for (int n = 0; n < batch_size; ++n) {
5699 for (int h = 0; h < height; ++h) {
5700 for (int w = 0; w < width; ++w) {
5701 for (int d = 0; d < depth; ++d) {
5702 im_data[d] += bias_data[d];
5703 }
5704 im_data += depth;
5705 }
5706 }
5707 }
5708 }
5709 }
5710
5711 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)5712 inline void TransposeConvV2(
5713 const ConvParams& params, const RuntimeShape& input_shape,
5714 const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5715 const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5716 const float* bias_data, const RuntimeShape& output_shape,
5717 float* const output_data, const RuntimeShape& col2im_shape,
5718 float* col2im_data, CpuBackendContext* cpu_backend_context) {
5719 ruy::profiler::ScopeLabel label("TransposeConvV2/float");
5720 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5721 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5722 TFLITE_DCHECK(col2im_data);
5723 TFLITE_DCHECK(hwoi_ordered_filter_data);
5724
5725 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5726 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5727 const int output_height = output_shape.Dims(1);
5728 const int output_width = output_shape.Dims(2);
5729 const int output_image_size = output_height * output_width;
5730 const int input_depth =
5731 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5732 const int output_depth =
5733 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5734 const int input_offset = input_image_size * input_depth;
5735 const int output_offset = output_image_size * output_depth;
5736
5737 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5738 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5739 const int padding_top = params.padding_values.height;
5740 const int padding_bottom =
5741 params.padding_values.height + params.padding_values.height_offset;
5742 const int padding_left = params.padding_values.width;
5743 const int padding_right =
5744 params.padding_values.width + params.padding_values.width_offset;
5745 const int stride_height = params.stride_height;
5746 const int stride_width = params.stride_width;
5747
5748 const int hwoi_ordered_filter_total_size =
5749 filter_height * filter_width * output_depth;
5750
5751 cpu_backend_gemm::MatrixParams<float> lhs_params;
5752 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5753 lhs_params.rows = hwoi_ordered_filter_total_size;
5754 lhs_params.cols = input_depth;
5755 float* output_data_p = output_data;
5756 std::fill_n(output_data, output_offset * batch_size, 0.0f);
5757 for (int i = 0; i < batch_size; ++i) {
5758 cpu_backend_gemm::MatrixParams<float> rhs_params;
5759 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5760 rhs_params.rows = input_depth;
5761 rhs_params.cols = input_image_size;
5762 cpu_backend_gemm::MatrixParams<float> dst_params;
5763 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5764 dst_params.rows = hwoi_ordered_filter_total_size;
5765 dst_params.cols = input_image_size;
5766 cpu_backend_gemm::GemmParams<float, float> gemm_params;
5767 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5768 input_data + input_offset * i, dst_params,
5769 col2im_data, gemm_params, cpu_backend_context);
5770
5771 Col2im(col2im_data, output_depth, output_height, output_width,
5772 filter_height, filter_width, padding_top, padding_left,
5773 padding_bottom, padding_right, stride_height, stride_width,
5774 output_data_p);
5775 output_data_p += output_offset;
5776 }
5777 output_data_p = output_data;
5778 BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5779 output_depth);
5780 }
5781
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5782 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5783 int32_t output_zp, int32_t* scratch, uint8_t* output) {
5784 ruy::profiler::ScopeLabel label("Quantize/uint8");
5785 int i = 0;
5786 const int32_t output_min = std::numeric_limits<uint8_t>::min();
5787 const int32_t output_max = std::numeric_limits<uint8_t>::max();
5788
5789 #ifdef USE_NEON
5790 const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5791 const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5792 const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5793
5794 using gemmlowp::RoundingDivideByPOT;
5795 using gemmlowp::SaturatingRoundingDoublingHighMul;
5796
5797 for (; i <= total_size - 16; i += 16) {
5798 int32x4x4_t scratch_val;
5799 scratch_val.val[0] = vld1q_s32(scratch + i);
5800 scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5801 scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5802 scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5803
5804 int32x4x4_t temp_val =
5805 MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5806
5807 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5808 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5809 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5810 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5811
5812 temp_val.val[0] =
5813 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5814 temp_val.val[1] =
5815 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5816 temp_val.val[2] =
5817 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5818 temp_val.val[3] =
5819 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5820
5821 const uint16x8_t result_1 =
5822 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5823 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5824 const uint16x8_t result_2 =
5825 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5826 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5827 const uint8x16_t result =
5828 vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5829 vst1q_u8(output + i, result);
5830 }
5831 #endif
5832 for (; i < total_size; ++i) {
5833 int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5834 temp += output_zp;
5835 if (temp > output_max) {
5836 temp = output_max;
5837 }
5838 if (temp < output_min) {
5839 temp = output_min;
5840 }
5841 output[i] = static_cast<uint8_t>(temp);
5842 }
5843 }
5844
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5845 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5846 int32_t channel_size, int32_t total_size,
5847 int32_t output_zp, int32_t output_min, int32_t output_max,
5848 int32_t* scratch, int8_t* output) {
5849 ruy::profiler::ScopeLabel label("Quantize/int8");
5850
5851 // Here we're trying to quantize the raw accumulators:
5852 // output_channels
5853 // data data data data data
5854 // rows data data data data data
5855 // data data data data data
5856 // ....
5857 //
5858 // In order to minimize the reload of the multipliers & shifts, once we load
5859 // the multipliers & shifts, we load & quantize the raw accumulators for every
5860 // row.
5861 #ifdef USE_NEON
5862 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5863 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5864 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5865 const int32x4_t zeros = vdupq_n_s32(0);
5866 #endif
5867
5868 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5869 const int32_t rows = total_size / channel_size;
5870
5871 int c = 0;
5872
5873 #ifdef USE_NEON
5874 using gemmlowp::RoundingDivideByPOT;
5875 for (; c <= channel_size - 8; c += 8) {
5876 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5877 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5878 int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5879 int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5880
5881 // Right shift will be performed as left shift with negative values.
5882 int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5883 int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5884
5885 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5886 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5887 for (int n = 0; n < rows; ++n) {
5888 int loc = n * channel_size + c;
5889 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5890 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5891
5892 // Saturating Rounding Doubling High Mul.
5893 acc_1 = vshlq_s32(acc_1, left_shift_1);
5894 acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5895 acc_2 = vshlq_s32(acc_2, left_shift_2);
5896 acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5897
5898 // Rounding Dividing By POT.
5899 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5900 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5901
5902 // Add the output offset.
5903 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5904 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5905
5906 // Apply the activation function.
5907 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5908 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5909 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5910 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5911
5912 // Saturating cast to int8 and store to destination.
5913 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5914 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5915 const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5916 const int8x8_t res_s8 = vqmovn_s16(res_s16);
5917 vst1_s8(output + loc, res_s8);
5918 }
5919 }
5920
5921 #endif // USE_NEON
5922 // Handle leftover values, one by one. This is very slow.
5923 for (; c < channel_size; c++) {
5924 for (int n = 0; n < rows; ++n) {
5925 int loc = n * channel_size + c;
5926 int32 acc = scratch[loc];
5927 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5928 acc += output_zp;
5929 acc = std::max(acc, output_min);
5930 acc = std::min(acc, output_max);
5931 output[loc] = static_cast<int8>(acc);
5932 }
5933 }
5934 }
5935
5936 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5937 inline void TransposeConvV2(
5938 const ConvParams& params, const RuntimeShape& input_shape,
5939 const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5940 const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5941 const int32* bias_data, const RuntimeShape& output_shape,
5942 uint8_t* output_data, const RuntimeShape& col2im_shape,
5943 int32_t* col2im_data, int32_t* scratch_data,
5944 CpuBackendContext* cpu_backend_context) {
5945 ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5946 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5947 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5948 TFLITE_DCHECK(col2im_data);
5949 TFLITE_DCHECK(hwoi_ordered_filter_data);
5950
5951 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5952 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5953 const int output_height = output_shape.Dims(1);
5954 const int output_width = output_shape.Dims(2);
5955 const int output_image_size = output_height * output_width;
5956 const int input_depth =
5957 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5958 const int output_depth =
5959 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5960 const int input_offset = input_image_size * input_depth;
5961 const int output_offset = output_image_size * output_depth;
5962
5963 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5964 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5965 const int padding_top = params.padding_values.height;
5966 const int padding_bottom =
5967 params.padding_values.height + params.padding_values.height_offset;
5968 const int padding_left = params.padding_values.width;
5969 const int padding_right =
5970 params.padding_values.width + params.padding_values.width_offset;
5971 const int stride_height = params.stride_height;
5972 const int stride_width = params.stride_width;
5973
5974 const int hwoi_ordered_filter_total_size =
5975 filter_height * filter_width * output_depth;
5976
5977 cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5978 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5979 lhs_params.rows = hwoi_ordered_filter_total_size;
5980 lhs_params.cols = input_depth;
5981 lhs_params.zero_point = -params.weights_offset;
5982
5983 int32_t* scratch_data_p = scratch_data;
5984 std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5985 for (int i = 0; i < batch_size; ++i) {
5986 cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5987 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5988 rhs_params.rows = input_depth;
5989 rhs_params.cols = input_image_size;
5990 rhs_params.zero_point = -params.input_offset;
5991
5992 cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5993 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5994 dst_params.rows = hwoi_ordered_filter_total_size;
5995 dst_params.cols = input_image_size;
5996
5997 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5998 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5999 input_data + input_offset * i, dst_params,
6000 col2im_data, gemm_params, cpu_backend_context);
6001
6002 Col2im(col2im_data, output_depth, output_height, output_width,
6003 filter_height, filter_width, padding_top, padding_left,
6004 padding_bottom, padding_right, stride_height, stride_width,
6005 scratch_data_p);
6006
6007 scratch_data_p += output_offset;
6008 }
6009 scratch_data_p = scratch_data;
6010 BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
6011 output_depth);
6012
6013 Quantize(params.output_multiplier, params.output_shift,
6014 output_shape.FlatSize(), params.output_offset, scratch_data,
6015 output_data);
6016 }
6017
6018 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
6019 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
6020 // reference version. Debug checks are in place to test if this occurs.
6021 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
6022 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)6023 inline void ResizeNearestNeighbor(
6024 const tflite::ResizeNearestNeighborParams& op_params,
6025 const RuntimeShape& unextended_input_shape, const uint8* input_data,
6026 const RuntimeShape& output_size_shape, const int32* output_size_data,
6027 const RuntimeShape& unextended_output_shape, uint8* output_data) {
6028 if (op_params.align_corners || op_params.half_pixel_centers) {
6029 // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
6030 // this kernel.
6031 reference_ops::ResizeNearestNeighbor(
6032 op_params, unextended_input_shape, input_data, output_size_shape,
6033 output_size_data, unextended_output_shape, output_data);
6034 return;
6035 }
6036 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
6037 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
6038
6039 const RuntimeShape input_shape =
6040 RuntimeShape::ExtendedShape(4, unextended_input_shape);
6041 const RuntimeShape output_shape =
6042 RuntimeShape::ExtendedShape(4, unextended_output_shape);
6043
6044 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
6045 int32 input_height = input_shape.Dims(1);
6046 int32 input_width = input_shape.Dims(2);
6047 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
6048
6049 // The Tensorflow version of this op allows resize on the width and height
6050 // axis only.
6051 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
6052 int32 output_height = output_size_data[0];
6053 int32 output_width = output_size_data[1];
6054
6055 // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
6056 // error factor and to avoid zero scales. For example, with input_height = 1,
6057 // output_height = 3, the float scaling factor would be non-zero at 1/3.
6058 // With fixed-point, this is zero.
6059 int32 height_scale = (input_height << 16) / output_height + 1;
6060 int32 width_scale = (input_width << 16) / output_width + 1;
6061
6062 const int col_offset = input_shape.Dims(3);
6063 const int row_offset = input_shape.Dims(2) * col_offset;
6064 const int batch_offset = input_shape.Dims(1) * row_offset;
6065
6066 const uint8* input_ptr = input_data;
6067 uint8* output_ptr = output_data;
6068 for (int b = 0; b < batches; ++b) {
6069 for (int y = 0; y < output_height; ++y) {
6070 int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
6071 // Check offset calculation is the same as the reference version. See
6072 // function comment for details. We check using a non-float version of:
6073 // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
6074 // / output_height)));
6075 TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
6076 TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
6077 const uint8* y_input_ptr = input_ptr + in_y * row_offset;
6078 for (int x = 0; x < output_width; ++x) {
6079 int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
6080 // Check offset calculation is the same as the reference version. See
6081 // function comment for details. We check using a non-float version of:
6082 // TFLITE_DCHECK_EQ(in_y,
6083 // std::floor(y * (static_cast<float>(input_width)
6084 // / output_width)));
6085 TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
6086 TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
6087 const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
6088 memcpy(output_ptr, x_input_ptr, depth);
6089 output_ptr += depth;
6090 }
6091 }
6092 input_ptr += batch_offset;
6093 }
6094 }
6095
6096 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)6097 inline void Requantize(const input_type* input_data, int32_t size,
6098 int32_t effective_scale_multiplier,
6099 int32_t effective_scale_shift, int32_t input_zeropoint,
6100 int32_t output_zeropoint, output_type* output_data) {
6101 reference_ops::Requantize(input_data, size, effective_scale_multiplier,
6102 effective_scale_shift, input_zeropoint,
6103 output_zeropoint, output_data);
6104 }
6105
6106 template <>
6107 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
6108 int32_t effective_scale_multiplier,
6109 int32_t effective_scale_shift,
6110 int32_t input_zeropoint,
6111 int32_t output_zeropoint,
6112 uint8_t* output_data) {
6113 ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
6114
6115 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
6116 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
6117
6118 int i = 0;
6119 #ifdef USE_NEON
6120 // Constants.
6121 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6122 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6123 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6124 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6125
6126 for (; i <= size - 16; i += 16) {
6127 const int8x16_t input_vec = vld1q_s8(input_data + i);
6128 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
6129 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
6130 int32x4x4_t input;
6131 input.val[0] = vmovl_s16(vget_low_s16(first_half));
6132 input.val[1] = vmovl_s16(vget_high_s16(first_half));
6133 input.val[2] = vmovl_s16(vget_low_s16(second_half));
6134 input.val[3] = vmovl_s16(vget_high_s16(second_half));
6135 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6136 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6137 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6138 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6139
6140 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6141 input, effective_scale_multiplier, effective_scale_shift);
6142
6143 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6144 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6145 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6146 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6147 result.val[0] =
6148 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6149 result.val[1] =
6150 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6151 result.val[2] =
6152 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6153 result.val[3] =
6154 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6155
6156 const uint32x4_t result_val_1_unsigned =
6157 vreinterpretq_u32_s32(result.val[0]);
6158 const uint32x4_t result_val_2_unsigned =
6159 vreinterpretq_u32_s32(result.val[1]);
6160 const uint32x4_t result_val_3_unsigned =
6161 vreinterpretq_u32_s32(result.val[2]);
6162 const uint32x4_t result_val_4_unsigned =
6163 vreinterpretq_u32_s32(result.val[3]);
6164
6165 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
6166 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
6167 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
6168 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
6169 const uint16x8_t output_first_half =
6170 vcombine_u16(narrowed_val_1, narrowed_val_2);
6171 const uint16x8_t output_second_half =
6172 vcombine_u16(narrowed_val_3, narrowed_val_4);
6173 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
6174 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
6175 const uint8x16_t narrowed_result =
6176 vcombine_u8(narrowed_first_half, narrowed_second_half);
6177 vst1q_u8(output_data + i, narrowed_result);
6178 }
6179
6180 #endif
6181 for (; i < size; ++i) {
6182 const int32_t input = input_data[i] - input_zeropoint;
6183 const int32_t output =
6184 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6185 effective_scale_shift) +
6186 output_zeropoint;
6187 const int32_t clamped_output =
6188 std::max(std::min(output, kMaxOutput), kMinOutput);
6189 output_data[i] = static_cast<uint8_t>(clamped_output);
6190 }
6191 }
6192
6193 template <>
6194 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
6195 int32_t effective_scale_multiplier,
6196 int32_t effective_scale_shift,
6197 int32_t input_zeropoint,
6198 int32_t output_zeropoint,
6199 int8_t* output_data) {
6200 ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
6201
6202 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
6203 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
6204
6205 int i = 0;
6206 #ifdef USE_NEON
6207 // Constants.
6208 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6209 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6210 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6211 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6212
6213 for (; i <= size - 16; i += 16) {
6214 const uint8x16_t input_vec = vld1q_u8(input_data + i);
6215 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
6216 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
6217 int32x4x4_t input;
6218 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
6219 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
6220 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
6221 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
6222 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6223 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6224 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6225 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6226
6227 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6228 input, effective_scale_multiplier, effective_scale_shift);
6229
6230 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6231 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6232 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6233 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6234 result.val[0] =
6235 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6236 result.val[1] =
6237 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6238 result.val[2] =
6239 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6240 result.val[3] =
6241 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6242
6243 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
6244 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
6245 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
6246 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
6247 const int16x8_t output_first_half =
6248 vcombine_s16(narrowed_val_1, narrowed_val_2);
6249 const int16x8_t output_second_half =
6250 vcombine_s16(narrowed_val_3, narrowed_val_4);
6251 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
6252 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
6253 const int8x16_t narrowed_result =
6254 vcombine_s8(narrowed_first_half, narrowed_second_half);
6255 vst1q_s8(output_data + i, narrowed_result);
6256 }
6257
6258 #endif
6259 for (; i < size; ++i) {
6260 const int32_t input = input_data[i] - input_zeropoint;
6261 const int32_t output =
6262 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6263 effective_scale_shift) +
6264 output_zeropoint;
6265 const int32_t clamped_output =
6266 std::max(std::min(output, kMaxOutput), kMinOutput);
6267 output_data[i] = static_cast<int8_t>(clamped_output);
6268 }
6269 }
6270
6271 template <>
6272 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
6273 int32_t effective_scale_multiplier,
6274 int32_t effective_scale_shift,
6275 int32_t input_zeropoint,
6276 int32_t output_zeropoint,
6277 int8_t* output_data) {
6278 ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
6279
6280 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
6281 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
6282
6283 int i = 0;
6284 #ifdef USE_NEON
6285 // Constants.
6286 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6287 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6288 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6289 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6290
6291 for (; i <= size - 16; i += 16) {
6292 const int8x16_t input_vec = vld1q_s8(input_data + i);
6293 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
6294 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
6295 int32x4x4_t input;
6296 input.val[0] = vmovl_s16(vget_low_s16(first_half));
6297 input.val[1] = vmovl_s16(vget_high_s16(first_half));
6298 input.val[2] = vmovl_s16(vget_low_s16(second_half));
6299 input.val[3] = vmovl_s16(vget_high_s16(second_half));
6300
6301 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6302 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6303 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6304 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6305
6306 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6307 input, effective_scale_multiplier, effective_scale_shift);
6308
6309 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6310 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6311 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6312 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6313 result.val[0] =
6314 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6315 result.val[1] =
6316 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6317 result.val[2] =
6318 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6319 result.val[3] =
6320 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6321
6322 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
6323 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
6324 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
6325 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
6326 const int16x8_t output_first_half =
6327 vcombine_s16(narrowed_val_1, narrowed_val_2);
6328 const int16x8_t output_second_half =
6329 vcombine_s16(narrowed_val_3, narrowed_val_4);
6330 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
6331 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
6332 const int8x16_t narrowed_result =
6333 vcombine_s8(narrowed_first_half, narrowed_second_half);
6334 vst1q_s8(output_data + i, narrowed_result);
6335 }
6336
6337 #endif
6338 for (; i < size; ++i) {
6339 const int32_t input = input_data[i] - input_zeropoint;
6340 const int32_t output =
6341 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6342 effective_scale_shift) +
6343 output_zeropoint;
6344 const int32_t clamped_output =
6345 std::max(std::min(output, kMaxOutput), kMinOutput);
6346 output_data[i] = static_cast<int8_t>(clamped_output);
6347 }
6348 }
6349
6350 template <>
6351 inline void Requantize<uint8_t, uint8_t>(
6352 const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
6353 int32_t effective_scale_shift, int32_t input_zeropoint,
6354 int32_t output_zeropoint, uint8_t* output_data) {
6355 ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
6356
6357 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
6358 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
6359
6360 int i = 0;
6361 #ifdef USE_NEON
6362 // Constants.
6363 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
6364 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
6365 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
6366 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
6367
6368 for (; i <= size - 16; i += 16) {
6369 const uint8x16_t input_vec = vld1q_u8(input_data + i);
6370 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
6371 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
6372 int32x4x4_t input;
6373 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
6374 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
6375 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
6376 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
6377 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
6378 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
6379 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
6380 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
6381
6382 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
6383 input, effective_scale_multiplier, effective_scale_shift);
6384
6385 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
6386 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
6387 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
6388 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
6389 result.val[0] =
6390 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
6391 result.val[1] =
6392 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
6393 result.val[2] =
6394 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
6395 result.val[3] =
6396 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
6397
6398 const uint32x4_t result_val_1_unsigned =
6399 vreinterpretq_u32_s32(result.val[0]);
6400 const uint32x4_t result_val_2_unsigned =
6401 vreinterpretq_u32_s32(result.val[1]);
6402 const uint32x4_t result_val_3_unsigned =
6403 vreinterpretq_u32_s32(result.val[2]);
6404 const uint32x4_t result_val_4_unsigned =
6405 vreinterpretq_u32_s32(result.val[3]);
6406
6407 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
6408 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
6409 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
6410 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
6411 const uint16x8_t output_first_half =
6412 vcombine_u16(narrowed_val_1, narrowed_val_2);
6413 const uint16x8_t output_second_half =
6414 vcombine_u16(narrowed_val_3, narrowed_val_4);
6415 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
6416 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
6417 const uint8x16_t narrowed_result =
6418 vcombine_u8(narrowed_first_half, narrowed_second_half);
6419 vst1q_u8(output_data + i, narrowed_result);
6420 }
6421
6422 #endif
6423 for (; i < size; ++i) {
6424 const int32_t input = input_data[i] - input_zeropoint;
6425 const int32_t output =
6426 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
6427 effective_scale_shift) +
6428 output_zeropoint;
6429 const int32_t clamped_output =
6430 std::max(std::min(output, kMaxOutput), kMinOutput);
6431 output_data[i] = static_cast<uint8_t>(clamped_output);
6432 }
6433 }
6434
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)6435 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
6436 const RuntimeShape& output_shape, float* output_data) {
6437 ruy::profiler::ScopeLabel label("HardSwish/Float");
6438 auto size = MatchingFlatSize(input_shape, output_shape);
6439 int i = 0;
6440 #ifdef USE_NEON
6441 const float32x4_t zero = vdupq_n_f32(0.0f);
6442 const float32x4_t three = vdupq_n_f32(3.0f);
6443 const float32x4_t six = vdupq_n_f32(6.0f);
6444 const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
6445
6446 for (; i <= size - 16; i += 16) {
6447 // 4x partially unrolled version of the loop below. Refer to its comments.
6448 const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
6449 const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
6450 const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
6451 const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
6452 const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
6453 const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
6454 const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
6455 const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
6456 const float32x4_t in_reluish_0 =
6457 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
6458 const float32x4_t in_reluish_1 =
6459 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
6460 const float32x4_t in_reluish_2 =
6461 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
6462 const float32x4_t in_reluish_3 =
6463 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
6464 const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
6465 const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
6466 const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
6467 const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
6468 vst1q_f32(output_data + i + 0, product_0);
6469 vst1q_f32(output_data + i + 4, product_1);
6470 vst1q_f32(output_data + i + 8, product_2);
6471 vst1q_f32(output_data + i + 12, product_3);
6472 }
6473 for (; i <= size - 4; i += 4) {
6474 // The expression to be computed is:
6475 // out = one_sixth * in * min(six, max(zero, (in + three)))
6476 // We structure the AST to have two roughly balanced, independent branches:
6477 // - Multiplication: in_scaled = one_sixth * in.
6478 // - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
6479 // Then the remaining multiplication at the root of the tree.
6480 const float32x4_t in = vld1q_f32(input_data + i);
6481 const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
6482 const float32x4_t in_reluish =
6483 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
6484 const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
6485 vst1q_f32(output_data + i, product);
6486 }
6487 #endif
6488 for (; i < size; i++) {
6489 const float in = input_data[i];
6490 output_data[i] =
6491 in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6492 }
6493 }
6494
6495 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6496 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6497 // Narrow values down to 8 bit unsigned, saturating.
6498 uint8x8_t res8 = vqmovun_s16(src);
6499 // Store results to destination.
6500 vst1_u8(dst, res8);
6501 }
6502
SaturateAndStore(int16x8_t src,std::int8_t * dst)6503 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6504 // Narrow values down to 8 bit unsigned, saturating.
6505 int8x8_t res8 = vqmovn_s16(src);
6506 // Store results to destination.
6507 vst1_s8(dst, res8);
6508 }
6509 #endif
6510
6511 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6512 inline void HardSwish(const HardSwishParams& params,
6513 const RuntimeShape& input_shape, const T* input_data,
6514 const RuntimeShape& output_shape, T* output_data) {
6515 ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6516
6517 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6518
6519 int i = 0;
6520 // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6521 // amounts that can be zero, in which case we rely on the correct behavior
6522 // of a left shift by zero returning just its first operand unmodified.
6523 // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6524 // buggy in the case of zero shift amounts, see b/137199585. That is why
6525 // this NEON code path is restricted to true ARM NEON, excluding
6526 // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6527 // left shifts is slow scalar code, so there may not be much benefit in
6528 // running that over just plain reference code.
6529 //
6530 // TODO(b/137199585): revisit when this is fixed.
6531 #ifdef __ARM_NEON
6532 const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6533 vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6534 const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6535 vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6536 const int16x8_t negative_reluish_multiplier_exponent =
6537 vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6538 const int16x8_t constant_32767 = vdupq_n_s16(32767);
6539 const int16x8_t output_multiplier_exponent =
6540 vdupq_n_s16(params.output_multiplier_exponent);
6541 const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6542 // 4x unrolled version of the below NEON loop. Read that first.
6543 for (; i <= flat_size - 32; i += 32) {
6544 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6545 const int16x8x2_t input_value_0_1 =
6546 Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6547 const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6548 input_data + i + 16, params.input_zero_point);
6549 const int16x8_t input_value_on_hires_input_scale_0 =
6550 vshlq_n_s16(input_value_0_1.val[0], 7);
6551 const int16x8_t input_value_on_hires_input_scale_1 =
6552 vshlq_n_s16(input_value_0_1.val[1], 7);
6553 const int16x8_t input_value_on_hires_input_scale_2 =
6554 vshlq_n_s16(input_value_2_3.val[0], 7);
6555 const int16x8_t input_value_on_hires_input_scale_3 =
6556 vshlq_n_s16(input_value_2_3.val[1], 7);
6557 const int16x8_t input_value_on_preshift_output_scale_0 =
6558 vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6559 params.output_multiplier_fixedpoint_int16);
6560 const int16x8_t input_value_on_preshift_output_scale_1 =
6561 vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6562 params.output_multiplier_fixedpoint_int16);
6563 const int16x8_t input_value_on_preshift_output_scale_2 =
6564 vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6565 params.output_multiplier_fixedpoint_int16);
6566 const int16x8_t input_value_on_preshift_output_scale_3 =
6567 vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6568 params.output_multiplier_fixedpoint_int16);
6569 int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6570 int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6571 int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6572 int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6573 reluish_value_0 = vqshlq_s16(
6574 reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6575 reluish_value_1 = vqshlq_s16(
6576 reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6577 reluish_value_2 = vqshlq_s16(
6578 reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6579 reluish_value_3 = vqshlq_s16(
6580 reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6581 reluish_value_0 = vqrdmulhq_n_s16(
6582 reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6583 reluish_value_1 = vqrdmulhq_n_s16(
6584 reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6585 reluish_value_2 = vqrdmulhq_n_s16(
6586 reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6587 reluish_value_3 = vqrdmulhq_n_s16(
6588 reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6589 reluish_value_0 = vqshlq_s16(reluish_value_0,
6590 positive_reluish_multiplier_exponent_last_bit);
6591 reluish_value_1 = vqshlq_s16(reluish_value_1,
6592 positive_reluish_multiplier_exponent_last_bit);
6593 reluish_value_2 = vqshlq_s16(reluish_value_2,
6594 positive_reluish_multiplier_exponent_last_bit);
6595 reluish_value_3 = vqshlq_s16(reluish_value_3,
6596 positive_reluish_multiplier_exponent_last_bit);
6597 reluish_value_0 =
6598 vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6599 reluish_value_1 =
6600 vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6601 reluish_value_2 =
6602 vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6603 reluish_value_3 =
6604 vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6605 reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6606 reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6607 reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6608 reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6609 const int16x8_t preshift_output_value_0 =
6610 vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6611 const int16x8_t preshift_output_value_1 =
6612 vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6613 const int16x8_t preshift_output_value_2 =
6614 vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6615 const int16x8_t preshift_output_value_3 =
6616 vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6617 int16x8_t output_value_0 =
6618 vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6619 int16x8_t output_value_1 =
6620 vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6621 int16x8_t output_value_2 =
6622 vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6623 int16x8_t output_value_3 =
6624 vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6625 output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6626 output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6627 output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6628 output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6629 SaturateAndStore(output_value_0, output_data + i);
6630 SaturateAndStore(output_value_1, output_data + i + 8);
6631 SaturateAndStore(output_value_2, output_data + i + 16);
6632 SaturateAndStore(output_value_3, output_data + i + 24);
6633 }
6634 // NEON version of reference_ops::HardSwish. Read that first.
6635 for (; i <= flat_size - 8; i += 8) {
6636 using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6637 const int16x8_t input_value =
6638 Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6639 const int16x8_t input_value_on_hires_input_scale =
6640 vshlq_n_s16(input_value, 7);
6641 const int16x8_t input_value_on_preshift_output_scale =
6642 vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6643 params.output_multiplier_fixedpoint_int16);
6644 int16x8_t reluish_value = input_value_on_hires_input_scale;
6645 reluish_value = vqshlq_s16(reluish_value,
6646 positive_reluish_multiplier_exponent_minus_one);
6647 reluish_value = vqrdmulhq_n_s16(reluish_value,
6648 params.reluish_multiplier_fixedpoint_int16);
6649 reluish_value = vqshlq_s16(reluish_value,
6650 positive_reluish_multiplier_exponent_last_bit);
6651 reluish_value =
6652 vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6653 reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6654 const int16x8_t preshift_output_value =
6655 vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6656 int16x8_t output_value =
6657 vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6658 output_value = vaddq_s16(output_value, output_zero_point);
6659 SaturateAndStore(output_value, output_data + i);
6660 }
6661 #endif
6662 // TODO(b/137208495): revisit when unit tests cover reference code.
6663 // Fall back to reference_ops::HardSwish. In general we have preferred
6664 // to duplicate such scalar code rather than call reference code to handle
6665 // leftovers, thinking that code duplication was not a big concern.
6666 // However, most of our unit tests happen to test only optimized code,
6667 // and the quantized HardSwish implementation is nontrivial enough that
6668 // I really want test coverage for the reference code.
6669 if (i < flat_size) {
6670 const RuntimeShape leftover_shape{flat_size - i};
6671 reference_ops::HardSwish(params, leftover_shape, input_data + i,
6672 leftover_shape, output_data + i);
6673 }
6674 }
6675
6676 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6677 inline void IntegerExponentPow(const ArithmeticParams& params,
6678 const RuntimeShape& unextended_base_shape,
6679 const T* base_data, const int exponent,
6680 const RuntimeShape& unextended_output_shape,
6681 T* output_data) {
6682 TFLITE_DCHECK_GE(exponent, 1);
6683 if (exponent == 1) {
6684 // copy data over.
6685 std::memcpy(output_data, base_data,
6686 unextended_base_shape.FlatSize() * sizeof(T));
6687 } else {
6688 IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6689 unextended_output_shape, output_data);
6690 Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6691 output_data, unextended_output_shape, output_data);
6692 if (exponent % 2 == 1) {
6693 Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6694 output_data, unextended_output_shape, output_data);
6695 }
6696 }
6697 }
6698
6699 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6700 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6701 const T* input1_data,
6702 const RuntimeShape& unextended_input2_shape,
6703 const T* input2_data,
6704 const RuntimeShape& unextended_output_shape,
6705 T* output_data) {
6706 ruy::profiler::ScopeLabel label("PowBroadcast");
6707
6708 if (unextended_input2_shape.FlatSize() == 1) {
6709 static const float epsilon = 1e-5;
6710 const T exponent = input2_data[0];
6711 const int int_exponent = static_cast<int>(std::round(exponent));
6712 if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6713 (int_exponent >= 1)) {
6714 ArithmeticParams params;
6715 if (std::is_same<T, float>::value) {
6716 params.float_activation_max = std::numeric_limits<float>::max();
6717 params.float_activation_min = std::numeric_limits<float>::lowest();
6718 } else if (std::is_same<T, int>::value) {
6719 params.quantized_activation_max = std::numeric_limits<int>::max();
6720 params.quantized_activation_min = std::numeric_limits<int>::lowest();
6721 }
6722 IntegerExponentPow(params, unextended_input1_shape, input1_data,
6723 int_exponent, unextended_output_shape, output_data);
6724 return;
6725 }
6726 }
6727 reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6728 unextended_input2_shape, input2_data,
6729 unextended_output_shape, output_data);
6730 }
6731
6732 #ifdef USE_NEON
6733
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6734 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6735 const float32x4_t scale_dup,
6736 const float32x4_t zero_times_scale_dup,
6737 float32x4_t* output) {
6738 #ifdef __ARM_FEATURE_FMA
6739 *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6740 #else
6741 *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6742 zero_times_scale_dup);
6743 #endif
6744 }
6745
6746 #endif // USE_NEON
6747
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6748 inline void Dequantize(const tflite::DequantizationParams& op_params,
6749 const RuntimeShape& input_shape,
6750 const uint8_t* input_data,
6751 const RuntimeShape& output_shape, float* output_data) {
6752 ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6753 const int32 zero_point = op_params.zero_point;
6754 const double scale = op_params.scale;
6755 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6756
6757 int i = 0;
6758 #ifdef USE_NEON
6759 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6760 const float32x4_t zero_times_scale_dup =
6761 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6762 for (; i <= flat_size - 8; i += 8) {
6763 const uint8x8_t input_u8 = vld1_u8(input_data + i);
6764 const uint16x8_t input_u16 = vmovl_u8(input_u8);
6765 const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6766 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6767 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6768 const int32x4_t val_low = vmovl_s16(input_s16_low);
6769 const int32x4_t val_high = vmovl_s16(input_s16_high);
6770
6771 float32x4_t result_low, result_high;
6772 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6773 &result_low);
6774 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6775 &result_high);
6776
6777 vst1q_f32(output_data + i, result_low);
6778 vst1q_f32(output_data + i + 4, result_high);
6779 }
6780 #endif // NEON
6781 for (; i < flat_size; ++i) {
6782 const int32 val = input_data[i];
6783 const float result = static_cast<float>(scale * (val - zero_point));
6784 output_data[i] = result;
6785 }
6786 }
6787
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6788 inline void Dequantize(const tflite::DequantizationParams& op_params,
6789 const RuntimeShape& input_shape,
6790 const int8_t* input_data,
6791 const RuntimeShape& output_shape, float* output_data) {
6792 ruy::profiler::ScopeLabel label("Dequantize/Int8");
6793 const int32 zero_point = op_params.zero_point;
6794 const double scale = op_params.scale;
6795 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6796
6797 int i = 0;
6798 #ifdef USE_NEON
6799 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6800 const float32x4_t zero_times_scale_dup =
6801 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6802 for (; i <= flat_size - 8; i += 8) {
6803 const int8x8_t input_s8 = vld1_s8(input_data + i);
6804 const int16x8_t input_s16 = vmovl_s8(input_s8);
6805 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6806 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6807 const int32x4_t val_low = vmovl_s16(input_s16_low);
6808 const int32x4_t val_high = vmovl_s16(input_s16_high);
6809
6810 float32x4_t result_low, result_high;
6811 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6812 &result_low);
6813 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6814 &result_high);
6815
6816 vst1q_f32(output_data + i, result_low);
6817 vst1q_f32(output_data + i + 4, result_high);
6818 }
6819 #endif // NEON
6820 for (; i < flat_size; ++i) {
6821 const int32 val = input_data[i];
6822 const float result = static_cast<float>(scale * (val - zero_point));
6823 output_data[i] = result;
6824 }
6825 }
6826
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6827 inline void Dequantize(const tflite::DequantizationParams& op_params,
6828 const RuntimeShape& input_shape,
6829 const int16_t* input_data,
6830 const RuntimeShape& output_shape, float* output_data) {
6831 ruy::profiler::ScopeLabel label("Dequantize/Int16");
6832 const int32 zero_point = op_params.zero_point;
6833 const double scale = op_params.scale;
6834 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6835
6836 int i = 0;
6837 #ifdef USE_NEON
6838 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6839 const float32x4_t zero_times_scale_dup =
6840 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6841 for (; i <= flat_size - 8; i += 8) {
6842 const int16x4_t input_s16_low = vld1_s16(input_data + i);
6843 const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6844 const int32x4_t val_low = vmovl_s16(input_s16_low);
6845 const int32x4_t val_high = vmovl_s16(input_s16_high);
6846
6847 float32x4_t result_low, result_high;
6848 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6849 &result_low);
6850 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6851 &result_high);
6852
6853 vst1q_f32(output_data + i, result_low);
6854 vst1q_f32(output_data + i + 4, result_high);
6855 }
6856 #endif // NEON
6857 for (; i < flat_size; ++i) {
6858 const int32 val = input_data[i];
6859 const float result = static_cast<float>(scale * (val - zero_point));
6860 output_data[i] = result;
6861 }
6862 }
6863
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6864 inline void Dequantize(const RuntimeShape& input_shape,
6865 const Eigen::half* input_data,
6866 const RuntimeShape& output_shape, float* output_data) {
6867 reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6868 }
6869
6870 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6871 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6872 const RuntimeShape& input_shape,
6873 const float* input_data,
6874 const RuntimeShape& output_shape, T* output_data) {
6875 reference_ops::AffineQuantize(op_params, input_shape, input_data,
6876 output_shape, output_data);
6877 }
6878
6879 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6880 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6881 const RuntimeShape& input_shape,
6882 const float* input_data,
6883 const RuntimeShape& output_shape,
6884 int8_t* output_data) {
6885 ruy::profiler::ScopeLabel label("Quantize/Int8");
6886 const int32 zero_point = op_params.zero_point;
6887 const double scale = static_cast<double>(op_params.scale);
6888 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6889 static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6890 static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6891
6892 int i = 0;
6893 #ifdef USE_NEON
6894 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6895 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6896 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6897 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6898
6899 for (; i <= flat_size - 8; i += 8) {
6900 const float* src_data_ptr = input_data + i;
6901 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6902 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6903
6904 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6905 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6906
6907 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6908 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6909
6910 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6911 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6912
6913 // Clamp the values to fit the target type's range.
6914 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6915 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6916 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6917 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6918
6919 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6920 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6921 const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6922 const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6923 vst1_s8(output_data + i, combined_val_narrowed);
6924 }
6925 #endif // NEON
6926
6927 for (; i < flat_size; ++i) {
6928 const float val = input_data[i];
6929 const int32 unclamped =
6930 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6931 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6932 output_data[i] = clamped;
6933 }
6934 }
6935
6936 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6937 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6938 const RuntimeShape& input_shape,
6939 const float* input_data,
6940 const RuntimeShape& output_shape,
6941 uint8_t* output_data) {
6942 ruy::profiler::ScopeLabel label("Quantize/Uint8");
6943 const int32 zero_point = op_params.zero_point;
6944 const double scale = static_cast<double>(op_params.scale);
6945 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6946 static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6947 static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6948
6949 int i = 0;
6950 #ifdef USE_NEON
6951 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6952 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6953 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6954 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6955
6956 for (; i <= flat_size - 8; i += 8) {
6957 const float* src_data_ptr = input_data + i;
6958 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6959 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6960
6961 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6962 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6963
6964 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6965 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6966
6967 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6968 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6969
6970 // Clamp the values to fit the target type's range.
6971 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6972 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6973 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6974 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6975
6976 const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6977 const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6978 const uint16x8_t combined_val =
6979 vcombine_u16(narrowed_val_0, narrowed_val_1);
6980 const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6981 vst1_u8(output_data + i, combined_val_narrowed);
6982 }
6983 #endif // NEON
6984
6985 for (; i < flat_size; ++i) {
6986 const float val = input_data[i];
6987 const int32 unclamped =
6988 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6989 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6990 output_data[i] = clamped;
6991 }
6992 }
6993
6994 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6995 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6996 const RuntimeShape& input_shape,
6997 const float* input_data,
6998 const RuntimeShape& output_shape,
6999 int16_t* output_data) {
7000 ruy::profiler::ScopeLabel label("Quantize/Int16");
7001 const int32 zero_point = op_params.zero_point;
7002 const double scale = static_cast<double>(op_params.scale);
7003 const int flat_size = MatchingFlatSize(input_shape, output_shape);
7004 static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
7005 static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
7006
7007 int i = 0;
7008 #ifdef USE_NEON
7009 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
7010 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
7011 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
7012 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
7013
7014 for (; i <= flat_size - 8; i += 8) {
7015 const float* src_data_ptr = input_data + i;
7016 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
7017 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
7018
7019 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
7020 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
7021
7022 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
7023 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
7024
7025 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
7026 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
7027
7028 // Clamp the values to fit the target type's range.
7029 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
7030 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
7031 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
7032 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
7033
7034 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
7035 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
7036 vst1_s16(output_data + i, narrowed_val_0);
7037 vst1_s16(output_data + i + 4, narrowed_val_1);
7038 }
7039 #endif // NEON
7040
7041 for (; i < flat_size; ++i) {
7042 const float val = input_data[i];
7043 const int32 unclamped =
7044 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
7045 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
7046 output_data[i] = clamped;
7047 }
7048 }
7049
7050 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7051 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7052 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7053 #ifdef GEMMLOWP_NEON
7054
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)7055 inline int16x8x4_t SaturatingRounding(
7056 int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
7057 int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
7058 // This performs what is expressed in the scalar code as
7059 // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7060 // static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7061 // static_cast<int16>(input_multiplier));
7062 const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
7063 const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
7064 const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
7065 const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
7066 const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
7067 int16x8x4_t result;
7068 result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
7069 result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
7070 result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
7071 result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
7072 return result;
7073 }
7074
7075 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
7076 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)7077 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
7078 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
7079 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
7080 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
7081 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
7082 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
7083 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
7084 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
7085
7086 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
7087 // method, gemmlowp::tanh spends about 80% of the execution times. The
7088 // current implementation is rougly 12-bit accurate in the 16-bit fixed
7089 // point case. Until reaching to error bounds, there are rooms for
7090 // improvements.
7091 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
7092 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
7093 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
7094 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
7095
7096 // Divide by 2^7 as in the scalar code
7097 int16x8x4_t result;
7098 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
7099 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
7100 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
7101 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
7102 return result;
7103 }
7104
7105 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
7106 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)7107 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
7108 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
7109 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
7110 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
7111 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
7112 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
7113 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
7114 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
7115
7116 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
7117 // method, gemmlowp::tanh spends about 80% of the execution times. The
7118 // current implementation is rougly 12-bit accurate in the 16-bit fixed
7119 // point case. Until reaching to error bounds, there are rooms for
7120 // improvements.
7121 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
7122 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
7123 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
7124 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
7125
7126 // Divide by 2^7 as in the scalar code
7127 int16x8x4_t result;
7128 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
7129 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
7130 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
7131 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
7132 return result;
7133 }
7134
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)7135 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
7136 int16x8x2_t input_val, int16x8_t range_radius_dup,
7137 int16x8_t neg_range_radius_dup) {
7138 const uint16x8_t mask_rightclamp_0 =
7139 vcgtq_s16(input_val.val[0], range_radius_dup);
7140 const uint16x8_t mask_rightclamp_1 =
7141 vcgtq_s16(input_val.val[1], range_radius_dup);
7142
7143 const uint16x8_t mask_leftclamp_0 =
7144 vcgeq_s16(input_val.val[0], neg_range_radius_dup);
7145 const uint16x8_t mask_leftclamp_1 =
7146 vcgeq_s16(input_val.val[1], neg_range_radius_dup);
7147
7148 uint8x16x2_t result;
7149 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
7150 vshrn_n_u16(mask_leftclamp_1, 8));
7151 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
7152 vshrn_n_u16(mask_rightclamp_1, 8));
7153 return result;
7154 }
7155
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)7156 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
7157 int16x8x2_t input_val, int16x8_t range_radius_dup,
7158 int16x8_t neg_range_radius_dup) {
7159 const uint16x8_t mask_rightclamp_0 =
7160 vcgtq_s16(input_val.val[0], range_radius_dup);
7161 const uint16x8_t mask_rightclamp_1 =
7162 vcgtq_s16(input_val.val[1], range_radius_dup);
7163
7164 const uint16x8_t mask_leftclamp_0 =
7165 vcltq_s16(input_val.val[0], neg_range_radius_dup);
7166 const uint16x8_t mask_leftclamp_1 =
7167 vcltq_s16(input_val.val[1], neg_range_radius_dup);
7168
7169 uint8x16x2_t result;
7170 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
7171 vshrn_n_u16(mask_leftclamp_1, 8));
7172 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
7173 vshrn_n_u16(mask_rightclamp_1, 8));
7174 return result;
7175 }
7176
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)7177 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
7178 uint8x16x2_t masks_clamp) {
7179 // Store back to memory
7180 vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
7181 masks_clamp.val[0]));
7182 }
7183
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)7184 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
7185 uint8x16x2_t masks_clamp) {
7186 static const int8x16_t max_dup = vdupq_n_s8(127);
7187 static const int8x16_t min_dup = vdupq_n_s8(-128);
7188 // Store back to memory
7189 vst1q_s8(output_dst,
7190 vbslq_s8(masks_clamp.val[1], max_dup,
7191 vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
7192 }
7193
7194 #endif // GEMMLOWP_NEON
7195
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)7196 inline void Tanh16bitPrecision(const TanhParams& params,
7197 const RuntimeShape& input_shape,
7198 const uint8* input_data,
7199 const RuntimeShape& output_shape,
7200 uint8* output_data) {
7201 // Note that this is almost the exact same code as in Logistic().
7202 ruy::profiler::ScopeLabel label("Tanh/Uint8");
7203 const int32 input_zero_point = params.input_zero_point;
7204 const int32 input_range_radius = params.input_range_radius;
7205 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
7206 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7207 const int size = MatchingFlatSize(input_shape, output_shape);
7208
7209 int c = 0;
7210 int16_t output_zero_point = 128;
7211
7212 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7213 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7214 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7215 #ifdef GEMMLOWP_NEON
7216 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7217 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7218 const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
7219
7220 // Handle 32 values at a time
7221 for (; c <= size - 32; c += 32) {
7222 // Read input uint8 values, cast to int16 and subtract input_zero_point
7223 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7224 const int16x8x2_t input_val_centered_0_1 =
7225 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7226 const int16x8x2_t input_val_centered_2_3 =
7227 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7228
7229 // Prepare the bit masks that we will use at the end to implement the logic
7230 // that was expressed in the scalar code with branching:
7231 // if (input_val_centered < -input_range_radius) {
7232 // output_val = 0;
7233 // } else if (input_val_centered > input_range_radius) {
7234 // output_val = 255;
7235 // } else {
7236 // ...
7237 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
7238 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7239 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
7240 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7241
7242 int16x8x4_t input_val_rescaled = SaturatingRounding(
7243 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7244 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7245 input_left_shift, input_multiplier);
7246
7247 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
7248
7249 // Add the output zero point
7250 output_val_s16.val[0] =
7251 vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
7252 output_val_s16.val[1] =
7253 vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
7254 output_val_s16.val[2] =
7255 vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
7256 output_val_s16.val[3] =
7257 vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
7258
7259 // Cast output values to uint8, saturating
7260 uint8x16_t output_val_u8_0_1 = vcombine_u8(
7261 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7262 uint8x16_t output_val_u8_2_3 = vcombine_u8(
7263 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7264
7265 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7266 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7267 masks_clamp_2_3);
7268 }
7269 #endif // GEMMLOWP_NEON
7270 // Leftover loop: handle one value at a time with scalar code.
7271 for (; c < size; ++c) {
7272 const uint8 input_val_u8 = input_data[c];
7273 const int16 input_val_centered =
7274 static_cast<int16>(input_val_u8) - input_zero_point;
7275 uint8 output_val;
7276 if (input_val_centered < -input_range_radius) {
7277 output_val = 0;
7278 } else if (input_val_centered > input_range_radius) {
7279 output_val = 255;
7280 } else {
7281 using gemmlowp::SaturatingRoundingDoublingHighMul;
7282 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7283 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7284 static_cast<int16>(input_multiplier));
7285 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7286 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7287 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7288 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
7289 using gemmlowp::RoundingDivideByPOT;
7290 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
7291 output_val_s16 += output_zero_point;
7292 if (output_val_s16 == 256) {
7293 output_val_s16 = 255;
7294 }
7295 TFLITE_DCHECK_GE(output_val_s16, 0);
7296 TFLITE_DCHECK_LE(output_val_s16, 255);
7297 output_val = static_cast<uint8>(output_val_s16);
7298 }
7299 output_data[c] = output_val;
7300 }
7301 }
7302
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7303 inline void Tanh16bitPrecision(const TanhParams& params,
7304 const RuntimeShape& input_shape,
7305 const int8* input_data,
7306 const RuntimeShape& output_shape,
7307 int8* output_data) {
7308 // Note that this is almost the exact same code as in Logistic().
7309 ruy::profiler::ScopeLabel label("Tanh/Int8");
7310 const int32 input_zero_point = params.input_zero_point;
7311 const int32 input_range_radius = params.input_range_radius;
7312 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
7313 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7314 const int size = MatchingFlatSize(input_shape, output_shape);
7315
7316 int c = 0;
7317 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7318 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7319 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7320 #ifdef GEMMLOWP_NEON
7321 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7322 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7323
7324 // Handle 32 values at a time
7325 for (; c <= size - 32; c += 32) {
7326 // Read input int8 values, cast to int16 and subtract input_zero_point
7327 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7328 const int16x8x2_t input_val_centered_0_1 =
7329 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7330 const int16x8x2_t input_val_centered_2_3 =
7331 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7332
7333 // Prepare the bit masks that we will use at the end to implement the logic
7334 // that was expressed in the scalar code with branching:
7335 // if (input_val_centered < -input_range_radius) {
7336 // output_val = -128;
7337 // } else if (input_val_centered > input_range_radius) {
7338 // output_val = 127;
7339 // } else {
7340 // ...
7341 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7342 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7343 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7344 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7345
7346 int16x8x4_t input_val_rescaled = SaturatingRounding(
7347 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7348 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7349 input_left_shift, input_multiplier);
7350
7351 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
7352
7353 // Cast output values to uint8, saturating
7354 int8x16_t output_val_s8_0_1 = vcombine_s8(
7355 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7356 int8x16_t output_val_s8_2_3 = vcombine_s8(
7357 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7358
7359 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7360 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7361 masks_clamp_2_3);
7362 }
7363 #endif // GEMMLOWP_NEON
7364 // Leftover loop: handle one value at a time with scalar code.
7365 for (; c < size; ++c) {
7366 const int8 input_val_s8 = input_data[c];
7367 const int16 input_val_centered =
7368 static_cast<int16>(input_val_s8) - input_zero_point;
7369 int8 output_val;
7370 if (input_val_centered <= -input_range_radius) {
7371 output_val = -128;
7372 } else if (input_val_centered >= input_range_radius) {
7373 output_val = 127;
7374 } else {
7375 using gemmlowp::SaturatingRoundingDoublingHighMul;
7376 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7377 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7378 static_cast<int16>(input_multiplier));
7379 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7380 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7381 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7382 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
7383 using gemmlowp::RoundingDivideByPOT;
7384 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
7385 if (output_val_s16 == 128) {
7386 output_val_s16 = 127;
7387 }
7388 TFLITE_DCHECK_GE(output_val_s16, -128);
7389 TFLITE_DCHECK_LE(output_val_s16, 127);
7390 output_val = static_cast<int8>(output_val_s16);
7391 }
7392 output_data[c] = output_val;
7393 }
7394 }
7395
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)7396 inline void Logistic16bitPrecision(const LogisticParams& params,
7397 const RuntimeShape& input_shape,
7398 const uint8* input_data,
7399 const RuntimeShape& output_shape,
7400 uint8* output_data) {
7401 ruy::profiler::ScopeLabel label("Logistic/Uint8");
7402 const int32 input_zero_point = params.input_zero_point;
7403 const int32 input_range_radius = params.input_range_radius;
7404 const int32 input_multiplier = params.input_multiplier;
7405 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7406 const int size = MatchingFlatSize(input_shape, output_shape);
7407
7408 int c = 0;
7409 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7410 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7411 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7412 #ifdef GEMMLOWP_NEON
7413 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7414 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7415
7416 // Handle 32 values at a time
7417 for (; c <= size - 32; c += 32) {
7418 // Read input uint8 values, cast to int16 and subtract input_zero_point
7419 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7420 const int16x8x2_t input_val_centered_0_1 =
7421 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7422 const int16x8x2_t input_val_centered_2_3 =
7423 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7424
7425 // Prepare the bit masks that we will use at the end to implement the logic
7426 // that was expressed in the scalar code with branching:
7427 // if (input_val_centered < -input_range_radius) {
7428 // output_val = 0;
7429 // } else if (input_val_centered > input_range_radius) {
7430 // output_val = 255;
7431 // } else {
7432 // ...
7433 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
7434 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7435 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
7436 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7437
7438 int16x8x4_t input_val_rescaled = SaturatingRounding(
7439 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7440 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7441 input_left_shift, input_multiplier);
7442
7443 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7444
7445 // Cast output values to uint8, saturating
7446 uint8x16_t output_val_u8_0_1 = vcombine_u8(
7447 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7448 uint8x16_t output_val_u8_2_3 = vcombine_u8(
7449 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7450
7451 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7452 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7453 masks_clamp_2_3);
7454 }
7455 #endif // GEMMLOWP_NEON
7456 // Leftover loop: handle one value at a time with scalar code.
7457 for (; c < size; ++c) {
7458 const uint8 input_val_u8 = input_data[c];
7459 const int16 input_val_centered =
7460 static_cast<int16>(input_val_u8) - input_zero_point;
7461 uint8 output_val;
7462 if (input_val_centered < -input_range_radius) {
7463 output_val = 0;
7464 } else if (input_val_centered > input_range_radius) {
7465 output_val = 255;
7466 } else {
7467 using gemmlowp::SaturatingRoundingDoublingHighMul;
7468 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7469 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7470 static_cast<int16>(input_multiplier));
7471 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7472 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7473 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7474 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7475 using gemmlowp::RoundingDivideByPOT;
7476 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7477 if (output_val_s16 == 256) {
7478 output_val_s16 = 255;
7479 }
7480 TFLITE_DCHECK_GE(output_val_s16, 0);
7481 TFLITE_DCHECK_LE(output_val_s16, 255);
7482 output_val = static_cast<uint8>(output_val_s16);
7483 }
7484 output_data[c] = output_val;
7485 }
7486 }
7487
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7488 inline void Logistic16bitPrecision(const LogisticParams& params,
7489 const RuntimeShape& input_shape,
7490 const int8* input_data,
7491 const RuntimeShape& output_shape,
7492 int8* output_data) {
7493 ruy::profiler::ScopeLabel label("Logistic/Int8");
7494 const int32 input_zero_point = params.input_zero_point;
7495 const int32 input_range_radius = params.input_range_radius;
7496 const int32 input_multiplier = params.input_multiplier;
7497 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7498 const int size = MatchingFlatSize(input_shape, output_shape);
7499
7500 int c = 0;
7501 const int16 output_zero_point = 128;
7502 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7503 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7504 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7505 #ifdef GEMMLOWP_NEON
7506 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7507 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7508 const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7509
7510 // Handle 32 values at a time
7511 for (; c <= size - 32; c += 32) {
7512 // Read input int8 values, cast to int16 and subtract input_zero_point
7513 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7514 const int16x8x2_t input_val_centered_0_1 =
7515 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7516 const int16x8x2_t input_val_centered_2_3 =
7517 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7518
7519 // Prepare the bit masks that we will use at the end to implement the logic
7520 // that was expressed in the scalar code with branching:
7521 // if (input_val_centered < -input_range_radius) {
7522 // output_val = -128;
7523 // } else if (input_val_centered > input_range_radius) {
7524 // output_val = 127;
7525 // } else {
7526 // ...
7527 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7528 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7529 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7530 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7531
7532 int16x8x4_t input_val_rescaled = SaturatingRounding(
7533 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7534 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7535 input_left_shift, input_multiplier);
7536
7537 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7538
7539 // Substract output zero point.
7540 output_val_s16.val[0] =
7541 vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7542 output_val_s16.val[1] =
7543 vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7544 output_val_s16.val[2] =
7545 vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7546 output_val_s16.val[3] =
7547 vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7548
7549 // Cast output values to int8, saturating
7550 int8x16_t output_val_s8_0_1 = vcombine_s8(
7551 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7552 int8x16_t output_val_s8_2_3 = vcombine_s8(
7553 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7554
7555 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7556 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7557 masks_clamp_2_3);
7558 }
7559 #endif // GEMMLOWP_NEON
7560 // Leftover loop: handle one value at a time with scalar code.
7561 for (; c < size; ++c) {
7562 const int8 input_val_s8 = input_data[c];
7563 const int16 input_val_centered =
7564 static_cast<int16>(input_val_s8) - input_zero_point;
7565 int8 output_val;
7566 if (input_val_centered < -input_range_radius) {
7567 output_val = -128;
7568 } else if (input_val_centered > input_range_radius) {
7569 output_val = 127;
7570 } else {
7571 using gemmlowp::SaturatingRoundingDoublingHighMul;
7572 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7573 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7574 static_cast<int16>(input_multiplier));
7575 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7576 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7577 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7578 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7579 using gemmlowp::RoundingDivideByPOT;
7580 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7581 output_val_s16 -= output_zero_point;
7582 if (output_val_s16 == 128) {
7583 output_val_s16 = 127;
7584 }
7585 TFLITE_DCHECK_GE(output_val_s16, -128);
7586 TFLITE_DCHECK_LE(output_val_s16, 127);
7587 output_val = static_cast<int8>(output_val_s16);
7588 }
7589 output_data[c] = output_val;
7590 }
7591 }
7592
7593 // Transpose2D only deals with typical 2D matrix transpose ops.
7594 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7595 // left to right (down the rows) of the input, and then from top to bottom.
7596 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7597 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7598 const RuntimeShape& output_shape, T* output_data) {
7599 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7600 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7601
7602 const int d0 = input_shape.DimsData()[0];
7603 const int d1 = input_shape.DimsData()[1];
7604 const int kLines = 4;
7605 const int kSkipSize = (kLines - 1) * d1;
7606
7607 const T* input = input_data;
7608
7609 int i = 0;
7610 for (; i <= d0 - kLines; i += kLines) {
7611 T* output = output_data + i;
7612
7613 const T* input_ptr = input;
7614 optimized_ops_preload_l1_keep(input_ptr);
7615 input_ptr += d1;
7616 optimized_ops_preload_l1_keep(input_ptr);
7617 input_ptr += d1;
7618 optimized_ops_preload_l1_keep(input_ptr);
7619 input_ptr += d1;
7620 optimized_ops_preload_l1_keep(input_ptr);
7621
7622 int j = 0;
7623 for (; j <= d1 - kLines; j += kLines) {
7624 input_ptr = input;
7625 const T a00 = input_ptr[0];
7626 const T a01 = input_ptr[1];
7627 const T a02 = input_ptr[2];
7628 const T a03 = input_ptr[3];
7629 input_ptr += d1;
7630 const T a10 = input_ptr[0];
7631 const T a11 = input_ptr[1];
7632 const T a12 = input_ptr[2];
7633 const T a13 = input_ptr[3];
7634 input_ptr += d1;
7635 const T a20 = input_ptr[0];
7636 const T a21 = input_ptr[1];
7637 const T a22 = input_ptr[2];
7638 const T a23 = input_ptr[3];
7639 input_ptr += d1;
7640 const T a30 = input_ptr[0];
7641 const T a31 = input_ptr[1];
7642 const T a32 = input_ptr[2];
7643 const T a33 = input_ptr[3];
7644
7645 output[0] = a00;
7646 output[1] = a10;
7647 output[2] = a20;
7648 output[3] = a30;
7649 output += d0;
7650
7651 output[0] = a01;
7652 output[1] = a11;
7653 output[2] = a21;
7654 output[3] = a31;
7655 output += d0;
7656
7657 output[0] = a02;
7658 output[1] = a12;
7659 output[2] = a22;
7660 output[3] = a32;
7661 output += d0;
7662
7663 output[0] = a03;
7664 output[1] = a13;
7665 output[2] = a23;
7666 output[3] = a33;
7667 output += d0;
7668
7669 input += kLines;
7670 }
7671 if (j == d1) {
7672 input += kSkipSize;
7673 } else {
7674 for (int p = 0; p < kLines; ++p) {
7675 for (int q = 0; q < d1 - j; ++q) {
7676 *(output + q * d0 + p) = *(input + p * d1 + q);
7677 }
7678 }
7679 input += (d1 - j) + kSkipSize;
7680 }
7681 }
7682 for (; i < d0; ++i) {
7683 T* output = output_data + i;
7684 for (int j = 0; j < d1; ++j) {
7685 *output = *input;
7686 output += d0;
7687 ++input;
7688 }
7689 }
7690 }
7691
7692 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7693 inline void Transpose2D(const RuntimeShape& input_shape,
7694 const int32_t* input_data,
7695 const RuntimeShape& output_shape,
7696 int32_t* output_data) {
7697 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7698 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7699
7700 const int d0 = input_shape.DimsData()[0];
7701 const int d1 = input_shape.DimsData()[1];
7702 #ifdef USE_NEON
7703 const int kLines = 4;
7704 const int kSkipSize = (kLines - 1) * d1;
7705 #endif
7706
7707 const int32_t* input = input_data;
7708
7709 int i = 0;
7710 #ifdef USE_NEON
7711 for (; i <= d0 - kLines; i += kLines) {
7712 int32_t* output = output_data + i;
7713
7714 const int32_t* input_ptr = input;
7715 optimized_ops_preload_l1_keep(input_ptr);
7716 input_ptr += d1;
7717 optimized_ops_preload_l1_keep(input_ptr);
7718 input_ptr += d1;
7719 optimized_ops_preload_l1_keep(input_ptr);
7720 input_ptr += d1;
7721 optimized_ops_preload_l1_keep(input_ptr);
7722
7723 int j = 0;
7724 for (; j <= d1 - kLines; j += kLines) {
7725 input_ptr = input;
7726 int32x4_t a0 = vld1q_s32(input);
7727 input_ptr += d1;
7728 int32x4_t a1 = vld1q_s32(input_ptr);
7729 input_ptr += d1;
7730 int32x4_t a2 = vld1q_s32(input_ptr);
7731 input_ptr += d1;
7732 int32x4_t a3 = vld1q_s32(input_ptr);
7733
7734 int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7735 int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7736 int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7737 int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7738
7739 vst1q_s32(output, tmp3.val[0]);
7740 output += d0;
7741 vst1q_s32(output, tmp4.val[0]);
7742 output += d0;
7743 vst1q_s32(output, tmp3.val[1]);
7744 output += d0;
7745 vst1q_s32(output, tmp4.val[1]);
7746 output += d0;
7747 input += kLines;
7748 }
7749 if (j == d1) {
7750 input += kSkipSize;
7751 } else {
7752 for (int p = 0; p < kLines; ++p) {
7753 for (int q = 0; q < d1 - j; ++q) {
7754 *(output + q * d0 + p) = *(input + p * d1 + q);
7755 }
7756 }
7757 input += (d1 - j) + kSkipSize;
7758 }
7759 }
7760 #endif
7761 for (; i < d0; ++i) {
7762 int32_t* output = output_data + i;
7763 for (int j = 0; j < d1; ++j) {
7764 *output = *input;
7765 output += d0;
7766 ++input;
7767 }
7768 }
7769 }
7770
7771 // TODO(b/173718660): see if we can reduce the number
7772 // of lines of code in branching without affecting latency.
7773 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7774 inline void Transpose3D(const TransposeParams& params,
7775 const RuntimeShape& input_shape, const T* input_data,
7776 const RuntimeShape& output_shape, T* output_data) {
7777 int s1, s2, s3;
7778 s1 = input_shape.Dims(0);
7779 s2 = input_shape.Dims(1);
7780 s3 = input_shape.Dims(2);
7781
7782 int p1, p2, p3;
7783 if (params.perm[0] == 2) {
7784 p1 = 1;
7785 } else if (params.perm[1] == 2) {
7786 p2 = 1;
7787 } else {
7788 p3 = 1;
7789 }
7790
7791 if (params.perm[0] == 1) {
7792 p1 = s3;
7793 } else if (params.perm[1] == 1) {
7794 p2 = s3;
7795 } else {
7796 p3 = s3;
7797 }
7798
7799 if (params.perm[0] == 0) {
7800 p1 = s2 * s3;
7801 } else if (params.perm[1] == 0) {
7802 p2 = s2 * s3;
7803 } else {
7804 p3 = s2 * s3;
7805 }
7806
7807 int o_s[3];
7808 o_s[0] = input_shape.Dims(params.perm[0]);
7809 o_s[1] = input_shape.Dims(params.perm[1]);
7810 o_s[2] = input_shape.Dims(params.perm[2]);
7811
7812 for (int i1 = 0; i1 < o_s[0]; ++i1) {
7813 for (int i2 = 0; i2 < o_s[1]; ++i2) {
7814 for (int i3 = 0; i3 < o_s[2]; ++i3) {
7815 const int i = i1 * p1 + i2 * p2 + i3 * p3;
7816 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7817 output_data[o] = input_data[i];
7818 }
7819 }
7820 }
7821 }
7822
7823 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7824 void TransposeImpl(const TransposeParams& params,
7825 const RuntimeShape& input_shape, const T* input_data,
7826 const RuntimeShape& output_shape, T* output_data) {
7827 const int dims_cnt = input_shape.DimensionsCount();
7828
7829 int dim0, dim1;
7830 if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7831 &dim1)) {
7832 Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7833 RuntimeShape({dim1, dim0}), output_data);
7834 return;
7835 }
7836
7837 // TODO(b/141217325): notably Eigen is better suited for
7838 // larger inputs whereas Transpose3D is generally
7839 // better for smaller ones.
7840 //
7841 // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7842 // and Transpose3D is better for 72^3 and down.
7843 //
7844 // 96^3 is not mobile-friendly for certain usecases
7845 // (e.g. model used in beam search for seq2seq) but is in others.
7846 // Consider tradeoffs.
7847 if (dims_cnt == 3) {
7848 Transpose3D(params, input_shape, input_data, output_shape, output_data);
7849 return;
7850 }
7851
7852 // Reroute to the reference version if an optimized method for the given data
7853 // is not available.
7854 reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7855 output_data);
7856 }
7857
7858 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7859 void Transpose(const TransposeParams& unshrinked_params,
7860 const RuntimeShape& unshrinked_input_shape, const T* input_data,
7861 const RuntimeShape& unshrinked_output_shape, T* output_data) {
7862 ruy::profiler::ScopeLabel label("Transpose");
7863
7864 const int output_size = unshrinked_output_shape.DimensionsCount();
7865 TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7866 TFLITE_DCHECK_LE(output_size, N);
7867 TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7868
7869 RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7870 RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7871 TransposeParams shrinked_params = unshrinked_params;
7872
7873 // Reduce any dimensions that have one size. Lower transpose op usually
7874 // performs better since memory access patterns will be improved.
7875 transpose_utils::RemoveOneSizeDimensions(
7876 &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7877
7878 // Handle identity cases.
7879 // TODO(b/140779653): Add an optimization pass in the conversion process to
7880 // remove transpose op nodes where they do nothing like the below one.
7881 bool identical = true;
7882 for (int i = 0; i < shrinked_params.perm_count; ++i) {
7883 if (shrinked_params.perm[i] != i) {
7884 identical = false;
7885 break;
7886 }
7887 }
7888 if (identical) {
7889 memcpy(output_data, input_data,
7890 unshrinked_input_shape.FlatSize() * sizeof(T));
7891 return;
7892 }
7893
7894 // Reduce dimensions by flattening.
7895 if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7896 RuntimeShape non_flatten_input_shape;
7897 RuntimeShape non_flatten_output_shape;
7898 TransposeParams non_flatten_params;
7899 const int total_size = shrinked_input_shape.FlatSize();
7900 const int non_flatten_size = transpose_utils::Flatten(
7901 shrinked_input_shape, shrinked_output_shape, shrinked_params,
7902 &non_flatten_input_shape, &non_flatten_output_shape,
7903 &non_flatten_params);
7904 TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7905
7906 for (int i = 0; i < total_size; i += non_flatten_size) {
7907 TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7908 input_data + i, non_flatten_output_shape,
7909 output_data + i);
7910 }
7911 return;
7912 }
7913
7914 // Call non-flattened case.
7915 TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7916 shrinked_output_shape, output_data);
7917 }
7918
7919 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7920 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7921 const int8* input1_data, const int8* input2_data,
7922 int8* output_data) {
7923 ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7924 int i = 0;
7925 #ifdef USE_NEON
7926 for (; i <= size - 16; i += 16) {
7927 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7928 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7929 const int8x16_t max_data =
7930 vmaxq_s8(input1_val_original, input2_val_original);
7931 vst1q_s8(output_data + i, max_data);
7932 }
7933 #endif // USE_NEON
7934 for (; i < size; ++i) {
7935 const int8 input1_val = input1_data[i];
7936 const int8 input2_val = input2_data[i];
7937 output_data[i] = std::max(input1_val, input2_val);
7938 }
7939 }
7940
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7941 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7942 int8 input1_data, const int8* input2_data,
7943 int8* output_data) {
7944 ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7945 int i = 0;
7946
7947 #ifdef USE_NEON
7948 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7949 for (; i <= size - 16; i += 16) {
7950 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7951 const int8x16_t max_data =
7952 vmaxq_s8(input1_val_original, input2_val_original);
7953 vst1q_s8(output_data + i, max_data);
7954 }
7955 #endif // USE_NEON
7956 for (; i < size; ++i) {
7957 const int8 input2_val = input2_data[i];
7958 output_data[i] = std::max(input1_data, input2_val);
7959 }
7960 }
7961
7962 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7963 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7964 const int8* input1_data, const int8* input2_data,
7965 int8* output_data) {
7966 ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7967 int i = 0;
7968 #ifdef USE_NEON
7969 for (; i <= size - 16; i += 16) {
7970 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7971 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7972 const int8x16_t min_data =
7973 vminq_s8(input1_val_original, input2_val_original);
7974 vst1q_s8(output_data + i, min_data);
7975 }
7976 #endif // USE_NEON
7977 for (; i < size; ++i) {
7978 const int8 input1_val = input1_data[i];
7979 const int8 input2_val = input2_data[i];
7980 output_data[i] = std::min(input1_val, input2_val);
7981 }
7982 }
7983
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7984 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7985 int8 input1_data, const int8* input2_data,
7986 int8* output_data) {
7987 ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7988 int i = 0;
7989
7990 #ifdef USE_NEON
7991 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7992 for (; i <= size - 16; i += 16) {
7993 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7994 const int8x16_t min_data =
7995 vminq_s8(input1_val_original, input2_val_original);
7996 vst1q_s8(output_data + i, min_data);
7997 }
7998 #endif // USE_NEON
7999 for (; i < size; ++i) {
8000 const int8 input2_val = input2_data[i];
8001 output_data[i] = std::min(input1_data, input2_val);
8002 }
8003 }
8004
8005 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)8006 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
8007 const RuntimeShape& input1_shape,
8008 const int8* input1_data,
8009 const RuntimeShape& input2_shape,
8010 const int8* input2_data,
8011 const RuntimeShape& output_shape,
8012 int8* output_data, Op op) {
8013 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8014 return reference_ops::MaximumMinimumBroadcastSlow(
8015 input1_shape, input1_data, input2_shape, input2_data, output_shape,
8016 output_data, op);
8017 }
8018
8019 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
8020 input2_data, output_shape, output_data,
8021 MaximumElementwise, MaximumScalarBroadcast);
8022 }
8023
8024 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)8025 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
8026 const RuntimeShape& input1_shape,
8027 const int8* input1_data,
8028 const RuntimeShape& input2_shape,
8029 const int8* input2_data,
8030 const RuntimeShape& output_shape,
8031 int8* output_data, Op op) {
8032 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8033 return reference_ops::MaximumMinimumBroadcastSlow(
8034 input1_shape, input1_data, input2_shape, input2_data, output_shape,
8035 output_data, op);
8036 }
8037
8038 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
8039 input2_data, output_shape, output_data,
8040 MinimumElementwise, MinimumScalarBroadcast);
8041 }
8042
8043 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)8044 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
8045 bool exclusive, bool reverse, T* output_data) {
8046 Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
8047
8048 for (int i = 0; i < axis; ++i) {
8049 dims[0] *= shape.Dims(i);
8050 }
8051 dims[1] = shape.Dims(axis);
8052 for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
8053 dims[2] *= shape.Dims(i);
8054 }
8055
8056 typedef Eigen::TensorMap<
8057 Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
8058 Eigen::Aligned>
8059 ConstTensor;
8060 typedef Eigen::TensorMap<
8061 Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
8062 Tensor;
8063 ConstTensor input(input_data, dims);
8064 Tensor output(output_data, dims);
8065
8066 if (reverse) {
8067 Eigen::array<bool, 3> reverse_idx = {false, true, false};
8068 output =
8069 input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
8070 } else {
8071 output = input.cumsum(1, exclusive);
8072 }
8073 }
8074
8075 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)8076 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
8077 bool exclusive, bool reverse, T* output_data) {
8078 const int dim = shape.DimensionsCount();
8079 TFLITE_DCHECK_GE(dim, 1);
8080 CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
8081 }
8082
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)8083 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
8084 float alpha, const float* input_data,
8085 float* output_data) {
8086 ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
8087 int i = 0;
8088
8089 #ifdef USE_NEON
8090 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
8091 const float32x4_t alpha_dup = vdupq_n_f32(alpha);
8092 for (; i <= size - 16; i += 16) {
8093 const float32x4_t input1 = vld1q_f32(input_data + i);
8094 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
8095 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
8096 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
8097
8098 const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
8099 const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
8100 const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
8101 const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
8102
8103 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
8104 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
8105 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
8106 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
8107
8108 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
8109 vst1q_f32(output_data + i, result1);
8110 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
8111 vst1q_f32(output_data + i + 4, result2);
8112 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
8113 vst1q_f32(output_data + i + 8, result3);
8114 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
8115 vst1q_f32(output_data + i + 12, result4);
8116 }
8117
8118 for (; i <= size - 4; i += 4) {
8119 const float32x4_t input = vld1q_f32(input_data + i);
8120 const float32x4_t temp = vmulq_f32(input, alpha_dup);
8121 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
8122 const float32x4_t result = vbslq_f32(mask, input, temp);
8123 vst1q_f32(output_data + i, result);
8124 }
8125 #endif // USE_NEON
8126 for (; i < size; ++i) {
8127 const float input = input_data[i];
8128 output_data[i] = input >= 0.f ? input : input * alpha;
8129 }
8130 }
8131
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)8132 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
8133 const float* alpha_data, const float* input_data,
8134 float* output_data) {
8135 ruy::profiler::ScopeLabel label("PreluElementWise/float");
8136
8137 int i = 0;
8138 #ifdef USE_NEON
8139 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
8140 for (; i <= flat_size - 16; i += 16) {
8141 const float32x4_t input1 = vld1q_f32(input_data + i);
8142 const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
8143 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
8144 const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
8145 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
8146 const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
8147 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
8148 const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
8149
8150 const float32x4_t temp1 = vmulq_f32(input1, alpha1);
8151 const float32x4_t temp2 = vmulq_f32(input2, alpha2);
8152 const float32x4_t temp3 = vmulq_f32(input3, alpha3);
8153 const float32x4_t temp4 = vmulq_f32(input4, alpha4);
8154
8155 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
8156 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
8157 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
8158 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
8159
8160 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
8161 vst1q_f32(output_data + i, result1);
8162 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
8163 vst1q_f32(output_data + i + 4, result2);
8164 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
8165 vst1q_f32(output_data + i + 8, result3);
8166 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
8167 vst1q_f32(output_data + i + 12, result4);
8168 }
8169
8170 for (; i <= flat_size - 4; i += 4) {
8171 const float32x4_t input = vld1q_f32(input_data + i);
8172 const float32x4_t alpha = vld1q_f32(alpha_data + i);
8173
8174 const float32x4_t temp = vmulq_f32(input, alpha);
8175 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
8176 const float32x4_t result = vbslq_f32(mask, input, temp);
8177 vst1q_f32(output_data + i, result);
8178 }
8179 #endif // USE_NEON
8180 for (; i < flat_size; ++i) {
8181 const float input = input_data[i];
8182 const float alpha = alpha_data[i];
8183 output_data[i] = input >= 0.f ? input : input * alpha;
8184 }
8185 }
8186
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))8187 inline void BroadcastPReluDispatch(
8188 const ArithmeticParams& params, const RuntimeShape& input_shape,
8189 const float* input_data, const RuntimeShape& alpha_shape,
8190 const float* alpha_data, const RuntimeShape& output_shape,
8191 float* output_data, float (*func)(float, float)) {
8192 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
8193 return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
8194 input_shape, input_data, alpha_shape, alpha_data, output_shape,
8195 output_data, func);
8196 }
8197
8198 BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
8199 alpha_data, output_shape, output_data,
8200 PReluElementWise, PReluScalarBroadcast);
8201 }
8202
8203 } // namespace optimized_ops
8204 } // namespace tflite
8205
8206 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8207 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8208 #pragma GCC diagnostic pop
8209 #endif
8210
8211 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8212