1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21
22 #include <algorithm>
23 #include <cmath>
24 #include <cstdint>
25 #include <limits>
26 #include <memory>
27 #include <tuple>
28 #include <type_traits>
29
30 #include "tensorflow/lite/kernels/internal/common.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/internal/reference/add.h"
33 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
34
35 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
36 #include <Accelerate/Accelerate.h>
37 #endif
38
39 #include "Eigen/Core"
40 #include "fixedpoint/fixedpoint.h"
41 #include "ruy/profiler/instrumentation.h" // from @ruy
42 #include "tensorflow/lite/c/common.h"
43 #include "tensorflow/lite/kernels/cpu_backend_context.h"
44 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
45 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
46 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
47 #include "tensorflow/lite/kernels/internal/cppmath.h"
48 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
49 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
50 #include "tensorflow/lite/kernels/internal/quantization_util.h"
51 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
52 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
53 #include "tensorflow/lite/kernels/internal/tensor.h"
54 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
55 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
56 #include "tensorflow/lite/kernels/internal/types.h"
57 #include "unsupported/Eigen/CXX11/Tensor"
58
59 #if __aarch64__ && __clang__
60 #define TFLITE_SOFTMAX_USE_UINT16_LUT
61 #endif
62
63 namespace tflite {
64 namespace optimized_ops {
65
66 // Unoptimized reference ops:
67 using reference_ops::Broadcast4DSlowGreater;
68 using reference_ops::Broadcast4DSlowGreaterEqual;
69 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
70 using reference_ops::Broadcast4DSlowGreaterWithScaling;
71 using reference_ops::Broadcast4DSlowLess;
72 using reference_ops::Broadcast4DSlowLessEqual;
73 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
74 using reference_ops::Broadcast4DSlowLessWithScaling;
75 using reference_ops::BroadcastAdd4DSlow;
76 using reference_ops::BroadcastMul4DSlow;
77 using reference_ops::BroadcastSub16POTSlow;
78 using reference_ops::BroadcastSubSlow;
79 using reference_ops::Concatenation;
80 using reference_ops::ConcatenationWithScaling;
81 using reference_ops::DepthConcatenation;
82 using reference_ops::Div;
83 using reference_ops::Elu;
84 using reference_ops::FakeQuant;
85 using reference_ops::Fill;
86 using reference_ops::Gather;
87 using reference_ops::Greater;
88 using reference_ops::GreaterEqual;
89 using reference_ops::GreaterEqualWithScaling;
90 using reference_ops::GreaterWithScaling;
91 using reference_ops::LeakyRelu;
92 using reference_ops::Less;
93 using reference_ops::LessEqual;
94 using reference_ops::LessEqualWithScaling;
95 using reference_ops::LessWithScaling;
96 using reference_ops::Mean;
97 using reference_ops::ProcessBroadcastShapes;
98 using reference_ops::RankOneSelect;
99 using reference_ops::Relu1;
100 using reference_ops::Relu6;
101 using reference_ops::ReluX;
102 using reference_ops::Round;
103 using reference_ops::Select;
104 using reference_ops::SpaceToBatchND;
105 using reference_ops::Split;
106 using reference_ops::Sub16;
107
108 // TODO(b/80247582) Remove this constant.
109 // This will be phased out as the shifts are revised with more thought. Use of a
110 // constant enables us to track progress on this work.
111 //
112 // Used to convert from old-style shifts (right) to new-style (left).
113 static constexpr int kReverseShift = -1;
114
115 // Make a local VectorMap typedef allowing to map a float array
116 // as a Eigen vector expression. The std::conditional here is to
117 // construct the suitable Eigen type for the constness of the
118 // data. Indeed, for const data, we need to produce
119 // Eigen::Map<const Eigen::Matrix<float, ...>>
120 // and not the more straightforward
121 // Eigen::Map<Eigen::Matrix<const float, ...>>
122 template <typename Scalar>
123 using VectorMap = typename std::conditional<
124 std::is_const<Scalar>::value,
125 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
126 Eigen::Dynamic, 1>>,
127 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
128
129 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)130 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
131 const int size = shape.FlatSize();
132 return VectorMap<Scalar>(data, size, 1);
133 }
134
135 // Make a local VectorMap typedef allowing to map a float array
136 // as a Eigen matrix expression. The same explanation as for VectorMap
137 // above also applies here.
138 template <typename Scalar>
139 using MatrixMap = typename std::conditional<
140 std::is_const<Scalar>::value,
141 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
142 Eigen::Dynamic, Eigen::Dynamic>>,
143 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
144
145 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)146 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
147 const RuntimeShape& shape) {
148 const int dims_count = shape.DimensionsCount();
149 const int rows = shape.Dims(dims_count - 1);
150 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
151 return MatrixMap<Scalar>(data, rows, cols);
152 }
153
154 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)155 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
156 const RuntimeShape& shape) {
157 const int cols = shape.Dims(0);
158 const int rows = FlatSizeSkipDim(shape, 0);
159 return MatrixMap<Scalar>(data, rows, cols);
160 }
161
162 template <typename Scalar>
163 using ArrayMap = typename std::conditional<
164 std::is_const<Scalar>::value,
165 Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
166 Eigen::Dynamic, Eigen::Dynamic>>,
167 Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
168
169 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)170 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
171 const RuntimeShape& shape) {
172 const int dims_count = shape.DimensionsCount();
173 const int rows = shape.Dims(dims_count - 1);
174 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
175 return ArrayMap<Scalar>(data, rows, cols);
176 }
177
178 // Copied from tensorflow/core/framework/tensor_types.h
179 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
180 struct TTypes {
181 // Rank-1 tensor (vector) of scalar type T.
182 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
183 Eigen::Aligned>
184 Flat;
185 typedef Eigen::TensorMap<
186 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
187 UnalignedConstMatrix;
188 };
189
190 // TODO(b/62193649): this function is only needed as long
191 // as we have the --variable_batch hack.
192 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)193 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
194 const RuntimeShape& shape,
195 int rows) {
196 const int flatsize = shape.FlatSize();
197 TFLITE_DCHECK_EQ(flatsize % rows, 0);
198 const int cols = flatsize / rows;
199 return MatrixMap<Scalar>(data, rows, cols);
200 }
201
202 template <typename ElementwiseF, typename ScalarBroadcastF, typename T>
BinaryBroadcastFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const T * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const T * unswitched_input2_data,const RuntimeShape & output_shape,T * output_data,ElementwiseF elementwise_f,ScalarBroadcastF scalar_broadcast_f)203 inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params,
204 const RuntimeShape& unswitched_input1_shape,
205 const T* unswitched_input1_data,
206 const RuntimeShape& unswitched_input2_shape,
207 const T* unswitched_input2_data,
208 const RuntimeShape& output_shape,
209 T* output_data, ElementwiseF elementwise_f,
210 ScalarBroadcastF scalar_broadcast_f) {
211 ArithmeticParams switched_params = unswitched_params;
212 switched_params.input1_offset = unswitched_params.input2_offset;
213 switched_params.input1_multiplier = unswitched_params.input2_multiplier;
214 switched_params.input1_shift = unswitched_params.input2_shift;
215 switched_params.input2_offset = unswitched_params.input1_offset;
216 switched_params.input2_multiplier = unswitched_params.input1_multiplier;
217 switched_params.input2_shift = unswitched_params.input1_shift;
218
219 const bool use_unswitched =
220 unswitched_params.broadcast_category ==
221 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
222
223 const ArithmeticParams& params =
224 use_unswitched ? unswitched_params : switched_params;
225 const T* input1_data =
226 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
227 const T* input2_data =
228 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
229
230 // Fivefold nested loops. The second input resets its position for each
231 // iteration of the second loop. The first input resets its position at the
232 // beginning of the fourth loop. The innermost loop is an elementwise add of
233 // sections of the arrays.
234 T* output_data_ptr = output_data;
235 const T* input1_data_ptr = input1_data;
236 const T* input2_data_reset = input2_data;
237 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
238 // between input shapes. y3 for input 1 is always broadcast, and so the
239 // dimension there is 1, whereas optionally y1 might be broadcast for
240 // input 2. Put another way, input1.shape.FlatSize = y0 * y1 * y2 * y4,
241 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
242 int y0 = params.broadcast_shape[0];
243 int y1 = params.broadcast_shape[1];
244 int y2 = params.broadcast_shape[2];
245 int y3 = params.broadcast_shape[3];
246 int y4 = params.broadcast_shape[4];
247 if (y4 > 1) {
248 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
249 // dimension.
250 for (int i0 = 0; i0 < y0; ++i0) {
251 const T* input2_data_ptr = nullptr;
252 for (int i1 = 0; i1 < y1; ++i1) {
253 input2_data_ptr = input2_data_reset;
254 for (int i2 = 0; i2 < y2; ++i2) {
255 for (int i3 = 0; i3 < y3; ++i3) {
256 elementwise_f(y4, params, input1_data_ptr, input2_data_ptr,
257 output_data_ptr);
258 input2_data_ptr += y4;
259 output_data_ptr += y4;
260 }
261 // We have broadcast y4 of input1 data y3 times, and now move on.
262 input1_data_ptr += y4;
263 }
264 }
265 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
266 input2_data_reset = input2_data_ptr;
267 }
268 } else if (input1_data_ptr != nullptr) {
269 // Special case of y4 == 1, in which the innermost loop is a single
270 // element and can be combined with the next (y3) as an inner broadcast.
271 //
272 // Note that this handles the case of pure scalar broadcast when
273 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
274 // broadcast with batch (as y2 > 1).
275 //
276 // NOTE The process is the same as the above general case except
277 // simplified for y4 == 1 and the loop over y3 is contained within the
278 // AddScalarBroadcast function.
279 for (int i0 = 0; i0 < y0; ++i0) {
280 const T* input2_data_ptr = nullptr;
281 for (int i1 = 0; i1 < y1; ++i1) {
282 input2_data_ptr = input2_data_reset;
283 for (int i2 = 0; i2 < y2; ++i2) {
284 scalar_broadcast_f(y3, params, *input1_data_ptr, input2_data_ptr,
285 output_data_ptr);
286 input2_data_ptr += y3;
287 output_data_ptr += y3;
288 input1_data_ptr += 1;
289 }
290 }
291 input2_data_reset = input2_data_ptr;
292 }
293 }
294 }
295
296 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
297
298 // Looks up each element of <indices> in <table>, returns them in a vector.
aarch64_lookup_vector(const uint8x16x4_t table[4],uint8x16_t indices)299 inline uint8x16_t aarch64_lookup_vector(const uint8x16x4_t table[4],
300 uint8x16_t indices) {
301 // Look up in 1st quarter of the table: top 2 bits of indices == 00
302 uint8x16_t output1 = vqtbl4q_u8(table[0], indices);
303 // Look up in 2nd quarter of the table: top 2 bits of indices == 01
304 uint8x16_t output2 =
305 vqtbl4q_u8(table[1], veorq_u8(indices, vdupq_n_u8(0x40)));
306 // Look up in 3rd quarter of the table: top 2 bits of indices == 10
307 uint8x16_t output3 =
308 vqtbl4q_u8(table[2], veorq_u8(indices, vdupq_n_u8(0x80)));
309 // Look up in 4th quarter of the table: top 2 bits of indices == 11
310 uint8x16_t output4 =
311 vqtbl4q_u8(table[3], veorq_u8(indices, vdupq_n_u8(0xc0)));
312
313 // Combine result of the 4 lookups.
314 return vorrq_u8(vorrq_u8(output1, output2), vorrq_u8(output3, output4));
315 }
316
317 #endif
318
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)319 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
320 float output_activation_max,
321 const RuntimeShape& bias_shape,
322 const float* bias_data,
323 const RuntimeShape& array_shape,
324 float* array_data) {
325 BiasAndClamp(output_activation_min, output_activation_max,
326 bias_shape.FlatSize(), bias_data, array_shape.FlatSize(),
327 array_data);
328 }
329
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * cpu_backend_context)330 inline void FullyConnected(
331 const FullyConnectedParams& params, const RuntimeShape& input_shape,
332 const float* input_data, const RuntimeShape& weights_shape,
333 const float* weights_data, const RuntimeShape& bias_shape,
334 const float* optional_bias_data, const RuntimeShape& output_shape,
335 float* output_data, CpuBackendContext* cpu_backend_context) {
336 ruy::profiler::ScopeLabel label("FullyConnected");
337 const int dims_count = weights_shape.DimensionsCount();
338 const int input_rows = weights_shape.Dims(dims_count - 1);
339 cpu_backend_gemm::MatrixParams<float> rhs_params;
340 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
341 rhs_params.rows = input_rows;
342 rhs_params.cols = input_shape.FlatSize() / input_rows;
343 rhs_params.cache_policy =
344 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
345 TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
346 cpu_backend_gemm::MatrixParams<float> lhs_params;
347 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
348 lhs_params.cols = weights_shape.Dims(dims_count - 1);
349 lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
350 lhs_params.cache_policy =
351 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
352 cpu_backend_gemm::MatrixParams<float> dst_params;
353 dst_params.order = cpu_backend_gemm::Order::kColMajor;
354 dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
355 dst_params.cols =
356 FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
357 cpu_backend_gemm::GemmParams<float, float> gemm_params;
358 gemm_params.bias = optional_bias_data;
359 gemm_params.clamp_min = params.float_activation_min;
360 gemm_params.clamp_max = params.float_activation_max;
361 cpu_backend_gemm::Gemm(lhs_params, weights_data, rhs_params, input_data,
362 dst_params, output_data, gemm_params,
363 cpu_backend_context);
364 }
365
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,CpuBackendContext * cpu_backend_context)366 inline void FullyConnected(
367 const FullyConnectedParams& params, const RuntimeShape& input_shape,
368 const uint8* input_data, const RuntimeShape& filter_shape,
369 const uint8* filter_data, const RuntimeShape& bias_shape,
370 const int32* bias_data, const RuntimeShape& output_shape,
371 uint8* output_data, CpuBackendContext* cpu_backend_context) {
372 ruy::profiler::ScopeLabel label("FullyConnected/8bit");
373 const int32 input_offset = params.input_offset;
374 const int32 filter_offset = params.weights_offset;
375 const int32 output_offset = params.output_offset;
376 const int32 output_multiplier = params.output_multiplier;
377 const int output_shift = params.output_shift;
378 const int32 output_activation_min = params.quantized_activation_min;
379 const int32 output_activation_max = params.quantized_activation_max;
380 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
381 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
382 // TODO(b/62193649): This really should be:
383 // const int batches = ArraySize(output_dims, 1);
384 // but the current --variable_batch hack consists in overwriting the 3rd
385 // dimension with the runtime batch size, as we don't keep track for each
386 // array of which dimension is the batch dimension in it.
387 const int output_dim_count = output_shape.DimensionsCount();
388 const int filter_dim_count = filter_shape.DimensionsCount();
389 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
390 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
391 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
392 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
393 const int output_rows = output_shape.Dims(output_dim_count - 1);
394 TFLITE_DCHECK_EQ(output_rows, filter_rows);
395 if (bias_data) {
396 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
397 }
398
399 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
400 lhs_params.rows = filter_rows;
401 lhs_params.cols = filter_cols;
402 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
403 lhs_params.zero_point = -filter_offset;
404 lhs_params.cache_policy =
405 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
406 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
407 rhs_params.rows = filter_cols;
408 rhs_params.cols = batches;
409 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
410 rhs_params.zero_point = -input_offset;
411 rhs_params.cache_policy =
412 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
413 cpu_backend_gemm::MatrixParams<uint8> dst_params;
414 dst_params.rows = filter_rows;
415 dst_params.cols = batches;
416 dst_params.order = cpu_backend_gemm::Order::kColMajor;
417 dst_params.zero_point = output_offset;
418 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
419 gemm_params.bias = bias_data;
420 gemm_params.clamp_min = output_activation_min;
421 gemm_params.clamp_max = output_activation_max;
422 gemm_params.multiplier_fixedpoint = output_multiplier;
423 gemm_params.multiplier_exponent = output_shift;
424 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
425 dst_params, output_data, gemm_params,
426 cpu_backend_context);
427 }
428
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,CpuBackendContext * cpu_backend_context)429 inline void FullyConnected(
430 const FullyConnectedParams& params, const RuntimeShape& input_shape,
431 const uint8* input_data, const RuntimeShape& filter_shape,
432 const uint8* filter_data, const RuntimeShape& bias_shape,
433 const int32* bias_data_int32, const RuntimeShape& output_shape,
434 int16* output_data, CpuBackendContext* cpu_backend_context) {
435 ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16");
436 const int32 input_offset = params.input_offset;
437 const int32 filter_offset = params.weights_offset;
438 const int32 output_offset = params.output_offset;
439 const int32 output_multiplier = params.output_multiplier;
440 const int output_shift = params.output_shift;
441 const int32 output_activation_min = params.quantized_activation_min;
442 const int32 output_activation_max = params.quantized_activation_max;
443 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
444 TFLITE_DCHECK_EQ(output_offset, 0);
445 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
446 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
447
448 // TODO(b/62193649): This really should be:
449 // const int batches = ArraySize(output_dims, 1);
450 // but the current --variable_batch hack consists in overwriting the 3rd
451 // dimension with the runtime batch size, as we don't keep track for each
452 // array of which dimension is the batch dimension in it.
453 const int output_dim_count = output_shape.DimensionsCount();
454 const int filter_dim_count = filter_shape.DimensionsCount();
455 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
456 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
457 output_shape, output_dim_count - 1);
458 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
459
460 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
461 lhs_params.rows = output_depth;
462 lhs_params.cols = accum_depth;
463 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
464 lhs_params.zero_point = -filter_offset;
465 lhs_params.cache_policy =
466 cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
467 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
468 rhs_params.rows = accum_depth;
469 rhs_params.cols = batches;
470 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
471 rhs_params.zero_point = -input_offset;
472 rhs_params.cache_policy =
473 cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
474 cpu_backend_gemm::MatrixParams<int16> dst_params;
475 dst_params.rows = output_depth;
476 dst_params.cols = batches;
477 dst_params.order = cpu_backend_gemm::Order::kColMajor;
478 dst_params.zero_point = 0;
479 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
480 gemm_params.bias = bias_data_int32;
481 gemm_params.clamp_min = output_activation_min;
482 gemm_params.clamp_max = output_activation_max;
483 gemm_params.multiplier_fixedpoint = output_multiplier;
484 gemm_params.multiplier_exponent = output_shift;
485 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
486 dst_params, output_data, gemm_params,
487 cpu_backend_context);
488 }
489
490 // Internal function doing the actual arithmetic work for
491 // ShuffledFullyConnected.
492 // May be called either directly by it (single-threaded case) or may be used
493 // as the 'task' for worker threads to run (multi-threaded case, see
494 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)495 inline void ShuffledFullyConnectedWorkerImpl(
496 const uint8* shuffled_input_workspace_data,
497 const int8* shuffled_weights_data, int batches, int output_depth,
498 int output_stride, int accum_depth, const int32* bias_data,
499 int32 output_multiplier, int output_shift, int16* output_data) {
500 #if defined USE_NEON
501 const int8* shuffled_weights_ptr = shuffled_weights_data;
502 if (batches == 1) {
503 const int right_shift = output_shift > 0 ? 0 : -output_shift;
504 const int left_shift = output_shift > 0 ? output_shift : 0;
505 for (int c = 0; c < output_depth; c += 4) {
506 // Accumulation loop.
507 int32x4_t row_accum0 = vdupq_n_s32(0);
508 int32x4_t row_accum1 = vdupq_n_s32(0);
509 int32x4_t row_accum2 = vdupq_n_s32(0);
510 int32x4_t row_accum3 = vdupq_n_s32(0);
511 for (int d = 0; d < accum_depth; d += 16) {
512 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
513 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
514 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
515 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
516 shuffled_weights_ptr += 64;
517 int8x16_t input =
518 vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
519 int16x8_t local_accum0 =
520 vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
521 int16x8_t local_accum1 =
522 vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
523 int16x8_t local_accum2 =
524 vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
525 int16x8_t local_accum3 =
526 vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
527 local_accum0 =
528 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
529 local_accum1 =
530 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
531 local_accum2 =
532 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
533 local_accum3 =
534 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
535 row_accum0 = vpadalq_s16(row_accum0, local_accum0);
536 row_accum1 = vpadalq_s16(row_accum1, local_accum1);
537 row_accum2 = vpadalq_s16(row_accum2, local_accum2);
538 row_accum3 = vpadalq_s16(row_accum3, local_accum3);
539 }
540 // Horizontally reduce accumulators
541 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
542 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
543 pairwise_reduced_acc_0 =
544 vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
545 pairwise_reduced_acc_1 =
546 vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
547 pairwise_reduced_acc_2 =
548 vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
549 pairwise_reduced_acc_3 =
550 vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
551 const int32x2_t reduced_lo =
552 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
553 const int32x2_t reduced_hi =
554 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
555 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
556 // Add bias values.
557 int32x4_t bias_vec = vld1q_s32(bias_data + c);
558 reduced = vaddq_s32(reduced, bias_vec);
559 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
560 // Multiply by the fixed-point multiplier.
561 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
562 // Rounding-shift-right.
563 using gemmlowp::RoundingDivideByPOT;
564 reduced = RoundingDivideByPOT(reduced, right_shift);
565 // Narrow values down to 16 bit signed.
566 const int16x4_t res16 = vqmovn_s32(reduced);
567 vst1_s16(output_data + c, res16);
568 }
569 } else if (batches == 4) {
570 const int right_shift = output_shift > 0 ? 0 : -output_shift;
571 const int left_shift = output_shift > 0 ? output_shift : 0;
572 for (int c = 0; c < output_depth; c += 4) {
573 const int8* shuffled_input_ptr =
574 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
575 // Accumulation loop.
576 int32x4_t row_accum00 = vdupq_n_s32(0);
577 int32x4_t row_accum10 = vdupq_n_s32(0);
578 int32x4_t row_accum20 = vdupq_n_s32(0);
579 int32x4_t row_accum30 = vdupq_n_s32(0);
580 int32x4_t row_accum01 = vdupq_n_s32(0);
581 int32x4_t row_accum11 = vdupq_n_s32(0);
582 int32x4_t row_accum21 = vdupq_n_s32(0);
583 int32x4_t row_accum31 = vdupq_n_s32(0);
584 int32x4_t row_accum02 = vdupq_n_s32(0);
585 int32x4_t row_accum12 = vdupq_n_s32(0);
586 int32x4_t row_accum22 = vdupq_n_s32(0);
587 int32x4_t row_accum32 = vdupq_n_s32(0);
588 int32x4_t row_accum03 = vdupq_n_s32(0);
589 int32x4_t row_accum13 = vdupq_n_s32(0);
590 int32x4_t row_accum23 = vdupq_n_s32(0);
591 int32x4_t row_accum33 = vdupq_n_s32(0);
592 for (int d = 0; d < accum_depth; d += 16) {
593 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
594 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
595 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
596 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
597 shuffled_weights_ptr += 64;
598 int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
599 int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
600 int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
601 int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
602 shuffled_input_ptr += 64;
603 int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
604 #define TFLITE_SHUFFLED_FC_ACCUM(B) \
605 local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B)); \
606 local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B)); \
607 local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B)); \
608 local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B)); \
609 local_accum0 = \
610 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
611 local_accum1 = \
612 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
613 local_accum2 = \
614 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
615 local_accum3 = \
616 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
617 row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0); \
618 row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1); \
619 row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2); \
620 row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
621
622 TFLITE_SHUFFLED_FC_ACCUM(0)
623 TFLITE_SHUFFLED_FC_ACCUM(1)
624 TFLITE_SHUFFLED_FC_ACCUM(2)
625 TFLITE_SHUFFLED_FC_ACCUM(3)
626
627 #undef TFLITE_SHUFFLED_FC_ACCUM
628 }
629 // Horizontally reduce accumulators
630
631 #define TFLITE_SHUFFLED_FC_STORE(B) \
632 { \
633 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, \
634 pairwise_reduced_acc_2, pairwise_reduced_acc_3; \
635 pairwise_reduced_acc_0 = \
636 vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
637 pairwise_reduced_acc_1 = \
638 vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
639 pairwise_reduced_acc_2 = \
640 vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
641 pairwise_reduced_acc_3 = \
642 vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
643 const int32x2_t reduced_lo = \
644 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); \
645 const int32x2_t reduced_hi = \
646 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); \
647 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); \
648 int32x4_t bias_vec = vld1q_s32(bias_data + c); \
649 reduced = vaddq_s32(reduced, bias_vec); \
650 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); \
651 reduced = vqrdmulhq_n_s32(reduced, output_multiplier); \
652 using gemmlowp::RoundingDivideByPOT; \
653 reduced = RoundingDivideByPOT(reduced, right_shift); \
654 const int16x4_t res16 = vqmovn_s32(reduced); \
655 vst1_s16(output_data + c + B * output_stride, res16); \
656 }
657
658 TFLITE_SHUFFLED_FC_STORE(0);
659 TFLITE_SHUFFLED_FC_STORE(1);
660 TFLITE_SHUFFLED_FC_STORE(2);
661 TFLITE_SHUFFLED_FC_STORE(3);
662
663 #undef TFLITE_SHUFFLED_FC_STORE
664 }
665 } else {
666 TFLITE_DCHECK(false);
667 return;
668 }
669 #else
670 if (batches == 1) {
671 int16* output_ptr = output_data;
672 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
673 // so that just reinterpreting them as int8 values is equivalent to
674 // subtracting 128 from them, thus implementing for free the subtraction of
675 // the zero_point value 128.
676 const int8* shuffled_weights_ptr =
677 reinterpret_cast<const int8*>(shuffled_weights_data);
678 // Likewise, we preshuffled and pre-xored the input data above.
679 const int8* shuffled_input_data =
680 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
681 for (int c = 0; c < output_depth; c += 4) {
682 // Internal accumulation.
683 // Initialize accumulator with the bias-value.
684 int32 accum[4] = {0};
685 // Accumulation loop.
686 for (int d = 0; d < accum_depth; d += 16) {
687 for (int i = 0; i < 4; i++) {
688 for (int j = 0; j < 16; j++) {
689 int8 input_val = shuffled_input_data[d + j];
690 int8 weights_val = *shuffled_weights_ptr++;
691 accum[i] += weights_val * input_val;
692 }
693 }
694 }
695 for (int i = 0; i < 4; i++) {
696 // Add bias value
697 int acc = accum[i] + bias_data[c + i];
698 // Down-scale the final int32 accumulator to the scale used by our
699 // (16-bit, typically 3 integer bits) fixed-point format. The quantized
700 // multiplier and shift here have been pre-computed offline
701 // (e.g. by toco).
702 acc =
703 MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
704 // Saturate, cast to int16, and store to output array.
705 acc = std::max(acc, -32768);
706 acc = std::min(acc, 32767);
707 output_ptr[c + i] = acc;
708 }
709 }
710 } else if (batches == 4) {
711 int16* output_ptr = output_data;
712 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
713 // so that just reinterpreting them as int8 values is equivalent to
714 // subtracting 128 from them, thus implementing for free the subtraction of
715 // the zero_point value 128.
716 const int8* shuffled_weights_ptr =
717 reinterpret_cast<const int8*>(shuffled_weights_data);
718 // Likewise, we preshuffled and pre-xored the input data above.
719 const int8* shuffled_input_data =
720 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
721 for (int c = 0; c < output_depth; c += 4) {
722 const int8* shuffled_input_ptr = shuffled_input_data;
723 // Accumulation loop.
724 // Internal accumulation.
725 // Initialize accumulator with the bias-value.
726 int32 accum[4][4];
727 for (int i = 0; i < 4; i++) {
728 for (int b = 0; b < 4; b++) {
729 accum[i][b] = 0;
730 }
731 }
732 for (int d = 0; d < accum_depth; d += 16) {
733 for (int i = 0; i < 4; i++) {
734 for (int b = 0; b < 4; b++) {
735 for (int j = 0; j < 16; j++) {
736 int8 input_val = shuffled_input_ptr[16 * b + j];
737 int8 weights_val = shuffled_weights_ptr[16 * i + j];
738 accum[i][b] += weights_val * input_val;
739 }
740 }
741 }
742 shuffled_input_ptr += 64;
743 shuffled_weights_ptr += 64;
744 }
745 for (int i = 0; i < 4; i++) {
746 for (int b = 0; b < 4; b++) {
747 // Add bias value
748 int acc = accum[i][b] + bias_data[c + i];
749 // Down-scale the final int32 accumulator to the scale used by our
750 // (16-bit, typically 3 integer bits) fixed-point format. The
751 // quantized multiplier and shift here have been pre-computed offline
752 // (e.g. by toco).
753 acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
754 output_shift);
755 // Saturate, cast to int16, and store to output array.
756 acc = std::max(acc, -32768);
757 acc = std::min(acc, 32767);
758 output_ptr[b * output_stride + c + i] = acc;
759 }
760 }
761 }
762 } else {
763 TFLITE_DCHECK(false);
764 return;
765 }
766 #endif
767 }
768
769 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
770 // to allow using gemmlowp's threadpool.
771 struct ShuffledFullyConnectedWorkerTask : cpu_backend_threadpool::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask772 ShuffledFullyConnectedWorkerTask(const uint8* input_data,
773 const int8* shuffled_weights_data,
774 int batches, int output_depth,
775 int output_stride, int accum_depth,
776 const int32* bias_data,
777 int32 output_multiplier, int output_shift,
778 int16* output_data)
779 : input_data_(input_data),
780 shuffled_weights_data_(shuffled_weights_data),
781 batches_(batches),
782 output_depth_(output_depth),
783 output_stride_(output_stride),
784 accum_depth_(accum_depth),
785 bias_data_(bias_data),
786 output_multiplier_(output_multiplier),
787 output_shift_(output_shift),
788 output_data_(output_data) {}
789
RunShuffledFullyConnectedWorkerTask790 void Run() override {
791 ShuffledFullyConnectedWorkerImpl(
792 input_data_, shuffled_weights_data_, batches_, output_depth_,
793 output_stride_, accum_depth_, bias_data_, output_multiplier_,
794 output_shift_, output_data_);
795 }
796
797 const uint8* input_data_;
798 const int8* shuffled_weights_data_;
799 int batches_;
800 int output_depth_;
801 int output_stride_;
802 int accum_depth_;
803 const int32* bias_data_;
804 int32 output_multiplier_;
805 int output_shift_;
806 int16* output_data_;
807 };
808
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,CpuBackendContext * cpu_backend_context)809 inline void ShuffledFullyConnected(
810 const FullyConnectedParams& params, const RuntimeShape& input_shape,
811 const uint8* input_data, const RuntimeShape& weights_shape,
812 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
813 const int32* bias_data, const RuntimeShape& output_shape,
814 int16* output_data, uint8* shuffled_input_workspace_data,
815 CpuBackendContext* cpu_backend_context) {
816 ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit");
817 const int32 output_multiplier = params.output_multiplier;
818 const int output_shift = params.output_shift;
819 const int32 output_activation_min = params.quantized_activation_min;
820 const int32 output_activation_max = params.quantized_activation_max;
821 TFLITE_DCHECK_EQ(output_activation_min, -32768);
822 TFLITE_DCHECK_EQ(output_activation_max, 32767);
823 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
824 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
825 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
826 // TODO(b/62193649): This really should be:
827 // const int batches = ArraySize(output_dims, 1);
828 // but the current --variable_batch hack consists in overwriting the 3rd
829 // dimension with the runtime batch size, as we don't keep track for each
830 // array of which dimension is the batch dimension in it.
831 const int output_dim_count = output_shape.DimensionsCount();
832 const int weights_dim_count = weights_shape.DimensionsCount();
833 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
834 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
835 output_shape, output_dim_count - 1);
836 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
837 TFLITE_DCHECK((accum_depth % 16) == 0);
838 TFLITE_DCHECK((output_depth % 4) == 0);
839 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
840 // so that just reinterpreting them as int8 values is equivalent to
841 // subtracting 128 from them, thus implementing for free the subtraction of
842 // the zero_point value 128.
843 const int8* int8_shuffled_weights_data =
844 reinterpret_cast<const int8*>(shuffled_weights_data);
845
846 // Shuffling and xoring of input activations into the workspace buffer
847 if (batches == 1) {
848 #ifdef USE_NEON
849 const uint8x16_t signbit = vdupq_n_u8(0x80);
850 for (int i = 0; i < accum_depth; i += 16) {
851 uint8x16_t val = vld1q_u8(input_data + i);
852 val = veorq_u8(val, signbit);
853 vst1q_u8(shuffled_input_workspace_data + i, val);
854 }
855 #else
856 for (int i = 0; i < accum_depth; i++) {
857 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
858 }
859 #endif
860 } else if (batches == 4) {
861 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
862 int c = 0;
863 #ifdef USE_NEON
864 const uint8x16_t signbit = vdupq_n_u8(0x80);
865 for (c = 0; c < accum_depth; c += 16) {
866 const uint8* src_data_ptr = input_data + c;
867 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
868 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
869 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
870 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
871 val0 = veorq_u8(val0, signbit);
872 val1 = veorq_u8(val1, signbit);
873 val2 = veorq_u8(val2, signbit);
874 val3 = veorq_u8(val3, signbit);
875 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
876 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
877 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
878 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
879 shuffled_input_workspace_ptr += 64;
880 }
881 #else
882 for (c = 0; c < accum_depth; c += 16) {
883 for (int b = 0; b < 4; b++) {
884 const uint8* src_data_ptr = input_data + b * accum_depth + c;
885 for (int j = 0; j < 16; j++) {
886 uint8 src_val = *src_data_ptr++;
887 // Flip the sign bit, so that the kernel will only need to
888 // reinterpret these uint8 values as int8, getting for free the
889 // subtraction of the zero_point value 128.
890 uint8 dst_val = src_val ^ 0x80;
891 *shuffled_input_workspace_ptr++ = dst_val;
892 }
893 }
894 }
895 #endif
896 } else {
897 TFLITE_DCHECK(false);
898 return;
899 }
900
901 static constexpr int kKernelRows = 4;
902 const int thread_count =
903 LegacyHowManyThreads<kKernelRows>(cpu_backend_context->max_num_threads(),
904 output_depth, batches, accum_depth);
905 if (thread_count == 1) {
906 // Single-thread case: do the computation on the current thread, don't
907 // use a threadpool
908 ShuffledFullyConnectedWorkerImpl(
909 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
910 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
911 output_shift, output_data);
912 return;
913 }
914
915 // Multi-threaded case: use the gemmlowp context's threadpool.
916 TFLITE_DCHECK_GT(thread_count, 1);
917 std::vector<ShuffledFullyConnectedWorkerTask> tasks;
918 // TODO(b/131746020) don't create new heap allocations every time.
919 // At least we make it a single heap allocation by using reserve().
920 tasks.reserve(thread_count);
921 const int kRowsPerWorker =
922 RoundUp<kKernelRows>(CeilQuotient(output_depth, thread_count));
923 int row_start = 0;
924 for (int i = 0; i < thread_count; i++) {
925 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
926 tasks.emplace_back(shuffled_input_workspace_data,
927 int8_shuffled_weights_data + row_start * accum_depth,
928 batches, row_end - row_start, output_depth, accum_depth,
929 bias_data + row_start, output_multiplier, output_shift,
930 output_data + row_start);
931 row_start = row_end;
932 }
933 TFLITE_DCHECK_EQ(row_start, output_depth);
934 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
935 cpu_backend_context);
936 }
937
938 #ifdef USE_NEON
939
RoundToNearest(const float32x4_t input)940 inline int32x4_t RoundToNearest(const float32x4_t input) {
941 #if defined(__aarch64__) || defined(__SSSE3__)
942 // Note: vcvtnq_s32_f32 is not available in ARMv7
943 return vcvtnq_s32_f32(input);
944 #else
945 static const float32x4_t zero_val_dup = vdupq_n_f32(0.0f);
946 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
947 static const float32x4_t minus_point5_val_dup = vdupq_n_f32(-0.5f);
948
949 const uint32x4_t mask = vcltq_f32(input, zero_val_dup);
950 const float32x4_t round =
951 vbslq_f32(mask, minus_point5_val_dup, point5_val_dup);
952 return vcvtq_s32_f32(vaddq_f32(input, round));
953 #endif // defined(__aarch64__) || defined(__SSSE3__)
954 }
955
RoundToNearestUnsigned(const float32x4_t input)956 inline uint32x4_t RoundToNearestUnsigned(const float32x4_t input) {
957 #if defined(__aarch64__)
958 // Note that vcvtnq_u32_f32 is not available in ARMv7 or in arm_neon_sse.h.
959 return vcvtnq_u32_f32(input);
960 #else
961 static const float32x4_t point5_val_dup = vdupq_n_f32(0.5f);
962
963 return vcvtq_u32_f32(vaddq_f32(input, point5_val_dup));
964 #endif // defined(__aarch64__)
965 }
966
967 #endif // USE_NEON
968
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 multiplier,int32 shift,int32 bias,const RuntimeShape & output_shape,uint8_t * output_data,int start_depth,int end_depth)969 inline void MeanImpl(const tflite::MeanParams& op_params,
970 const RuntimeShape& input_shape, const uint8_t* input_data,
971 int32 multiplier, int32 shift, int32 bias,
972 const RuntimeShape& output_shape, uint8_t* output_data,
973 int start_depth, int end_depth) {
974 ruy::profiler::ScopeLabel label("Mean4D/Uint8/MeanImpl");
975
976 // Current implementation only supports dimension equals 4 and simultaneous
977 // reduction over width and height.
978 const int output_batch = output_shape.Dims(0);
979 const int output_height = output_shape.Dims(2);
980 const int output_width = output_shape.Dims(2);
981 const int input_height = input_shape.Dims(1);
982 const int input_width = input_shape.Dims(2);
983
984 TFLITE_CHECK_EQ(op_params.axis_count, 2);
985 TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
986 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
987 TFLITE_CHECK_EQ(output_height, 1);
988 TFLITE_CHECK_EQ(output_width, 1);
989
990 constexpr int32_t kMinValue = std::numeric_limits<uint8_t>::min();
991 constexpr int32_t kMaxValue = std::numeric_limits<uint8_t>::max();
992
993 #ifdef USE_NEON
994 const int32x4_t bias_dup = vdupq_n_s32(bias);
995 const int32x4_t min_dup = vdupq_n_s32(kMinValue);
996 const int32x4_t max_dup = vdupq_n_s32(kMaxValue);
997 #endif // USE_NEON
998
999 for (int out_b = 0; out_b < output_batch; ++out_b) {
1000 int out_d = start_depth;
1001 #ifdef USE_NEON
1002
1003 for (; out_d <= end_depth - 16; out_d += 16) {
1004 int32x4x4_t temp_sum;
1005 temp_sum.val[0] = vdupq_n_s32(0);
1006 temp_sum.val[1] = vdupq_n_s32(0);
1007 temp_sum.val[2] = vdupq_n_s32(0);
1008 temp_sum.val[3] = vdupq_n_s32(0);
1009 for (int in_h = 0; in_h < input_height; ++in_h) {
1010 for (int in_w = 0; in_w < input_width; ++in_w) {
1011 const uint8_t* input_data_ptr =
1012 input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1013 uint8x16_t input_data_val = vld1q_u8(input_data_ptr);
1014
1015 int16x8_t input_data_low_shift =
1016 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_data_val)));
1017 int16x8_t input_data_high_shift =
1018 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_data_val)));
1019
1020 int32x4_t input_low_low =
1021 vmovl_s16(vget_low_s16(input_data_low_shift));
1022 int32x4_t input_high_low =
1023 vmovl_s16(vget_high_s16(input_data_low_shift));
1024 int32x4_t input_low_high =
1025 vmovl_s16(vget_low_s16(input_data_high_shift));
1026 int32x4_t input_high_high =
1027 vmovl_s16(vget_high_s16(input_data_high_shift));
1028
1029 temp_sum.val[0] = vaddq_s32(temp_sum.val[0], input_low_low);
1030 temp_sum.val[1] = vaddq_s32(temp_sum.val[1], input_high_low);
1031 temp_sum.val[2] = vaddq_s32(temp_sum.val[2], input_low_high);
1032 temp_sum.val[3] = vaddq_s32(temp_sum.val[3], input_high_high);
1033 }
1034 }
1035
1036 temp_sum =
1037 MultiplyByQuantizedMultiplier4Rows(temp_sum, multiplier, shift);
1038
1039 temp_sum.val[0] = vaddq_s32(temp_sum.val[0], bias_dup);
1040 temp_sum.val[1] = vaddq_s32(temp_sum.val[1], bias_dup);
1041 temp_sum.val[2] = vaddq_s32(temp_sum.val[2], bias_dup);
1042 temp_sum.val[3] = vaddq_s32(temp_sum.val[3], bias_dup);
1043
1044 temp_sum.val[0] = vminq_s32(vmaxq_s32(temp_sum.val[0], min_dup), max_dup);
1045 temp_sum.val[1] = vminq_s32(vmaxq_s32(temp_sum.val[1], min_dup), max_dup);
1046 temp_sum.val[2] = vminq_s32(vmaxq_s32(temp_sum.val[2], min_dup), max_dup);
1047 temp_sum.val[3] = vminq_s32(vmaxq_s32(temp_sum.val[3], min_dup), max_dup);
1048
1049 uint16x4_t narrowed_low_low =
1050 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[0]));
1051 uint16x4_t narrowed_high_low =
1052 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[1]));
1053 uint16x4_t narrowed_low_high =
1054 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[2]));
1055 uint16x4_t narrowed_high_high =
1056 vmovn_u32(vreinterpretq_u32_s32(temp_sum.val[3]));
1057
1058 uint16x8_t combined_low =
1059 vcombine_u16(narrowed_low_low, narrowed_high_low);
1060 uint16x8_t combined_high =
1061 vcombine_u16(narrowed_low_high, narrowed_high_high);
1062
1063 uint8x8_t narrowed_low = vmovn_u16(combined_low);
1064 uint8x8_t narrowed_high = vmovn_u16(combined_high);
1065
1066 uint8x16_t combined_output = vcombine_u8(narrowed_low, narrowed_high);
1067
1068 uint8_t* output_data_ptr =
1069 output_data + Offset(output_shape, out_b, 0, 0, out_d);
1070 vst1q_u8(output_data_ptr, combined_output);
1071 }
1072 #endif // USE_NEON
1073
1074 for (; out_d < end_depth; ++out_d) {
1075 int acc = 0;
1076 for (int in_h = 0; in_h < input_height; ++in_h) {
1077 for (int in_w = 0; in_w < input_width; ++in_w) {
1078 acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1079 }
1080 }
1081
1082 acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
1083 acc += bias;
1084 acc = std::min(std::max(acc, kMinValue), kMaxValue);
1085 output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1086 static_cast<uint8_t>(acc);
1087 }
1088 }
1089 }
1090
1091 struct MeanWorkerTask : cpu_backend_threadpool::Task {
MeanWorkerTaskMeanWorkerTask1092 MeanWorkerTask(const tflite::MeanParams& op_params,
1093 const RuntimeShape& input_shape, const uint8_t* input_data,
1094 int32 multiplier, int32 shift, int32 bias,
1095 const RuntimeShape& output_shape, uint8_t* output_data,
1096 int start_height, int end_height)
1097 : op_params(op_params),
1098 input_shape(input_shape),
1099 input_data(input_data),
1100 multiplier(multiplier),
1101 shift(shift),
1102 bias(bias),
1103 output_shape(output_shape),
1104 output_data(output_data),
1105 start_height(start_height),
1106 end_height(end_height) {}
1107
RunMeanWorkerTask1108 void Run() override {
1109 MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1110 output_shape, output_data, start_height, end_height);
1111 }
1112
1113 private:
1114 const tflite::MeanParams& op_params;
1115 const RuntimeShape& input_shape;
1116 const uint8_t* input_data;
1117 int32 multiplier;
1118 int32 shift;
1119 int32 bias;
1120 const RuntimeShape& output_shape;
1121 uint8_t* output_data;
1122 int start_height;
1123 int end_height;
1124 };
1125
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,CpuBackendContext * cpu_backend_context)1126 inline void Mean(const tflite::MeanParams& op_params,
1127 const RuntimeShape& unextended_input_shape,
1128 const uint8_t* input_data, int32 input_zero_point,
1129 float input_scale, const RuntimeShape& unextended_output_shape,
1130 uint8_t* output_data, int32 output_zero_point,
1131 float output_scale, CpuBackendContext* cpu_backend_context) {
1132 ruy::profiler::ScopeLabel label("Mean4D/Uint8");
1133 // Current implementation only supports dimension equals 4 and simultaneous
1134 // reduction over width and height.
1135 TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1136 TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1137 const RuntimeShape input_shape =
1138 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1139 const RuntimeShape output_shape =
1140 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1141 const int output_height = output_shape.Dims(1);
1142 const int output_width = output_shape.Dims(2);
1143 const int output_depth = output_shape.Dims(3);
1144
1145 TFLITE_CHECK_EQ(op_params.axis_count, 2);
1146 TFLITE_CHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1147 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1148 TFLITE_CHECK_EQ(output_height, 1);
1149 TFLITE_CHECK_EQ(output_width, 1);
1150
1151 const int input_height = input_shape.Dims(1);
1152 const int input_width = input_shape.Dims(2);
1153 const float num_elements_in_axis = input_width * input_height;
1154
1155 float temp = input_zero_point * input_scale / output_scale;
1156 temp = temp > 0 ? temp + 0.5f : temp - 0.5f;
1157 int32_t bias = output_zero_point - static_cast<int32_t>(temp);
1158 float real_scale = input_scale / (num_elements_in_axis * output_scale);
1159
1160 int32 multiplier, shift;
1161 QuantizeMultiplier(real_scale, &multiplier, &shift);
1162
1163 constexpr int kMinDepthPerThread = 8;
1164 int thread_count = output_depth / kMinDepthPerThread;
1165 thread_count = thread_count > 0 ? thread_count : 1;
1166 const int capped_thread_count =
1167 std::min(thread_count, cpu_backend_context->max_num_threads());
1168
1169 if (capped_thread_count == 1) {
1170 MeanImpl(op_params, input_shape, input_data, multiplier, shift, bias,
1171 output_shape, output_data, 0, output_depth);
1172 } else {
1173 // Instead parallel for batch, we loop for the output_depth since batch
1174 // is typical 1.
1175 std::vector<MeanWorkerTask> tasks;
1176 // TODO(b/131746020) don't create new heap allocations every time.
1177 // At least we make it a single heap allocation by using reserve().
1178 tasks.reserve(capped_thread_count);
1179 int depth_start = 0;
1180 for (int i = 0; i < capped_thread_count; ++i) {
1181 // Try to distribute the tasks as even as possible.
1182 int depth_end = depth_start +
1183 (output_depth - depth_start) / (capped_thread_count - i);
1184 tasks.emplace_back(op_params, input_shape, input_data, multiplier, shift,
1185 bias, output_shape, output_data, depth_start,
1186 depth_end);
1187 depth_start = depth_end;
1188 }
1189 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
1190 cpu_backend_context);
1191 }
1192 }
1193
1194 template <typename T, typename U>
MeanGeneral(const T * input_data,const int * input_dims,const int input_num_dims,T * output_data,const int * output_dims,const int output_num_dims,const int * axis,const int num_axis_dimensions,bool keep_dims,int * temp_index,int * resolved_axis,U * temp_sum)1195 inline bool MeanGeneral(const T* input_data, const int* input_dims,
1196 const int input_num_dims, T* output_data,
1197 const int* output_dims, const int output_num_dims,
1198 const int* axis, const int num_axis_dimensions,
1199 bool keep_dims, int* temp_index, int* resolved_axis,
1200 U* temp_sum) {
1201 return reference_ops::Mean(input_data, input_dims, input_num_dims,
1202 output_data, output_dims, output_num_dims, axis,
1203 num_axis_dimensions, keep_dims, temp_index,
1204 resolved_axis, temp_sum);
1205 }
1206
1207 template <>
1208 inline bool MeanGeneral<float, float>(
1209 const float* input_data, const int* input_dims, const int input_num_dims,
1210 float* output_data, const int* output_dims, const int output_num_dims,
1211 const int* axis, const int num_axis_dimensions, bool keep_dims,
1212 int* temp_index, int* resolved_axis, float* temp_sum) {
1213 // Handle reduce_mean for the last dimensions.
1214 if (num_axis_dimensions == 1 && axis[0] == (input_num_dims - 1)) {
1215 ruy::profiler::ScopeLabel label("MeanLastDim/Float");
1216 int output_size = 1;
1217 for (int i = 0; i < input_num_dims - 1; ++i) {
1218 output_size *= input_dims[i];
1219 }
1220 const int last_input_dim = input_dims[axis[0]];
1221
1222 // TODO(b/152563685): Consider use eigen to cover more general cases.
1223 const MatrixMap<const float> in_mat(input_data, last_input_dim,
1224 output_size);
1225 VectorMap<float> out(output_data, output_size, 1);
1226 out = (in_mat.array().colwise().sum()) / static_cast<float>(last_input_dim);
1227 return true;
1228 }
1229
1230 return reference_ops::Mean(input_data, input_dims, input_num_dims,
1231 output_data, output_dims, output_num_dims, axis,
1232 num_axis_dimensions, keep_dims, temp_index,
1233 resolved_axis, temp_sum);
1234 }
1235
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,CpuBackendContext * cpu_backend_context)1236 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1237 const float* input_data, const RuntimeShape& filter_shape,
1238 const float* filter_data, const RuntimeShape& bias_shape,
1239 const float* bias_data, const RuntimeShape& output_shape,
1240 float* output_data, const RuntimeShape& im2col_shape,
1241 float* im2col_data, CpuBackendContext* cpu_backend_context) {
1242 const int stride_width = params.stride_width;
1243 const int stride_height = params.stride_height;
1244 const int dilation_width_factor = params.dilation_width_factor;
1245 const int dilation_height_factor = params.dilation_height_factor;
1246 const float output_activation_min = params.float_activation_min;
1247 const float output_activation_max = params.float_activation_max;
1248 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1249 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1250 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1251
1252 ruy::profiler::ScopeLabel label("Conv");
1253
1254 // NB: the float 0.0f value is represented by all zero bytes.
1255 const uint8 float_zero_byte = 0x00;
1256 const float* gemm_input_data = nullptr;
1257 const RuntimeShape* gemm_input_shape = nullptr;
1258 const int filter_width = filter_shape.Dims(2);
1259 const int filter_height = filter_shape.Dims(1);
1260 const bool need_dilated_im2col =
1261 dilation_width_factor != 1 || dilation_height_factor != 1;
1262 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1263 filter_width != 1 || filter_height != 1;
1264 if (need_dilated_im2col) {
1265 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
1266 filter_shape, output_shape, im2col_data);
1267 gemm_input_data = im2col_data;
1268 gemm_input_shape = &im2col_shape;
1269 } else if (need_im2col) {
1270 TFLITE_DCHECK(im2col_data);
1271 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
1272 input_data, im2col_shape, im2col_data);
1273 gemm_input_data = im2col_data;
1274 gemm_input_shape = &im2col_shape;
1275 } else {
1276 TFLITE_DCHECK(!im2col_data);
1277 gemm_input_data = input_data;
1278 gemm_input_shape = &input_shape;
1279 }
1280
1281 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
1282 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
1283 int n = output_shape.Dims(3);
1284 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
1285
1286 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1287 // The following code computes matrix multiplication c = a * transponse(b)
1288 // with CBLAS, where:
1289 // * `a` is a matrix with dimensions (m, k).
1290 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
1291 // * `c` is a matrix with dimensions (m, n).
1292 // The naming of variables are aligned with CBLAS specification here.
1293 const float* a = gemm_input_data;
1294 const float* b = filter_data;
1295 float* c = output_data;
1296 // The stride of matrix a, b and c respectively.
1297 int stride_a = k;
1298 int stride_b = k;
1299 int stride_c = n;
1300
1301 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
1302 stride_a, b, stride_b, 0.0f, c, stride_c);
1303 optimized_ops::AddBiasAndEvalActivationFunction(
1304 output_activation_min, output_activation_max, bias_shape, bias_data,
1305 output_shape, output_data);
1306 #else
1307 // When an optimized CBLAS implementation is not available, fall back
1308 // to using cpu_backend_gemm.
1309 cpu_backend_gemm::MatrixParams<float> lhs_params;
1310 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1311 lhs_params.rows = n;
1312 lhs_params.cols = k;
1313 cpu_backend_gemm::MatrixParams<float> rhs_params;
1314 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1315 rhs_params.rows = k;
1316 rhs_params.cols = m;
1317 cpu_backend_gemm::MatrixParams<float> dst_params;
1318 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1319 dst_params.rows = n;
1320 dst_params.cols = m;
1321 cpu_backend_gemm::GemmParams<float, float> gemm_params;
1322 gemm_params.bias = bias_data;
1323 gemm_params.clamp_min = output_activation_min;
1324 gemm_params.clamp_max = output_activation_max;
1325 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1326 dst_params, output_data, gemm_params,
1327 cpu_backend_context);
1328 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
1329 }
1330
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & accum_scratch_shape,int32_t * accum_scratch,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,CpuBackendContext * context)1331 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
1332 const RuntimeShape& input_shape,
1333 const int8_t* input_data,
1334 const RuntimeShape& filter_shape,
1335 const int8_t* filter_data,
1336 const RuntimeShape& bias_shape, const float* bias_data,
1337 const RuntimeShape& accum_scratch_shape,
1338 int32_t* accum_scratch, const RuntimeShape& output_shape,
1339 float* output_data, const RuntimeShape& im2col_shape,
1340 int8_t* im2col_data, CpuBackendContext* context) {
1341 const int stride_width = params.stride_width;
1342 const int stride_height = params.stride_height;
1343 const int dilation_width_factor = params.dilation_width_factor;
1344 const int dilation_height_factor = params.dilation_height_factor;
1345 const float output_activation_min = params.float_activation_min;
1346 const float output_activation_max = params.float_activation_max;
1347 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1348 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1349 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1350
1351 const int batch_size = input_shape.Dims(0);
1352 const int filter_width = filter_shape.Dims(2);
1353 const int filter_height = filter_shape.Dims(1);
1354
1355 const int input_zero_point = 0;
1356 const int8_t* gemm_input_data = nullptr;
1357 int num_input;
1358 const bool need_dilated_im2col =
1359 dilation_width_factor != 1 || dilation_height_factor != 1;
1360 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1361 filter_width != 1 || filter_height != 1;
1362
1363 if (need_dilated_im2col) {
1364 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1365 filter_shape, output_shape, im2col_data);
1366 gemm_input_data = im2col_data;
1367 num_input = im2col_shape.FlatSize();
1368 } else if (need_im2col) {
1369 TFLITE_DCHECK(im2col_data);
1370 // symmetric quantization assumes zero point of 0.
1371
1372 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1373 input_data, im2col_shape, im2col_data);
1374 gemm_input_data = im2col_data;
1375 num_input = im2col_shape.FlatSize();
1376 } else {
1377 TFLITE_DCHECK(!im2col_data);
1378 gemm_input_data = input_data;
1379 num_input = input_shape.FlatSize();
1380 }
1381
1382 // Flatten 4D matrices into 2D matrices for matrix multiplication.
1383
1384 // Flatten so that each filter has its own row.
1385 const int filter_rows = filter_shape.Dims(0);
1386 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1387
1388 // In MatrixBatchVectorMultiplyAccumulate, each output value is the
1389 // dot product of one row of the first matrix with one row of the second
1390 // matrix. Therefore, the number of cols in each matrix are equivalent.
1391 //
1392 // After Im2Col, each input patch becomes a row.
1393 const int gemm_input_cols = filter_cols;
1394 const int gemm_input_rows = num_input / gemm_input_cols;
1395
1396 const int output_cols = output_shape.Dims(3);
1397 const int output_rows = FlatSizeSkipDim(output_shape, 3);
1398 TFLITE_DCHECK_EQ(output_cols, filter_rows);
1399 TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
1400 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
1401
1402 // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
1403 // input matrix has its own scale factor. This code duplicates the scale
1404 // factors for each row in the same batch.
1405 const int rows_per_batch = gemm_input_rows / batch_size;
1406 for (int i = gemm_input_rows - 1; i >= 0; --i) {
1407 scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
1408 }
1409
1410 std::fill_n(output_data, output_rows * output_cols, 0.0f);
1411
1412 // The scratch buffer must have the same size as the output.
1413 TFLITE_DCHECK_EQ(accum_scratch_shape.FlatSize(), output_shape.FlatSize());
1414 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
1415 filter_data, filter_rows, filter_cols, gemm_input_data,
1416 scaling_factors_ptr, /*n_batch=*/gemm_input_rows, accum_scratch,
1417 output_data, context);
1418 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
1419 bias_shape, bias_data, output_shape,
1420 output_data);
1421 }
1422
HybridConvPerChannel(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data,const float * per_channel_scale,int32_t * input_offset,const RuntimeShape & scratch_shape,int32_t * scratch,int32_t * row_sums,bool * compute_row_sums,CpuBackendContext * cpu_backend_context)1423 inline void HybridConvPerChannel(
1424 const ConvParams& params, float* scaling_factors_ptr,
1425 const RuntimeShape& input_shape, const int8_t* input_data,
1426 const RuntimeShape& filter_shape, const int8_t* filter_data,
1427 const RuntimeShape& bias_shape, const float* bias_data,
1428 const RuntimeShape& output_shape, float* output_data,
1429 const RuntimeShape& im2col_shape, int8_t* im2col_data,
1430 const float* per_channel_scale, int32_t* input_offset,
1431 const RuntimeShape& scratch_shape, int32_t* scratch, int32_t* row_sums,
1432 bool* compute_row_sums, CpuBackendContext* cpu_backend_context) {
1433 ruy::profiler::ScopeLabel label("ConvHybridPerChannel");
1434 const int stride_width = params.stride_width;
1435 const int stride_height = params.stride_height;
1436 const int dilation_width_factor = params.dilation_width_factor;
1437 const int dilation_height_factor = params.dilation_height_factor;
1438 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1439 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1440 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1441
1442 const int8* gemm_input_data = nullptr;
1443 const RuntimeShape* gemm_input_shape = nullptr;
1444 const int filter_width = filter_shape.Dims(2);
1445 const int filter_height = filter_shape.Dims(1);
1446 const bool need_dilated_im2col =
1447 dilation_width_factor != 1 || dilation_height_factor != 1;
1448 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1449 filter_width != 1 || filter_height != 1;
1450
1451 const int batch_size = input_shape.Dims(0);
1452
1453 if (need_dilated_im2col) {
1454 TFLITE_DCHECK(im2col_data);
1455 optimized_ops::DilatedIm2col(params, input_shape, input_data, filter_shape,
1456 output_shape, im2col_data, input_offset,
1457 batch_size);
1458 gemm_input_data = im2col_data;
1459 gemm_input_shape = &im2col_shape;
1460 } else if (need_im2col) {
1461 Im2col(params, filter_height, filter_width, input_offset, batch_size,
1462 input_shape, input_data, im2col_shape, im2col_data);
1463 gemm_input_data = im2col_data;
1464 gemm_input_shape = &im2col_shape;
1465 } else {
1466 TFLITE_DCHECK(!im2col_data);
1467 gemm_input_data = input_data;
1468 gemm_input_shape = &input_shape;
1469 }
1470
1471 const int filter_rows = filter_shape.Dims(0);
1472 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1473
1474 const int gemm_input_rows = gemm_input_shape->Dims(3);
1475 const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1476 const int output_rows = output_shape.Dims(3);
1477 const int output_cols =
1478 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1479
1480 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1481 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1482 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1483 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1484 TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
1485 if (!compute_row_sums || *compute_row_sums) {
1486 tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
1487 filter_cols);
1488 if (compute_row_sums) {
1489 *compute_row_sums = false;
1490 }
1491 }
1492
1493 cpu_backend_gemm::MatrixParams<int8> lhs_params;
1494 lhs_params.rows = filter_rows;
1495 lhs_params.cols = filter_cols;
1496 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1497
1498 cpu_backend_gemm::MatrixParams<int8> rhs_params;
1499 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1500 rhs_params.rows = gemm_input_rows;
1501 rhs_params.cols = gemm_input_cols;
1502
1503 cpu_backend_gemm::MatrixParams<int32> dst_params;
1504 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1505 dst_params.rows = output_rows;
1506 dst_params.cols = output_cols;
1507
1508 // TODO(b/149003801): Use hybrid gemm once supported in Ruy.
1509 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
1510 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1511 dst_params, scratch, gemm_params, cpu_backend_context);
1512
1513 MatrixMap<float> out_mat(output_data, filter_rows, output_cols);
1514 MatrixMap<int32_t> in_mat(scratch, filter_rows, output_cols);
1515 VectorMap<const float> bias_data_vec(bias_data, filter_rows, 1);
1516 VectorMap<int32_t> row_sums_vec(row_sums, filter_rows, 1);
1517 VectorMap<const float> per_channel_scale_vec(per_channel_scale, filter_rows,
1518 1);
1519 const int cols_per_batch = output_cols / batch_size;
1520 for (int c = 0; c < output_cols; c++) {
1521 const int b = c / cols_per_batch;
1522 const float input_scale = scaling_factors_ptr[b];
1523 const int32_t zero_point = input_offset[b];
1524 out_mat.col(c) =
1525 (((in_mat.col(c) - (row_sums_vec * zero_point))
1526 .cast<float>()
1527 .cwiseProduct((per_channel_scale_vec * input_scale))) +
1528 bias_data_vec)
1529 .cwiseMin(params.float_activation_max)
1530 .cwiseMax(params.float_activation_min);
1531 }
1532 }
1533
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,CpuBackendContext * cpu_backend_context)1534 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
1535 const uint8* input_data, const RuntimeShape& filter_shape,
1536 const uint8* filter_data, const RuntimeShape& bias_shape,
1537 const int32* bias_data, const RuntimeShape& output_shape,
1538 uint8* output_data, const RuntimeShape& im2col_shape,
1539 uint8* im2col_data, CpuBackendContext* cpu_backend_context) {
1540 ruy::profiler::ScopeLabel label("Conv/8bit");
1541
1542 const int stride_width = params.stride_width;
1543 const int stride_height = params.stride_height;
1544 const int dilation_width_factor = params.dilation_width_factor;
1545 const int dilation_height_factor = params.dilation_height_factor;
1546 const int32 input_offset = params.input_offset;
1547 const int32 filter_offset = params.weights_offset;
1548 const int32 output_offset = params.output_offset;
1549 const int32 output_multiplier = params.output_multiplier;
1550 const int output_shift = params.output_shift;
1551 const int32 output_activation_min = params.quantized_activation_min;
1552 const int32 output_activation_max = params.quantized_activation_max;
1553 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1554 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1555 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1556
1557 const uint8* gemm_input_data = nullptr;
1558 const RuntimeShape* gemm_input_shape = nullptr;
1559 const int filter_width = filter_shape.Dims(2);
1560 const int filter_height = filter_shape.Dims(1);
1561 const bool need_dilated_im2col =
1562 dilation_width_factor != 1 || dilation_height_factor != 1;
1563 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
1564 filter_width != 1 || filter_height != 1;
1565 if (need_dilated_im2col) {
1566 TFLITE_DCHECK(im2col_data);
1567 const int input_zero_point = -input_offset;
1568 TFLITE_DCHECK_GE(input_zero_point, 0);
1569 TFLITE_DCHECK_LE(input_zero_point, 255);
1570 DilatedIm2col(params, input_zero_point, input_shape, input_data,
1571 filter_shape, output_shape, im2col_data);
1572 gemm_input_data = im2col_data;
1573 gemm_input_shape = &im2col_shape;
1574 } else if (need_im2col) {
1575 TFLITE_DCHECK(im2col_data);
1576 const int input_zero_point = -input_offset;
1577 TFLITE_DCHECK_GE(input_zero_point, 0);
1578 TFLITE_DCHECK_LE(input_zero_point, 255);
1579 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
1580 input_data, im2col_shape, im2col_data);
1581 gemm_input_data = im2col_data;
1582 gemm_input_shape = &im2col_shape;
1583 } else {
1584 TFLITE_DCHECK(!im2col_data);
1585 gemm_input_data = input_data;
1586 gemm_input_shape = &input_shape;
1587 }
1588
1589 const int gemm_input_rows = gemm_input_shape->Dims(3);
1590 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
1591 // The root cause has not yet been identified though. Same applies below for
1592 // the other calls commented out. This is a partial rollback of cl/196819423.
1593 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
1594 const int gemm_input_cols = gemm_input_shape->Dims(0) *
1595 gemm_input_shape->Dims(1) *
1596 gemm_input_shape->Dims(2);
1597 const int filter_rows = filter_shape.Dims(0);
1598 // See b/79927784.
1599 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
1600 const int filter_cols =
1601 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
1602 const int output_rows = output_shape.Dims(3);
1603 // See b/79927784.
1604 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
1605 const int output_cols =
1606 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
1607 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1608 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
1609 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
1610 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1611
1612 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
1613 lhs_params.rows = filter_rows;
1614 lhs_params.cols = filter_cols;
1615 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
1616 lhs_params.zero_point = -filter_offset;
1617 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
1618 rhs_params.rows = gemm_input_rows;
1619 rhs_params.cols = gemm_input_cols;
1620 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
1621 rhs_params.zero_point = -input_offset;
1622 cpu_backend_gemm::MatrixParams<uint8> dst_params;
1623 dst_params.rows = output_rows;
1624 dst_params.cols = output_cols;
1625 dst_params.order = cpu_backend_gemm::Order::kColMajor;
1626 dst_params.zero_point = output_offset;
1627 cpu_backend_gemm::GemmParams<int32, uint8> gemm_params;
1628 gemm_params.bias = bias_data;
1629 gemm_params.clamp_min = output_activation_min;
1630 gemm_params.clamp_max = output_activation_max;
1631 gemm_params.multiplier_fixedpoint = output_multiplier;
1632 gemm_params.multiplier_exponent = output_shift;
1633 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
1634 dst_params, output_data, gemm_params,
1635 cpu_backend_context);
1636 }
1637
1638 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1639 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
1640 const RuntimeShape& unextended_input_shape,
1641 const T* input_data,
1642 const RuntimeShape& unextended_output_shape,
1643 T* output_data) {
1644 ruy::profiler::ScopeLabel label("DepthToSpace");
1645
1646 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1647 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1648 const RuntimeShape input_shape =
1649 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1650 const RuntimeShape output_shape =
1651 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1652
1653 const int input_depth = input_shape.Dims(3);
1654 const int input_width = input_shape.Dims(2);
1655 const int input_height = input_shape.Dims(1);
1656
1657 const int output_depth = output_shape.Dims(3);
1658 const int batch_size = output_shape.Dims(0);
1659
1660 // Number of continuous values that we can copy in one interation.
1661 const int stride = op_params.block_size * output_depth;
1662
1663 for (int batch = 0; batch < batch_size; ++batch) {
1664 for (int in_h = 0; in_h < input_height; ++in_h) {
1665 const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
1666 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1667 const T* src = input_ptr;
1668 for (int in_w = 0; in_w < input_width; ++in_w) {
1669 memcpy(output_data, src, stride * sizeof(T));
1670 output_data += stride;
1671 src += input_depth;
1672 }
1673 input_ptr += stride;
1674 }
1675 }
1676 }
1677 }
1678
1679 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)1680 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
1681 const RuntimeShape& unextended_input_shape,
1682 const T* input_data,
1683 const RuntimeShape& unextended_output_shape,
1684 T* output_data) {
1685 ruy::profiler::ScopeLabel label("SpaceToDepth");
1686
1687 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1688 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1689 const RuntimeShape input_shape =
1690 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1691 const RuntimeShape output_shape =
1692 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1693
1694 const int output_depth = output_shape.Dims(3);
1695 const int output_width = output_shape.Dims(2);
1696 const int output_height = output_shape.Dims(1);
1697
1698 const int input_depth = input_shape.Dims(3);
1699 const int batch_size = input_shape.Dims(0);
1700
1701 // Number of continuous values that we can copy in one interation.
1702 const int stride = op_params.block_size * input_depth;
1703
1704 for (int batch = 0; batch < batch_size; ++batch) {
1705 for (int out_h = 0; out_h < output_height; ++out_h) {
1706 T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
1707 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
1708 T* dst = output_ptr;
1709 for (int out_w = 0; out_w < output_width; ++out_w) {
1710 memcpy(dst, input_data, stride * sizeof(T));
1711 input_data += stride;
1712 dst += output_depth;
1713 }
1714 output_ptr += stride;
1715 }
1716 }
1717 }
1718 }
1719
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1720 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
1721 const RuntimeShape& output_shape, float* output_data) {
1722 ruy::profiler::ScopeLabel label("Relu (not fused)");
1723
1724 const auto input = MapAsVector(input_data, input_shape);
1725 auto output = MapAsVector(output_data, output_shape);
1726 output = input.cwiseMax(0.0f);
1727 }
1728
1729 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1730 const RuntimeShape& input_shape,
1731 const float* input_data,
1732 const RuntimeShape& output_shape,
1733 float* output_data, float epsilon = 1e-6) {
1734 ruy::profiler::ScopeLabel label("L2Normalization");
1735 const int trailing_dim = input_shape.DimensionsCount() - 1;
1736 const int outer_size =
1737 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1738 const int depth =
1739 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1740 for (int i = 0; i < outer_size; ++i) {
1741 float squared_l2_norm = 0;
1742 for (int c = 0; c < depth; ++c) {
1743 const float val = input_data[c];
1744 squared_l2_norm += val * val;
1745 }
1746 float l2_norm = std::sqrt(squared_l2_norm);
1747 l2_norm = std::max(l2_norm, epsilon);
1748 for (int c = 0; c < depth; ++c) {
1749 *output_data = *input_data / l2_norm;
1750 ++output_data;
1751 ++input_data;
1752 }
1753 }
1754 }
1755
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)1756 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
1757 const RuntimeShape& input_shape,
1758 const uint8* input_data,
1759 const RuntimeShape& output_shape,
1760 uint8* output_data) {
1761 ruy::profiler::ScopeLabel label("L2Normalization/8bit");
1762 const int trailing_dim = input_shape.DimensionsCount() - 1;
1763 const int depth =
1764 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1765 const int outer_size =
1766 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1767 const int32 input_zero_point = op_params.input_zero_point;
1768 for (int i = 0; i < outer_size; ++i) {
1769 int32 square_l2_norm = 0;
1770 for (int c = 0; c < depth; c++) {
1771 // Note that input_data advances by depth in the second pass below.
1772 int32 diff = input_data[c] - input_zero_point;
1773 square_l2_norm += diff * diff;
1774 }
1775 // TODO(b/29395854): add clamping to TOCO and TF Lite kernel
1776 // for all zero tensors in the input_data
1777 int32 inv_l2norm_multiplier;
1778 int inv_l2norm_shift;
1779 GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
1780 &inv_l2norm_multiplier, &inv_l2norm_shift);
1781
1782 for (int c = 0; c < depth; c++) {
1783 int32 diff = *input_data - input_zero_point;
1784 int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
1785 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
1786 int32 unclamped_output_val = 128 + rescaled_diff;
1787 int32 output_val = std::min(255, std::max(0, unclamped_output_val));
1788 *output_data = static_cast<uint8>(output_val);
1789 ++input_data;
1790 ++output_data;
1791 }
1792 }
1793 }
1794
AddElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)1795 inline void AddElementwise(int size, const ArithmeticParams& params,
1796 const float* input1_data, const float* input2_data,
1797 float* output_data) {
1798 int i = 0;
1799
1800 #ifdef USE_NEON
1801 const auto activation_min = vdupq_n_f32(params.float_activation_min);
1802 const auto activation_max = vdupq_n_f32(params.float_activation_max);
1803 for (; i <= size - 16; i += 16) {
1804 auto a10 = vld1q_f32(input1_data + i);
1805 auto a11 = vld1q_f32(input1_data + i + 4);
1806 auto a12 = vld1q_f32(input1_data + i + 8);
1807 auto a13 = vld1q_f32(input1_data + i + 12);
1808 auto a20 = vld1q_f32(input2_data + i);
1809 auto a21 = vld1q_f32(input2_data + i + 4);
1810 auto a22 = vld1q_f32(input2_data + i + 8);
1811 auto a23 = vld1q_f32(input2_data + i + 12);
1812 auto x0 = vaddq_f32(a10, a20);
1813 auto x1 = vaddq_f32(a11, a21);
1814 auto x2 = vaddq_f32(a12, a22);
1815 auto x3 = vaddq_f32(a13, a23);
1816 x0 = vmaxq_f32(activation_min, x0);
1817 x1 = vmaxq_f32(activation_min, x1);
1818 x2 = vmaxq_f32(activation_min, x2);
1819 x3 = vmaxq_f32(activation_min, x3);
1820 x0 = vminq_f32(activation_max, x0);
1821 x1 = vminq_f32(activation_max, x1);
1822 x2 = vminq_f32(activation_max, x2);
1823 x3 = vminq_f32(activation_max, x3);
1824 vst1q_f32(output_data + i, x0);
1825 vst1q_f32(output_data + i + 4, x1);
1826 vst1q_f32(output_data + i + 8, x2);
1827 vst1q_f32(output_data + i + 12, x3);
1828 }
1829 for (; i <= size - 4; i += 4) {
1830 auto a1 = vld1q_f32(input1_data + i);
1831 auto a2 = vld1q_f32(input2_data + i);
1832 auto x = vaddq_f32(a1, a2);
1833 x = vmaxq_f32(activation_min, x);
1834 x = vminq_f32(activation_max, x);
1835 vst1q_f32(output_data + i, x);
1836 }
1837 #endif // NEON
1838
1839 for (; i < size; i++) {
1840 auto x = input1_data[i] + input2_data[i];
1841 output_data[i] = ActivationFunctionWithMinMax(
1842 x, params.float_activation_min, params.float_activation_max);
1843 }
1844 }
1845
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)1846 inline void Add(const ArithmeticParams& params,
1847 const RuntimeShape& input1_shape, const float* input1_data,
1848 const RuntimeShape& input2_shape, const float* input2_data,
1849 const RuntimeShape& output_shape, float* output_data) {
1850 ruy::profiler::ScopeLabel label("Add");
1851 const int flat_size =
1852 MatchingElementsSize(input1_shape, input2_shape, output_shape);
1853 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
1854 }
1855
1856 // Element-wise add that can often be used for inner loop of broadcast add as
1857 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)1858 inline void AddElementwise(int size, const ArithmeticParams& params,
1859 const uint8* input1_data, const uint8* input2_data,
1860 uint8* output_data) {
1861 ruy::profiler::ScopeLabel label("AddElementwise/8bit");
1862 int i = 0;
1863 TFLITE_DCHECK_GT(params.input1_offset, -256);
1864 TFLITE_DCHECK_GT(params.input2_offset, -256);
1865 TFLITE_DCHECK_LT(params.input1_offset, 256);
1866 TFLITE_DCHECK_LT(params.input2_offset, 256);
1867 #ifdef USE_NEON
1868 const uint8x8_t output_activation_min_vector =
1869 vdup_n_u8(params.quantized_activation_min);
1870 const uint8x8_t output_activation_max_vector =
1871 vdup_n_u8(params.quantized_activation_max);
1872 for (; i <= size - 8; i += 8) {
1873 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
1874 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1875 const int16x8_t input1_val_s16 =
1876 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1877 const int16x8_t input2_val_s16 =
1878 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1879 const int16x8_t input1_val =
1880 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1881 const int16x8_t input2_val =
1882 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1883 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1884 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1885 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1886 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1887 int32x4_t x11 = vmovl_s16(input1_val_low);
1888 int32x4_t x12 = vmovl_s16(input1_val_high);
1889 int32x4_t x21 = vmovl_s16(input2_val_low);
1890 int32x4_t x22 = vmovl_s16(input2_val_high);
1891 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1892 x11 = vshlq_s32(x11, left_shift_dup);
1893 x12 = vshlq_s32(x12, left_shift_dup);
1894 x21 = vshlq_s32(x21, left_shift_dup);
1895 x22 = vshlq_s32(x22, left_shift_dup);
1896 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1897 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1898 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
1899 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
1900 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1901 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
1902 x11 = vshlq_s32(x11, input1_shift_dup);
1903 x12 = vshlq_s32(x12, input1_shift_dup);
1904 x21 = vshlq_s32(x21, input2_shift_dup);
1905 x22 = vshlq_s32(x22, input2_shift_dup);
1906 int32x4_t s1 = vaddq_s32(x11, x21);
1907 int32x4_t s2 = vaddq_s32(x12, x22);
1908 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
1909 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
1910 using gemmlowp::RoundingDivideByPOT;
1911 s1 = RoundingDivideByPOT(s1, -params.output_shift);
1912 s2 = RoundingDivideByPOT(s2, -params.output_shift);
1913 const int16x4_t s1_narrowed = vmovn_s32(s1);
1914 const int16x4_t s2_narrowed = vmovn_s32(s2);
1915 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
1916 vdupq_n_s16(params.output_offset));
1917 const uint8x8_t clamped =
1918 vmax_u8(output_activation_min_vector,
1919 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
1920 vst1_u8(output_data + i, clamped);
1921 }
1922 #endif // NEON
1923
1924 for (; i < size; ++i) {
1925 const int32 input1_val = params.input1_offset + input1_data[i];
1926 const int32 input2_val = params.input2_offset + input2_data[i];
1927 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
1928 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
1929 const int32 scaled_input1_val =
1930 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1931 shifted_input1_val, params.input1_multiplier, params.input1_shift);
1932 const int32 scaled_input2_val =
1933 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1934 shifted_input2_val, params.input2_multiplier, params.input2_shift);
1935 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
1936 const int32 raw_output =
1937 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1938 raw_sum, params.output_multiplier, params.output_shift) +
1939 params.output_offset;
1940 const int32 clamped_output =
1941 std::min(params.quantized_activation_max,
1942 std::max(params.quantized_activation_min, raw_output));
1943 output_data[i] = static_cast<uint8>(clamped_output);
1944 }
1945 }
1946
1947 // Scalar-broadcast add that can be used for inner loop of more general
1948 // broadcast add, so that, for example, scalar-broadcast with batch will still
1949 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)1950 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
1951 uint8 input1_data, const uint8* input2_data,
1952 uint8* output_data) {
1953 using gemmlowp::RoundingDivideByPOT;
1954
1955 ruy::profiler::ScopeLabel label("AddScalarBroadcast/8bit");
1956 TFLITE_DCHECK_GT(params.input1_offset, -256);
1957 TFLITE_DCHECK_GT(params.input2_offset, -256);
1958 TFLITE_DCHECK_LT(params.input1_offset, 256);
1959 TFLITE_DCHECK_LT(params.input2_offset, 256);
1960
1961 int i = 0;
1962
1963 #ifdef USE_NEON
1964 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
1965 const uint8x8_t output_activation_min_vector =
1966 vdup_n_u8(params.quantized_activation_min);
1967 const uint8x8_t output_activation_max_vector =
1968 vdup_n_u8(params.quantized_activation_max);
1969
1970 // Process broadcast scalar.
1971 const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
1972 const int16x8_t input1_val_s16 =
1973 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
1974 const int16x8_t input1_val =
1975 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
1976 const int16x4_t input1_val_high = vget_high_s16(input1_val);
1977 const int16x4_t input1_val_low = vget_low_s16(input1_val);
1978 int32x4_t x11 = vmovl_s16(input1_val_low);
1979 int32x4_t x12 = vmovl_s16(input1_val_high);
1980 x11 = vshlq_s32(x11, left_shift_dup);
1981 x12 = vshlq_s32(x12, left_shift_dup);
1982 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
1983 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
1984 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
1985 x11 = vshlq_s32(x11, input1_shift_dup);
1986 x12 = vshlq_s32(x12, input1_shift_dup);
1987
1988 for (; i <= size - 8; i += 8) {
1989 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
1990 const int16x8_t input2_val_s16 =
1991 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
1992 const int16x8_t input2_val =
1993 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
1994 const int16x4_t input2_val_high = vget_high_s16(input2_val);
1995 const int16x4_t input2_val_low = vget_low_s16(input2_val);
1996 int32x4_t x21 = vmovl_s16(input2_val_low);
1997 int32x4_t x22 = vmovl_s16(input2_val_high);
1998 x21 = vshlq_s32(x21, left_shift_dup);
1999 x22 = vshlq_s32(x22, left_shift_dup);
2000 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2001 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2002 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2003 x21 = vshlq_s32(x21, input2_shift_dup);
2004 x22 = vshlq_s32(x22, input2_shift_dup);
2005 int32x4_t s1 = vaddq_s32(x11, x21);
2006 int32x4_t s2 = vaddq_s32(x12, x22);
2007 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2008 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2009 s1 = RoundingDivideByPOT(s1, -params.output_shift);
2010 s2 = RoundingDivideByPOT(s2, -params.output_shift);
2011 const int16x4_t s1_narrowed = vmovn_s32(s1);
2012 const int16x4_t s2_narrowed = vmovn_s32(s2);
2013 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2014 vdupq_n_s16(params.output_offset));
2015 const uint8x8_t clamped =
2016 vmax_u8(output_activation_min_vector,
2017 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2018 vst1_u8(output_data + i, clamped);
2019 }
2020 #endif // NEON
2021
2022 if (i < size) {
2023 // Process broadcast scalar.
2024 const int32 input1_val = params.input1_offset + input1_data;
2025 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2026 const int32 scaled_input1_val =
2027 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2028 shifted_input1_val, params.input1_multiplier, params.input1_shift);
2029
2030 for (; i < size; ++i) {
2031 const int32 input2_val = params.input2_offset + input2_data[i];
2032 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2033 const int32 scaled_input2_val =
2034 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2035 shifted_input2_val, params.input2_multiplier,
2036 params.input2_shift);
2037 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2038 const int32 raw_output =
2039 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2040 raw_sum, params.output_multiplier, params.output_shift) +
2041 params.output_offset;
2042 const int32 clamped_output =
2043 std::min(params.quantized_activation_max,
2044 std::max(params.quantized_activation_min, raw_output));
2045 output_data[i] = static_cast<uint8>(clamped_output);
2046 }
2047 }
2048 }
2049
2050 // Scalar-broadcast add that can be used for inner loop of more general
2051 // broadcast add, so that, for example, scalar-broadcast with batch will still
2052 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,float broadcast_value,const float * input2_data,float * output_data)2053 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2054 float broadcast_value, const float* input2_data,
2055 float* output_data) {
2056 int i = 0;
2057 #ifdef USE_NEON
2058 const float32x4_t output_activation_min_vector =
2059 vdupq_n_f32(params.float_activation_min);
2060 const float32x4_t output_activation_max_vector =
2061 vdupq_n_f32(params.float_activation_max);
2062 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2063 for (; i <= size - 4; i += 4) {
2064 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2065
2066 const float32x4_t output =
2067 vaddq_f32(input2_val_original, broadcast_value_dup);
2068
2069 const float32x4_t clamped =
2070 vmaxq_f32(output_activation_min_vector,
2071 vminq_f32(output_activation_max_vector, output));
2072 vst1q_f32(output_data + i, clamped);
2073 }
2074 #endif // NEON
2075
2076 for (; i < size; ++i) {
2077 auto x = broadcast_value + input2_data[i];
2078 output_data[i] = ActivationFunctionWithMinMax(
2079 x, params.float_activation_min, params.float_activation_max);
2080 }
2081 }
2082
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2083 inline void Add(const ArithmeticParams& params,
2084 const RuntimeShape& input1_shape, const uint8* input1_data,
2085 const RuntimeShape& input2_shape, const uint8* input2_data,
2086 const RuntimeShape& output_shape, uint8* output_data) {
2087 TFLITE_DCHECK_LE(params.quantized_activation_min,
2088 params.quantized_activation_max);
2089 ruy::profiler::ScopeLabel label("Add/8bit");
2090 const int flat_size =
2091 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2092
2093 TFLITE_DCHECK_GT(params.input1_offset, -256);
2094 TFLITE_DCHECK_GT(params.input2_offset, -256);
2095 TFLITE_DCHECK_LT(params.input1_offset, 256);
2096 TFLITE_DCHECK_LT(params.input2_offset, 256);
2097 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2098 }
2099
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2100 inline void Add(const ArithmeticParams& params,
2101 const RuntimeShape& input1_shape, const int16* input1_data,
2102 const RuntimeShape& input2_shape, const int16* input2_data,
2103 const RuntimeShape& output_shape, int16* output_data) {
2104 ruy::profiler::ScopeLabel label("Add/Int16");
2105 TFLITE_DCHECK_LE(params.quantized_activation_min,
2106 params.quantized_activation_max);
2107
2108 const int input1_shift = params.input1_shift;
2109 const int flat_size =
2110 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2111 const int16 output_activation_min = params.quantized_activation_min;
2112 const int16 output_activation_max = params.quantized_activation_max;
2113
2114 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2115 TFLITE_DCHECK_LE(input1_shift, 0);
2116 TFLITE_DCHECK_LE(params.input2_shift, 0);
2117 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2118 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2119 const int input_right_shift =
2120 input1_shift == 0 ? -params.input2_shift : -input1_shift;
2121
2122 for (int i = 0; i < flat_size; i++) {
2123 // F0 uses 0 integer bits, range [-1, 1].
2124 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2125
2126 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2127 F0 scaled_input = F0::FromRaw(
2128 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2129 F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2130 const int16 raw_output = result.raw();
2131 const int16 clamped_output = std::min(
2132 output_activation_max, std::max(output_activation_min, raw_output));
2133 output_data[i] = clamped_output;
2134 }
2135 }
2136
2137 template <typename T>
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2138 inline typename std::enable_if<is_int32_or_int64<T>::value, void>::type Add(
2139 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2140 const T* input1_data, const RuntimeShape& input2_shape,
2141 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2142 ruy::profiler::ScopeLabel label("Add/int32or64");
2143
2144 T activation_min, activation_max;
2145 GetActivationParams(params, &activation_min, &activation_max);
2146
2147 auto input1_map = MapAsVector(input1_data, input1_shape);
2148 auto input2_map = MapAsVector(input2_data, input2_shape);
2149 auto output_map = MapAsVector(output_data, output_shape);
2150 if (input1_shape == input2_shape) {
2151 output_map.array() = (input1_map.array() + input2_map.array())
2152 .cwiseMax(activation_min)
2153 .cwiseMin(activation_max);
2154 } else if (input2_shape.FlatSize() == 1) {
2155 auto scalar = input2_data[0];
2156 output_map.array() = (input1_map.array() + scalar)
2157 .cwiseMax(activation_min)
2158 .cwiseMin(activation_max);
2159 } else if (input1_shape.FlatSize() == 1) {
2160 auto scalar = input1_data[0];
2161 output_map.array() = (scalar + input2_map.array())
2162 .cwiseMax(activation_min)
2163 .cwiseMin(activation_max);
2164 } else {
2165 reference_ops::BroadcastAdd4DSlow<T>(params, input1_shape, input1_data,
2166 input2_shape, input2_data,
2167 output_shape, output_data);
2168 }
2169 }
2170
2171 template <typename T>
BroadcastAddDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2172 inline void BroadcastAddDispatch(
2173 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2174 const T* input1_data, const RuntimeShape& input2_shape,
2175 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2176 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2177 return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
2178 input2_data, output_shape, output_data);
2179 }
2180
2181 BinaryBroadcastFiveFold(
2182 params, input1_shape, input1_data, input2_shape, input2_data,
2183 output_shape, output_data,
2184 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2185 T*)>(AddElementwise),
2186 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2187 AddScalarBroadcast));
2188 }
2189
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2190 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2191 const RuntimeShape& unswitched_input1_shape,
2192 const uint8* unswitched_input1_data,
2193 const RuntimeShape& unswitched_input2_shape,
2194 const uint8* unswitched_input2_data,
2195 const RuntimeShape& output_shape,
2196 uint8* output_data) {
2197 BroadcastAddDispatch(unswitched_params, unswitched_input1_shape,
2198 unswitched_input1_data, unswitched_input2_shape,
2199 unswitched_input2_data, output_shape, output_data);
2200 }
2201
BroadcastAddFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2202 inline void BroadcastAddFivefold(const ArithmeticParams& params,
2203 const RuntimeShape& unswitched_input1_shape,
2204 const float* unswitched_input1_data,
2205 const RuntimeShape& unswitched_input2_shape,
2206 const float* unswitched_input2_data,
2207 const RuntimeShape& output_shape,
2208 float* output_data) {
2209 BroadcastAddDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2210 unswitched_input2_shape, unswitched_input2_data,
2211 output_shape, output_data);
2212 }
2213
MulElementwise(int size,const ArithmeticParams & params,const float * input1_data,const float * input2_data,float * output_data)2214 inline void MulElementwise(int size, const ArithmeticParams& params,
2215 const float* input1_data, const float* input2_data,
2216 float* output_data) {
2217 const float output_activation_min = params.float_activation_min;
2218 const float output_activation_max = params.float_activation_max;
2219
2220 int i = 0;
2221 #ifdef USE_NEON
2222 const auto activation_min = vdupq_n_f32(output_activation_min);
2223 const auto activation_max = vdupq_n_f32(output_activation_max);
2224 for (; i <= size - 16; i += 16) {
2225 auto a10 = vld1q_f32(input1_data + i);
2226 auto a11 = vld1q_f32(input1_data + i + 4);
2227 auto a12 = vld1q_f32(input1_data + i + 8);
2228 auto a13 = vld1q_f32(input1_data + i + 12);
2229 auto a20 = vld1q_f32(input2_data + i);
2230 auto a21 = vld1q_f32(input2_data + i + 4);
2231 auto a22 = vld1q_f32(input2_data + i + 8);
2232 auto a23 = vld1q_f32(input2_data + i + 12);
2233 auto x0 = vmulq_f32(a10, a20);
2234 auto x1 = vmulq_f32(a11, a21);
2235 auto x2 = vmulq_f32(a12, a22);
2236 auto x3 = vmulq_f32(a13, a23);
2237
2238 x0 = vmaxq_f32(activation_min, x0);
2239 x1 = vmaxq_f32(activation_min, x1);
2240 x2 = vmaxq_f32(activation_min, x2);
2241 x3 = vmaxq_f32(activation_min, x3);
2242 x0 = vminq_f32(activation_max, x0);
2243 x1 = vminq_f32(activation_max, x1);
2244 x2 = vminq_f32(activation_max, x2);
2245 x3 = vminq_f32(activation_max, x3);
2246
2247 vst1q_f32(output_data + i, x0);
2248 vst1q_f32(output_data + i + 4, x1);
2249 vst1q_f32(output_data + i + 8, x2);
2250 vst1q_f32(output_data + i + 12, x3);
2251 }
2252 for (; i <= size - 4; i += 4) {
2253 auto a1 = vld1q_f32(input1_data + i);
2254 auto a2 = vld1q_f32(input2_data + i);
2255 auto x = vmulq_f32(a1, a2);
2256
2257 x = vmaxq_f32(activation_min, x);
2258 x = vminq_f32(activation_max, x);
2259
2260 vst1q_f32(output_data + i, x);
2261 }
2262 #endif // NEON
2263
2264 for (; i < size; i++) {
2265 auto x = input1_data[i] * input2_data[i];
2266 output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
2267 output_activation_max);
2268 }
2269 }
2270
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2271 inline void Mul(const ArithmeticParams& params,
2272 const RuntimeShape& input1_shape, const float* input1_data,
2273 const RuntimeShape& input2_shape, const float* input2_data,
2274 const RuntimeShape& output_shape, float* output_data) {
2275 ruy::profiler::ScopeLabel label("Mul");
2276
2277 const int flat_size =
2278 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2279 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2280 }
2281
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2282 inline void Mul(const ArithmeticParams& params,
2283 const RuntimeShape& input1_shape, const int32* input1_data,
2284 const RuntimeShape& input2_shape, const int32* input2_data,
2285 const RuntimeShape& output_shape, int32* output_data) {
2286 ruy::profiler::ScopeLabel label("Mul/int32/activation");
2287
2288 const int flat_size =
2289 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2290 const int32 output_activation_min = params.quantized_activation_min;
2291 const int32 output_activation_max = params.quantized_activation_max;
2292 for (int i = 0; i < flat_size; ++i) {
2293 output_data[i] = ActivationFunctionWithMinMax(
2294 input1_data[i] * input2_data[i], output_activation_min,
2295 output_activation_max);
2296 }
2297 }
2298
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2299 inline void MulNoActivation(const ArithmeticParams& params,
2300 const RuntimeShape& input1_shape,
2301 const int32* input1_data,
2302 const RuntimeShape& input2_shape,
2303 const int32* input2_data,
2304 const RuntimeShape& output_shape,
2305 int32* output_data) {
2306 ruy::profiler::ScopeLabel label("Mul/int32");
2307
2308 auto input1_map = MapAsVector(input1_data, input1_shape);
2309 auto input2_map = MapAsVector(input2_data, input2_shape);
2310 auto output_map = MapAsVector(output_data, output_shape);
2311 if (input1_shape == input2_shape) {
2312 output_map.array() = input1_map.array() * input2_map.array();
2313 } else if (input2_shape.FlatSize() == 1) {
2314 auto scalar = input2_data[0];
2315 output_map.array() = input1_map.array() * scalar;
2316 } else if (input1_shape.FlatSize() == 1) {
2317 auto scalar = input1_data[0];
2318 output_map.array() = scalar * input2_map.array();
2319 } else {
2320 reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data,
2321 input2_shape, input2_data, output_shape,
2322 output_data);
2323 }
2324 }
2325
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2326 inline void Mul(const ArithmeticParams& params,
2327 const RuntimeShape& input1_shape, const int16* input1_data,
2328 const RuntimeShape& input2_shape, const int16* input2_data,
2329 const RuntimeShape& output_shape, int16* output_data) {
2330 ruy::profiler::ScopeLabel label("Mul/Int16/NoActivation");
2331 // This is a copy of the reference implementation. We do not currently have a
2332 // properly optimized version.
2333
2334 const int flat_size =
2335 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2336
2337 for (int i = 0; i < flat_size; i++) {
2338 // F0 uses 0 integer bits, range [-1, 1].
2339 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2340
2341 F0 unclamped_result =
2342 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2343 output_data[i] = unclamped_result.raw();
2344 }
2345 }
2346
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2347 inline void Mul(const ArithmeticParams& params,
2348 const RuntimeShape& input1_shape, const int16* input1_data,
2349 const RuntimeShape& input2_shape, const int16* input2_data,
2350 const RuntimeShape& output_shape, uint8* output_data) {
2351 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
2352 // This is a copy of the reference implementation. We do not currently have a
2353 // properly optimized version.
2354 const int32 output_activation_min = params.quantized_activation_min;
2355 const int32 output_activation_max = params.quantized_activation_max;
2356 const int32 output_offset = params.output_offset;
2357 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2358
2359 const int flat_size =
2360 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2361
2362 for (int i = 0; i < flat_size; i++) {
2363 // F0 uses 0 integer bits, range [-1, 1].
2364 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2365
2366 F0 unclamped_result =
2367 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
2368 int16 rescaled_result =
2369 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
2370 int16 clamped_result =
2371 std::min<int16>(output_activation_max - output_offset, rescaled_result);
2372 clamped_result =
2373 std::max<int16>(output_activation_min - output_offset, clamped_result);
2374 output_data[i] = output_offset + clamped_result;
2375 }
2376 }
2377
2378 // Element-wise mul that can often be used for inner loop of broadcast Mul as
2379 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2380 inline void MulElementwise(int size, const ArithmeticParams& params,
2381 const uint8* input1_data, const uint8* input2_data,
2382 uint8* output_data) {
2383 int i = 0;
2384 TFLITE_DCHECK_GT(params.input1_offset, -256);
2385 TFLITE_DCHECK_LT(params.input1_offset, 256);
2386 TFLITE_DCHECK_GT(params.input2_offset, -256);
2387 TFLITE_DCHECK_LT(params.input2_offset, 256);
2388 TFLITE_DCHECK_GT(params.output_offset, -256);
2389 TFLITE_DCHECK_LT(params.output_offset, 256);
2390 #ifdef USE_NEON
2391 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
2392 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2393 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2394 const auto output_activation_min_vector =
2395 vdup_n_u8(params.quantized_activation_min);
2396 const auto output_activation_max_vector =
2397 vdup_n_u8(params.quantized_activation_max);
2398 const int left_shift = std::max(0, params.output_shift);
2399 const int right_shift = std::max(0, -params.output_shift);
2400 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2401 for (; i <= size - 8; i += 8) {
2402 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2403 const auto input1_val_original = vld1_u8(input1_data + i);
2404 const auto input2_val_original = vld1_u8(input2_data + i);
2405 const auto input1_val_s16 =
2406 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2407 const auto input2_val_s16 =
2408 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2409 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
2410 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2411
2412 const auto input1_val_low = vget_low_s16(input1_val);
2413 const auto input1_val_high = vget_high_s16(input1_val);
2414 const auto input2_val_low = vget_low_s16(input2_val);
2415 const auto input2_val_high = vget_high_s16(input2_val);
2416
2417 auto p1 = vmull_s16(input2_val_low, input1_val_low);
2418 auto p2 = vmull_s16(input2_val_high, input1_val_high);
2419
2420 p1 = vshlq_s32(p1, left_shift_vec);
2421 p2 = vshlq_s32(p2, left_shift_vec);
2422 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2423 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2424 using gemmlowp::RoundingDivideByPOT;
2425 p1 = RoundingDivideByPOT(p1, right_shift);
2426 p2 = RoundingDivideByPOT(p2, right_shift);
2427
2428 const auto p1_narrowed = vqmovn_s32(p1);
2429 const auto p2_narrowed = vqmovn_s32(p2);
2430 const auto p =
2431 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2432 const auto clamped =
2433 vmax_u8(output_activation_min_vector,
2434 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2435 vst1_u8(output_data + i, clamped);
2436 }
2437 #endif // NEON
2438
2439 for (; i < size; ++i) {
2440 const int32 input1_val = params.input1_offset + input1_data[i];
2441 const int32 input2_val = params.input2_offset + input2_data[i];
2442 const int32 unclamped_result =
2443 params.output_offset +
2444 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2445 params.output_multiplier,
2446 params.output_shift);
2447 const int32 clamped_output =
2448 std::min(params.quantized_activation_max,
2449 std::max(params.quantized_activation_min, unclamped_result));
2450 output_data[i] = static_cast<uint8>(clamped_output);
2451 }
2452 }
2453
2454 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)2455 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2456 const uint8 broadcast_value,
2457 const uint8* input2_data, uint8* output_data) {
2458 const int16 input1_val = params.input1_offset + broadcast_value;
2459
2460 int i = 0;
2461 TFLITE_DCHECK_GT(params.input1_offset, -256);
2462 TFLITE_DCHECK_LT(params.input1_offset, 256);
2463 TFLITE_DCHECK_GT(params.input2_offset, -256);
2464 TFLITE_DCHECK_LT(params.input2_offset, 256);
2465 TFLITE_DCHECK_GT(params.output_offset, -256);
2466 TFLITE_DCHECK_LT(params.output_offset, 256);
2467 #ifdef USE_NEON
2468 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
2469 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
2470 const auto output_activation_min_vector =
2471 vdup_n_u8(params.quantized_activation_min);
2472 const auto output_activation_max_vector =
2473 vdup_n_u8(params.quantized_activation_max);
2474 const int left_shift = std::max(0, params.output_shift);
2475 const int right_shift = std::max(0, -params.output_shift);
2476 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
2477 for (; i <= size - 8; i += 8) {
2478 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
2479 const auto input2_val_original = vld1_u8(input2_data + i);
2480 const auto input2_val_s16 =
2481 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2482 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
2483
2484 const auto input2_val_low = vget_low_s16(input2_val);
2485 const auto input2_val_high = vget_high_s16(input2_val);
2486
2487 auto p1 = vmull_n_s16(input2_val_low, input1_val);
2488 auto p2 = vmull_n_s16(input2_val_high, input1_val);
2489
2490 p1 = vshlq_s32(p1, left_shift_vec);
2491 p2 = vshlq_s32(p2, left_shift_vec);
2492 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
2493 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
2494 using gemmlowp::RoundingDivideByPOT;
2495 p1 = RoundingDivideByPOT(p1, right_shift);
2496 p2 = RoundingDivideByPOT(p2, right_shift);
2497
2498 const auto p1_narrowed = vmovn_s32(p1);
2499 const auto p2_narrowed = vmovn_s32(p2);
2500 const auto p =
2501 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
2502 const auto clamped =
2503 vmax_u8(output_activation_min_vector,
2504 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
2505 vst1_u8(output_data + i, clamped);
2506 }
2507 #endif // NEON
2508
2509 for (; i < size; ++i) {
2510 const int32 input2_val = params.input2_offset + input2_data[i];
2511 const int32 unclamped_result =
2512 params.output_offset +
2513 MultiplyByQuantizedMultiplier(input1_val * input2_val,
2514 params.output_multiplier,
2515 params.output_shift);
2516 const int32 clamped_output =
2517 std::min(params.quantized_activation_max,
2518 std::max(params.quantized_activation_min, unclamped_result));
2519 output_data[i] = static_cast<uint8>(clamped_output);
2520 }
2521 }
2522
2523 // Broadcast mul that can often be used for inner loop of broadcast Mul.
2524 // This function will handle scalar_value (LHS) * vector_values (RHS).
2525 // Since it's a float function, input params does not matter here.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const float broadcast_value,const float * input2_data,float * output_data)2526 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
2527 const float broadcast_value,
2528 const float* input2_data, float* output_data) {
2529 int i = 0;
2530 #ifdef USE_NEON
2531 const float32x4_t output_activation_min_vector =
2532 vdupq_n_f32(params.float_activation_min);
2533 const float32x4_t output_activation_max_vector =
2534 vdupq_n_f32(params.float_activation_max);
2535 const float32x4_t broadcast_value_dup = vdupq_n_f32(broadcast_value);
2536 for (; i <= size - 4; i += 4) {
2537 const float32x4_t input2_val_original = vld1q_f32(input2_data + i);
2538
2539 const float32x4_t output =
2540 vmulq_f32(input2_val_original, broadcast_value_dup);
2541
2542 const float32x4_t clamped =
2543 vmaxq_f32(output_activation_min_vector,
2544 vminq_f32(output_activation_max_vector, output));
2545 vst1q_f32(output_data + i, clamped);
2546 }
2547 #endif // NEON
2548
2549 for (; i < size; ++i) {
2550 float x = broadcast_value * input2_data[i];
2551 output_data[i] = ActivationFunctionWithMinMax(
2552 x, params.float_activation_min, params.float_activation_max);
2553 }
2554 }
2555
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2556 inline void Mul(const ArithmeticParams& params,
2557 const RuntimeShape& input1_shape, const uint8* input1_data,
2558 const RuntimeShape& input2_shape, const uint8* input2_data,
2559 const RuntimeShape& output_shape, uint8* output_data) {
2560 TFLITE_DCHECK_LE(params.quantized_activation_min,
2561 params.quantized_activation_max);
2562 ruy::profiler::ScopeLabel label("Mul/8bit");
2563 const int flat_size =
2564 MatchingElementsSize(input1_shape, input2_shape, output_shape);
2565
2566 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
2567 }
2568
2569 template <typename T>
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2570 inline void BroadcastMulDispatch(
2571 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2572 const T* input1_data, const RuntimeShape& input2_shape,
2573 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2574 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
2575 return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
2576 input2_data, output_shape, output_data);
2577 }
2578
2579 BinaryBroadcastFiveFold(
2580 params, input1_shape, input1_data, input2_shape, input2_data,
2581 output_shape, output_data,
2582 static_cast<void (*)(int, const ArithmeticParams&, const T*, const T*,
2583 T*)>(MulElementwise),
2584 static_cast<void (*)(int, const ArithmeticParams&, T, const T*, T*)>(
2585 MulSimpleBroadcast));
2586 }
2587
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2588 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
2589 const RuntimeShape& unswitched_input1_shape,
2590 const uint8* unswitched_input1_data,
2591 const RuntimeShape& unswitched_input2_shape,
2592 const uint8* unswitched_input2_data,
2593 const RuntimeShape& output_shape,
2594 uint8* output_data) {
2595 BroadcastMulDispatch(unswitched_params, unswitched_input1_shape,
2596 unswitched_input1_data, unswitched_input2_shape,
2597 unswitched_input2_data, output_shape, output_data);
2598 }
2599
BroadcastMulFivefold(const ArithmeticParams & params,const RuntimeShape & unswitched_input1_shape,const float * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const float * unswitched_input2_data,const RuntimeShape & output_shape,float * output_data)2600 inline void BroadcastMulFivefold(const ArithmeticParams& params,
2601 const RuntimeShape& unswitched_input1_shape,
2602 const float* unswitched_input1_data,
2603 const RuntimeShape& unswitched_input2_shape,
2604 const float* unswitched_input2_data,
2605 const RuntimeShape& output_shape,
2606 float* output_data) {
2607 BroadcastMulDispatch(params, unswitched_input1_shape, unswitched_input1_data,
2608 unswitched_input2_shape, unswitched_input2_data,
2609 output_shape, output_data);
2610 }
2611
2612 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
2613 // dimensionality if the runtime code does a single loop over one dimension
2614 // that handles broadcasting as the base case. The code generator would then
2615 // generate max(D1, D2) nested for loops.
2616 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
2617 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
2618 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
2619 // reference_ops.h.
2620 template <typename T, int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)2621 void BroadcastDivSlow(const ArithmeticParams& params,
2622 const RuntimeShape& unextended_input1_shape,
2623 const T* input1_data,
2624 const RuntimeShape& unextended_input2_shape,
2625 const T* input2_data,
2626 const RuntimeShape& unextended_output_shape,
2627 T* output_data) {
2628 ruy::profiler::ScopeLabel label("BroadcastDivSlow");
2629 T output_activation_min;
2630 T output_activation_max;
2631 GetActivationParams(params, &output_activation_min, &output_activation_max);
2632
2633 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2634 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2635 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2636
2637 NdArrayDesc<N> desc1;
2638 NdArrayDesc<N> desc2;
2639 NdArrayDesc<N> output_desc;
2640 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2641 unextended_input2_shape, &desc1, &desc2);
2642 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2643 &output_desc);
2644
2645 // In Tensorflow, the dimensions are canonically named (batch_number, row,
2646 // col, channel), with extents (batches, height, width, depth), with the
2647 // trailing dimension changing most rapidly (channels has the smallest stride,
2648 // typically 1 element).
2649 //
2650 // In generated C code, we store arrays with the dimensions reversed. The
2651 // first dimension has smallest stride.
2652 //
2653 // We name our variables by their Tensorflow convention, but generate C code
2654 // nesting loops such that the innermost loop has the smallest stride for the
2655 // best cache behavior.
2656 auto div_func = [&](int indexes[N]) {
2657 output_data[SubscriptToIndex(output_desc, indexes)] =
2658 ActivationFunctionWithMinMax(
2659 input1_data[SubscriptToIndex(desc1, indexes)] /
2660 input2_data[SubscriptToIndex(desc2, indexes)],
2661 output_activation_min, output_activation_max);
2662 };
2663 NDOpsHelper<N>(output_desc, div_func);
2664 }
2665
2666 // BroadcastDiv is intentionally duplicated from reference_ops.h.
2667 // For more details see the comment above the generic version of
2668 // BroadcastDivSlow.
2669 template <int N = 5>
BroadcastDivSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const uint8 * input1_data,const RuntimeShape & unextended_input2_shape,const uint8 * input2_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)2670 inline void BroadcastDivSlow(const ArithmeticParams& params,
2671 const RuntimeShape& unextended_input1_shape,
2672 const uint8* input1_data,
2673 const RuntimeShape& unextended_input2_shape,
2674 const uint8* input2_data,
2675 const RuntimeShape& unextended_output_shape,
2676 uint8* output_data) {
2677 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), N);
2678 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), N);
2679 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), N);
2680
2681 NdArrayDesc<N> desc1;
2682 NdArrayDesc<N> desc2;
2683 NdArrayDesc<N> output_desc;
2684 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
2685 unextended_input2_shape, &desc1, &desc2);
2686 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
2687 &output_desc);
2688
2689 TFLITE_DCHECK_GT(params.input1_offset, -256);
2690 TFLITE_DCHECK_LT(params.input1_offset, 256);
2691 TFLITE_DCHECK_GT(params.input2_offset, -256);
2692 TFLITE_DCHECK_LT(params.input2_offset, 256);
2693 TFLITE_DCHECK_GT(params.output_offset, -256);
2694 TFLITE_DCHECK_LT(params.output_offset, 256);
2695
2696 auto div_func = [&](int indexes[N]) {
2697 const int32 input1_val =
2698 params.input1_offset + input1_data[SubscriptToIndex(desc1, indexes)];
2699 const int32 input2_val =
2700 params.input2_offset + input2_data[SubscriptToIndex(desc2, indexes)];
2701 TFLITE_DCHECK_NE(input2_val, 0);
2702 int recip_shift;
2703 const int32 input2_inv =
2704 (input2_val > 0) ? GetReciprocal(input2_val, 31, &recip_shift)
2705 : -GetReciprocal(-input2_val, 31, &recip_shift);
2706 const int headroom = CountLeadingSignBits(input1_val);
2707 const int32 unscaled_quotient = MultiplyByQuantizedMultiplierGreaterThanOne(
2708 input1_val, input2_inv, headroom);
2709 const int total_shift = params.output_shift - recip_shift - headroom;
2710 const int32 unclamped_result =
2711 params.output_offset +
2712 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2713 unscaled_quotient, params.output_multiplier, total_shift);
2714 const int32 clamped_output =
2715 std::min(params.quantized_activation_max,
2716 std::max(params.quantized_activation_min, unclamped_result));
2717 output_data[SubscriptToIndex(output_desc, indexes)] =
2718 static_cast<uint8>(clamped_output);
2719 };
2720 NDOpsHelper<N>(output_desc, div_func);
2721 }
2722
2723 template <typename T>
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2724 inline void SubWithActivation(
2725 const ArithmeticParams& params, const RuntimeShape& input1_shape,
2726 const T* input1_data, const RuntimeShape& input2_shape,
2727 const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
2728 ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
2729 TFLITE_DCHECK_EQ(input1_shape.FlatSize(), input2_shape.FlatSize());
2730 auto input1_map = MapAsVector(input1_data, input1_shape);
2731 auto input2_map = MapAsVector(input2_data, input2_shape);
2732 auto output_map = MapAsVector(output_data, output_shape);
2733 T activation_min, activation_max;
2734 GetActivationParams(params, &activation_min, &activation_max);
2735 output_map.array() = (input1_map.array() - input2_map.array())
2736 .cwiseMin(activation_max)
2737 .cwiseMax(activation_min);
2738 }
2739
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2740 inline void SubNonBroadcast(const ArithmeticParams& params,
2741 const RuntimeShape& input1_shape,
2742 const float* input1_data,
2743 const RuntimeShape& input2_shape,
2744 const float* input2_data,
2745 const RuntimeShape& output_shape,
2746 float* output_data) {
2747 ruy::profiler::ScopeLabel label("SubNonBroadcast");
2748 SubWithActivation<float>(params, input1_shape, input1_data, input2_shape,
2749 input2_data, output_shape, output_data);
2750 }
2751
2752 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)2753 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
2754 const T* input1_data, const RuntimeShape& input2_shape,
2755 const T* input2_data, const RuntimeShape& output_shape,
2756 T* output_data) {
2757 ruy::profiler::ScopeLabel label("Sub");
2758
2759 auto input1_map = MapAsVector(input1_data, input1_shape);
2760 auto input2_map = MapAsVector(input2_data, input2_shape);
2761 auto output_map = MapAsVector(output_data, output_shape);
2762 if (input1_shape == input2_shape) {
2763 output_map.array() = input1_map.array() - input2_map.array();
2764 } else if (input1_shape.FlatSize() == 1) {
2765 auto scalar = input1_data[0];
2766 output_map.array() = scalar - input2_map.array();
2767 } else if (input2_shape.FlatSize() == 1) {
2768 auto scalar = input2_data[0];
2769 output_map.array() = input1_map.array() - scalar;
2770 } else {
2771 BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
2772 input2_data, output_shape, output_data);
2773 }
2774 }
2775
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data,CpuBackendContext * cpu_backend_context)2776 inline void LstmCell(
2777 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2778 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
2779 const float* prev_activ_data, const RuntimeShape& weights_shape,
2780 const float* weights_data, const RuntimeShape& unextended_bias_shape,
2781 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
2782 const float* prev_state_data,
2783 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
2784 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
2785 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
2786 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data,
2787 CpuBackendContext* cpu_backend_context) {
2788 ruy::profiler::ScopeLabel label("LstmCell");
2789 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2790 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2791 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2792 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2793 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2794 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2795 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2796 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2797 const RuntimeShape input_shape =
2798 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2799 const RuntimeShape prev_activ_shape =
2800 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2801 const RuntimeShape bias_shape =
2802 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2803 const RuntimeShape prev_state_shape =
2804 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2805 const RuntimeShape output_state_shape =
2806 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2807 const RuntimeShape output_activ_shape =
2808 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2809 const RuntimeShape concat_temp_shape =
2810 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2811 const RuntimeShape activ_temp_shape =
2812 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2813 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2814
2815 const int weights_dim_count = weights_shape.DimensionsCount();
2816 MatchingDim( // batches
2817 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
2818 output_state_shape, 0, output_activ_shape, 0);
2819 MatchingDim( // height
2820 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
2821 output_state_shape, 1, output_activ_shape, 1);
2822 MatchingDim( // width
2823 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
2824 output_state_shape, 2, output_activ_shape, 2);
2825 const int input_depth = input_shape.Dims(3);
2826 const int prev_activ_depth = prev_activ_shape.Dims(3);
2827 const int total_input_depth = prev_activ_depth + input_depth;
2828 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2829 total_input_depth);
2830 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2831 const int intern_activ_depth =
2832 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2833 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2834 intern_activ_depth * total_input_depth);
2835 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2836 const int output_depth =
2837 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2838 3, output_activ_shape, 3);
2839 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2840
2841 // Concatenate prev_activ and input data together
2842 std::vector<float const*> concat_input_arrays_data;
2843 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
2844 concat_input_arrays_data.push_back(input_data);
2845 concat_input_arrays_data.push_back(prev_activ_data);
2846 concat_input_arrays_shapes.push_back(&input_shape);
2847 concat_input_arrays_shapes.push_back(&prev_activ_shape);
2848 tflite::ConcatenationParams concat_params;
2849 concat_params.axis = 3;
2850 concat_params.inputs_count = concat_input_arrays_data.size();
2851 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
2852 &(concat_input_arrays_data[0]), concat_temp_shape,
2853 concat_temp_data);
2854
2855 // Fully connected
2856 tflite::FullyConnectedParams fc_params;
2857 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
2858 fc_params.float_activation_max = std::numeric_limits<float>::max();
2859 fc_params.lhs_cacheable = false;
2860 fc_params.rhs_cacheable = false;
2861 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
2862 weights_data, bias_shape, bias_data, activ_temp_shape,
2863 activ_temp_data, cpu_backend_context);
2864
2865 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
2866 // operations.
2867 ArrayMap<float> activ_temp_map =
2868 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
2869 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
2870 activ_temp_map.cols());
2871 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
2872 activ_temp_map.cols());
2873 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
2874 activ_temp_map.cols());
2875 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
2876 activ_temp_map.cols());
2877 ArrayMap<const float> prev_state_map =
2878 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
2879 ArrayMap<float> output_state_map =
2880 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
2881 ArrayMap<float> output_activ_map =
2882 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
2883
2884 // Combined memory state and final output calculation
2885 ruy::profiler::ScopeLabel label2("MemoryStateAndFinalOutput");
2886 output_state_map =
2887 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2888 new_input_sm.tanh() +
2889 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2890 prev_state_map;
2891 output_activ_map =
2892 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
2893 output_state_map.tanh();
2894 }
2895
2896 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,CpuBackendContext * cpu_backend_context)2897 inline void LstmCell(
2898 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
2899 const uint8* input_data_uint8,
2900 const RuntimeShape& unextended_prev_activ_shape,
2901 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
2902 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
2903 const int32* bias_data_int32,
2904 const RuntimeShape& unextended_prev_state_shape,
2905 const int16* prev_state_data_int16,
2906 const RuntimeShape& unextended_output_state_shape,
2907 int16* output_state_data_int16,
2908 const RuntimeShape& unextended_output_activ_shape,
2909 uint8* output_activ_data_uint8,
2910 const RuntimeShape& unextended_concat_temp_shape,
2911 uint8* concat_temp_data_uint8,
2912 const RuntimeShape& unextended_activ_temp_shape,
2913 int16* activ_temp_data_int16, CpuBackendContext* cpu_backend_context) {
2914 ruy::profiler::ScopeLabel label(
2915 "LstmCell/quantized (8bit external, 16bit internal)");
2916 int32 weights_zero_point = params.weights_zero_point;
2917 int32 accum_multiplier = params.accum_multiplier;
2918 int accum_shift = params.accum_shift;
2919 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2920 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
2921 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
2922 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
2923 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
2924 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
2925 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
2926 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
2927 const RuntimeShape input_shape =
2928 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2929 const RuntimeShape prev_activ_shape =
2930 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
2931 const RuntimeShape bias_shape =
2932 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
2933 const RuntimeShape prev_state_shape =
2934 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
2935 const RuntimeShape output_state_shape =
2936 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
2937 const RuntimeShape output_activ_shape =
2938 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
2939 const RuntimeShape concat_temp_shape =
2940 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
2941 const RuntimeShape activ_temp_shape =
2942 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
2943 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
2944
2945 // Gather dimensions information, and perform consistency checks.
2946 const int weights_dim_count = weights_shape.DimensionsCount();
2947 const int outer_size = MatchingFlatSizeSkipDim(
2948 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
2949 output_activ_shape);
2950 const int input_depth = input_shape.Dims(3);
2951 const int prev_activ_depth = prev_activ_shape.Dims(3);
2952 const int total_input_depth = prev_activ_depth + input_depth;
2953 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
2954 total_input_depth);
2955 const int intern_activ_depth =
2956 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
2957 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
2958 intern_activ_depth * total_input_depth);
2959 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
2960 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
2961 const int output_depth =
2962 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
2963 3, output_activ_shape, 3);
2964 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
2965 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
2966 const int fc_output_depth =
2967 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
2968 const int fc_accum_depth = total_input_depth;
2969 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
2970
2971 // Depth-concatenate prev_activ and input data together.
2972 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
2973 prev_activ_data_uint8};
2974 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
2975 &prev_activ_shape};
2976 tflite::ConcatenationParams concat_params;
2977 concat_params.axis = 3;
2978 concat_params.inputs_count = 2;
2979 Concatenation(concat_params, concat_input_arrays_shapes,
2980 concat_input_arrays_data, concat_temp_shape,
2981 concat_temp_data_uint8);
2982
2983 // Implementation of the fully connected node inside the LSTM cell.
2984 // The operands are 8-bit integers, the accumulators are internally 32bit
2985 // integers, and the output is 16-bit fixed-point with 3 integer bits so
2986 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
2987 // is explained in the function comment above.
2988 cpu_backend_gemm::MatrixParams<uint8> lhs_params;
2989 lhs_params.rows = fc_output_depth;
2990 lhs_params.cols = fc_accum_depth;
2991 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
2992 lhs_params.zero_point = weights_zero_point;
2993 cpu_backend_gemm::MatrixParams<uint8> rhs_params;
2994 rhs_params.rows = fc_accum_depth;
2995 rhs_params.cols = fc_batches;
2996 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
2997 rhs_params.zero_point = 128;
2998 cpu_backend_gemm::MatrixParams<int16> dst_params;
2999 dst_params.rows = fc_output_depth;
3000 dst_params.cols = fc_batches;
3001 dst_params.order = cpu_backend_gemm::Order::kColMajor;
3002 dst_params.zero_point = 0;
3003 cpu_backend_gemm::GemmParams<int32, int16> gemm_params;
3004 gemm_params.bias = bias_data_int32;
3005 gemm_params.multiplier_fixedpoint = accum_multiplier;
3006 gemm_params.multiplier_exponent = accum_shift;
3007 cpu_backend_gemm::Gemm(
3008 lhs_params, weights_data_uint8, rhs_params, concat_temp_data_uint8,
3009 dst_params, activ_temp_data_int16, gemm_params, cpu_backend_context);
3010
3011 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3012 // and muls, all done in 16-bit fixed-point.
3013 const int16* input_gate_input_ptr = activ_temp_data_int16;
3014 const int16* input_modulation_gate_input_ptr =
3015 activ_temp_data_int16 + output_depth;
3016 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3017 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3018 const int16* prev_state_ptr = prev_state_data_int16;
3019 int16* output_state_data_ptr = output_state_data_int16;
3020 uint8* output_activ_data_ptr = output_activ_data_uint8;
3021
3022 for (int b = 0; b < outer_size; ++b) {
3023 int c = 0;
3024 #ifdef GEMMLOWP_NEON
3025 for (; c <= output_depth - 8; c += 8) {
3026 // Define the fixed-point data types that we will use here. All use
3027 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3028 // They only differ by the number of integral vs. fractional bits,
3029 // determining the range of values that they can represent.
3030 //
3031 // F0 uses 0 integer bits, range [-1, 1].
3032 // This is the return type of math functions such as tanh, logistic,
3033 // whose range is in [-1, 1].
3034 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3035 // F3 uses 3 integer bits, range [-8, 8].
3036 // This is the range of the previous fully-connected node's output,
3037 // which is our input here.
3038 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3039 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3040 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3041 // number of integer bits is currently dictated by the model. See comment
3042 // on the StateIntegerBits template parameter above.
3043 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3044 // Implementation of input gate, using fixed-point logistic function.
3045 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3046 input_gate_input_ptr += 8;
3047 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3048 // Implementation of input modulation gate, using fixed-point tanh
3049 // function.
3050 F3 input_modulation_gate_input =
3051 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3052 input_modulation_gate_input_ptr += 8;
3053 F0 input_modulation_gate_output =
3054 gemmlowp::tanh(input_modulation_gate_input);
3055 // Implementation of forget gate, using fixed-point logistic function.
3056 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3057 forget_gate_input_ptr += 8;
3058 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3059 // Implementation of output gate, using fixed-point logistic function.
3060 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3061 output_gate_input_ptr += 8;
3062 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3063 // Implementation of internal multiplication nodes, still in fixed-point.
3064 F0 input_times_input_modulation =
3065 input_gate_output * input_modulation_gate_output;
3066 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3067 prev_state_ptr += 8;
3068 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3069 // Implementation of internal addition node, saturating.
3070 FS new_state = gemmlowp::SaturatingAdd(
3071 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3072 prev_state_times_forget_state);
3073 // Implementation of last internal Tanh node, still in fixed-point.
3074 // Since a Tanh fixed-point implementation is specialized for a given
3075 // number or integer bits, and each specialization can have a substantial
3076 // code size, and we already used above a Tanh on an input with 3 integer
3077 // bits, and per the table in the above function comment there is no
3078 // significant accuracy to be lost by clamping to [-8, +8] for a
3079 // 3-integer-bits representation, let us just do that. This helps people
3080 // porting this to targets where code footprint must be minimized.
3081 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3082 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3083 // Store the new internal state back to memory, as 16-bit integers.
3084 // Note: here we store the original value with StateIntegerBits, not
3085 // the rescaled 3-integer-bits value fed to tanh.
3086 vst1q_s16(output_state_data_ptr, new_state.raw());
3087 output_state_data_ptr += 8;
3088 // Down-scale the output activations to 8-bit integers, saturating,
3089 // and store back to memory.
3090 int16x8_t rescaled_output_activ =
3091 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3092 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3093 uint8x8_t uint8_output_activ =
3094 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3095 vst1_u8(output_activ_data_ptr, uint8_output_activ);
3096 output_activ_data_ptr += 8;
3097 }
3098 #endif
3099 for (; c < output_depth; ++c) {
3100 // Define the fixed-point data types that we will use here. All use
3101 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3102 // They only differ by the number of integral vs. fractional bits,
3103 // determining the range of values that they can represent.
3104 //
3105 // F0 uses 0 integer bits, range [-1, 1].
3106 // This is the return type of math functions such as tanh, logistic,
3107 // whose range is in [-1, 1].
3108 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3109 // F3 uses 3 integer bits, range [-8, 8].
3110 // This is the range of the previous fully-connected node's output,
3111 // which is our input here.
3112 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3113 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3114 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3115 // number of integer bits is currently dictated by the model. See comment
3116 // on the StateIntegerBits template parameter above.
3117 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3118 // Implementation of input gate, using fixed-point logistic function.
3119 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3120 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3121 // Implementation of input modulation gate, using fixed-point tanh
3122 // function.
3123 F3 input_modulation_gate_input =
3124 F3::FromRaw(*input_modulation_gate_input_ptr++);
3125 F0 input_modulation_gate_output =
3126 gemmlowp::tanh(input_modulation_gate_input);
3127 // Implementation of forget gate, using fixed-point logistic function.
3128 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3129 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3130 // Implementation of output gate, using fixed-point logistic function.
3131 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3132 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3133 // Implementation of internal multiplication nodes, still in fixed-point.
3134 F0 input_times_input_modulation =
3135 input_gate_output * input_modulation_gate_output;
3136 FS prev_state = FS::FromRaw(*prev_state_ptr++);
3137 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3138 // Implementation of internal addition node, saturating.
3139 FS new_state = gemmlowp::SaturatingAdd(
3140 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3141 prev_state_times_forget_state);
3142 // Implementation of last internal Tanh node, still in fixed-point.
3143 // Since a Tanh fixed-point implementation is specialized for a given
3144 // number or integer bits, and each specialization can have a substantial
3145 // code size, and we already used above a Tanh on an input with 3 integer
3146 // bits, and per the table in the above function comment there is no
3147 // significant accuracy to be lost by clamping to [-8, +8] for a
3148 // 3-integer-bits representation, let us just do that. This helps people
3149 // porting this to targets where code footprint must be minimized.
3150 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3151 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3152 // Store the new internal state back to memory, as 16-bit integers.
3153 // Note: here we store the original value with StateIntegerBits, not
3154 // the rescaled 3-integer-bits value fed to tanh.
3155 *output_state_data_ptr++ = new_state.raw();
3156 // Down-scale the output activations to 8-bit integers, saturating,
3157 // and store back to memory.
3158 int16 rescaled_output_activ =
3159 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3160 int16 clamped_output_activ =
3161 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3162 *output_activ_data_ptr++ = 128 + clamped_output_activ;
3163 }
3164 input_gate_input_ptr += 3 * output_depth;
3165 input_modulation_gate_input_ptr += 3 * output_depth;
3166 forget_gate_input_ptr += 3 * output_depth;
3167 output_gate_input_ptr += 3 * output_depth;
3168 }
3169 }
3170
NodeOffset(int b,int h,int w,int height,int width)3171 inline int NodeOffset(int b, int h, int w, int height, int width) {
3172 return (b * height + h) * width + w;
3173 }
3174
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3175 inline bool AveragePool(const PoolParams& params,
3176 const RuntimeShape& input_shape,
3177 const float* input_data,
3178 const RuntimeShape& output_shape, float* output_data) {
3179 ruy::profiler::ScopeLabel label("AveragePool");
3180 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3181 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3182 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3183 const int input_height = input_shape.Dims(1);
3184 const int input_width = input_shape.Dims(2);
3185 const int output_height = output_shape.Dims(1);
3186 const int output_width = output_shape.Dims(2);
3187 const int stride_height = params.stride_height;
3188 const int stride_width = params.stride_width;
3189
3190 if (stride_height == 0) return false;
3191 if (stride_width == 0) return false;
3192
3193 // TODO(benoitjacob) make this a proper reference impl without Eigen!
3194 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3195 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3196 // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3197 Eigen::VectorXf out_count(out_mat.cols());
3198 out_count.setZero();
3199 // Prefill the output to 0.
3200 out_mat.setZero();
3201 for (int b = 0; b < batches; ++b) {
3202 for (int h = 0; h < input_height; ++h) {
3203 for (int w = 0; w < input_width; ++w) {
3204 // (h_start, h_end) * (w_start, w_end) is the range that the input
3205 // vector projects to.
3206 int hpad = h + params.padding_values.height;
3207 int wpad = w + params.padding_values.width;
3208 int h_start = (hpad < params.filter_height)
3209 ? 0
3210 : (hpad - params.filter_height) / stride_height + 1;
3211 int h_end = std::min(hpad / stride_height + 1, output_height);
3212 int w_start = (wpad < params.filter_width)
3213 ? 0
3214 : (wpad - params.filter_width) / stride_width + 1;
3215 int w_end = std::min(wpad / stride_width + 1, output_width);
3216 // compute elementwise sum
3217 for (int ph = h_start; ph < h_end; ++ph) {
3218 for (int pw = w_start; pw < w_end; ++pw) {
3219 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3220 out_mat.col(out_offset) +=
3221 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
3222 out_count(out_offset)++;
3223 }
3224 }
3225 }
3226 }
3227 }
3228 // Divide the output by the actual number of elements being averaged over
3229 TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
3230 out_mat.array().rowwise() /= out_count.transpose().array();
3231
3232 const int flat_size = output_shape.FlatSize();
3233 for (int i = 0; i < flat_size; ++i) {
3234 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3235 params.float_activation_min,
3236 params.float_activation_max);
3237 }
3238
3239 return true;
3240 }
3241
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3242 inline bool AveragePool(const PoolParams& params,
3243 const RuntimeShape& input_shape,
3244 const uint8* input_data,
3245 const RuntimeShape& output_shape, uint8* output_data) {
3246 ruy::profiler::ScopeLabel label("AveragePool/8bit");
3247
3248 // Here, and in other pooling ops, in order to maintain locality of reference,
3249 // to minimize some recalculations, and to load into NEON vector registers, we
3250 // use an inner loop down the depth. Since depths can be large and hence we
3251 // would need arbitrarily large temporary storage, we divide the work up into
3252 // depth tranches just within the batch loop.
3253 static constexpr int kPoolingAccTrancheSize = 256;
3254
3255 TFLITE_DCHECK_LE(params.quantized_activation_min,
3256 params.quantized_activation_max);
3257 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3258 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3259 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3260 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3261 const int input_height = input_shape.Dims(1);
3262 const int input_width = input_shape.Dims(2);
3263 const int output_height = output_shape.Dims(1);
3264 const int output_width = output_shape.Dims(2);
3265 const int stride_height = params.stride_height;
3266 const int stride_width = params.stride_width;
3267
3268 uint32 acc[kPoolingAccTrancheSize];
3269 for (int batch = 0; batch < batches; ++batch) {
3270 // We proceed through the depth in tranches (see comment above). The
3271 // depth_base is the depth at the beginning of the tranche. The
3272 // tranche_depth is the depth dimension of the tranche.
3273 for (int depth_base = 0; depth_base < depth;
3274 depth_base += kPoolingAccTrancheSize) {
3275 const int tranche_depth =
3276 std::min(depth - depth_base, kPoolingAccTrancheSize);
3277 for (int out_y = 0; out_y < output_height; ++out_y) {
3278 for (int out_x = 0; out_x < output_width; ++out_x) {
3279 const int in_x_origin =
3280 (out_x * stride_width) - params.padding_values.width;
3281 const int in_y_origin =
3282 (out_y * stride_height) - params.padding_values.height;
3283 const int filter_x_start = std::max(0, -in_x_origin);
3284 const int filter_x_end =
3285 std::min(params.filter_width, input_width - in_x_origin);
3286 const int filter_y_start = std::max(0, -in_y_origin);
3287 const int filter_y_end =
3288 std::min(params.filter_height, input_height - in_y_origin);
3289 const int filter_count =
3290 (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
3291 if (filter_count == 0) return false;
3292 memset(acc, 0, tranche_depth * sizeof(acc[0]));
3293 const uint8* input_ptr =
3294 input_data + depth_base +
3295 depth * (in_x_origin +
3296 input_width * (in_y_origin + input_height * batch));
3297 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3298 const uint8* input_row_ptr =
3299 input_ptr + depth * (fy * input_width + filter_x_start);
3300 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3301 const uint8* input_channel_ptr = input_row_ptr;
3302 int channel = 0;
3303 #ifdef USE_NEON
3304 for (; channel <= tranche_depth - 16; channel += 16) {
3305 uint16x4_t acc_reg[4];
3306 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3307 input_channel_ptr += 16;
3308 acc_reg[0] = vget_low_u16(vmovl_u8(vget_low_u8(input_reg)));
3309 acc_reg[1] = vget_high_u16(vmovl_u8(vget_low_u8(input_reg)));
3310 acc_reg[2] = vget_low_u16(vmovl_u8(vget_high_u8(input_reg)));
3311 acc_reg[3] = vget_high_u16(vmovl_u8(vget_high_u8(input_reg)));
3312 for (int i = 0; i < 4; i++) {
3313 vst1q_u32(
3314 acc + channel + 4 * i,
3315 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3316 }
3317 }
3318 for (; channel <= tranche_depth - 8; channel += 8) {
3319 uint16x4_t acc_reg[2];
3320 uint16x8_t input_reg = vmovl_u8(vld1_u8(input_channel_ptr));
3321 input_channel_ptr += 8;
3322 acc_reg[0] = vget_low_u16(input_reg);
3323 acc_reg[1] = vget_high_u16(input_reg);
3324 for (int i = 0; i < 2; i++) {
3325 vst1q_u32(
3326 acc + channel + 4 * i,
3327 vaddw_u16(vld1q_u32(acc + channel + 4 * i), acc_reg[i]));
3328 }
3329 }
3330 #endif
3331 for (; channel < tranche_depth; ++channel) {
3332 acc[channel] += *input_channel_ptr++;
3333 }
3334 input_row_ptr += depth;
3335 }
3336 }
3337 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3338 out_x, depth_base);
3339 int channel = 0;
3340 #ifdef USE_NEON
3341 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
3342 if (filter_count == FILTER_COUNT) { \
3343 for (; channel <= tranche_depth - 8; channel += 8) { \
3344 uint16 buf[8]; \
3345 for (int i = 0; i < 8; i++) { \
3346 buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
3347 } \
3348 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
3349 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
3350 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
3351 vst1_u8(output_ptr + channel, buf8); \
3352 } \
3353 }
3354 AVGPOOL_DIVIDING_BY(9)
3355 AVGPOOL_DIVIDING_BY(15)
3356 #undef AVGPOOL_DIVIDING_BY
3357 for (; channel <= tranche_depth - 8; channel += 8) {
3358 uint16 buf[8];
3359 for (int i = 0; i < 8; i++) {
3360 buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
3361 }
3362 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
3363 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
3364 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
3365 vst1_u8(output_ptr + channel, buf8);
3366 }
3367 #endif
3368 for (; channel < tranche_depth; ++channel) {
3369 uint16 a = (acc[channel] + filter_count / 2) / filter_count;
3370 a = std::max<uint16>(a, params.quantized_activation_min);
3371 a = std::min<uint16>(a, params.quantized_activation_max);
3372 output_ptr[channel] = static_cast<uint8>(a);
3373 }
3374 }
3375 }
3376 }
3377 }
3378 return true;
3379 }
3380
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3381 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3382 const float* input_data, const RuntimeShape& output_shape,
3383 float* output_data) {
3384 ruy::profiler::ScopeLabel label("MaxPool");
3385 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3386 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3387 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3388 const int input_height = input_shape.Dims(1);
3389 const int input_width = input_shape.Dims(2);
3390 const int output_height = output_shape.Dims(1);
3391 const int output_width = output_shape.Dims(2);
3392 const int stride_height = params.stride_height;
3393 const int stride_width = params.stride_width;
3394
3395 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3396 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3397 // Prefill the output to minimum representable float value
3398 out_mat.setConstant(std::numeric_limits<float>::lowest());
3399 for (int b = 0; b < batches; ++b) {
3400 for (int h = 0; h < input_height; ++h) {
3401 for (int w = 0; w < input_width; ++w) {
3402 // (h_start, h_end) * (w_start, w_end) is the range that the input
3403 // vector projects to.
3404 int hpad = h + params.padding_values.height;
3405 int wpad = w + params.padding_values.width;
3406 int h_start = (hpad < params.filter_height)
3407 ? 0
3408 : (hpad - params.filter_height) / stride_height + 1;
3409 int h_end = std::min(hpad / stride_height + 1, output_height);
3410 int w_start = (wpad < params.filter_width)
3411 ? 0
3412 : (wpad - params.filter_width) / stride_width + 1;
3413 int w_end = std::min(wpad / stride_width + 1, output_width);
3414 // compute elementwise sum
3415 for (int ph = h_start; ph < h_end; ++ph) {
3416 for (int pw = w_start; pw < w_end; ++pw) {
3417 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
3418 out_mat.col(out_offset) =
3419 out_mat.col(out_offset)
3420 .cwiseMax(in_mat.col(
3421 NodeOffset(b, h, w, input_height, input_width)));
3422 }
3423 }
3424 }
3425 }
3426 }
3427 const int flat_size = output_shape.FlatSize();
3428 for (int i = 0; i < flat_size; ++i) {
3429 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3430 params.float_activation_min,
3431 params.float_activation_max);
3432 }
3433 }
3434
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)3435 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
3436 const uint8* input_data, const RuntimeShape& output_shape,
3437 uint8* output_data) {
3438 ruy::profiler::ScopeLabel label("MaxPool/8bit");
3439
3440 // Here, and in other pooling ops, in order to maintain locality of reference,
3441 // to minimize some recalculations, and to load into NEON vector registers, we
3442 // use an inner loop down the depth. Since depths can be large and hence we
3443 // would need arbitrarily large temporary storage, we divide the work up into
3444 // depth tranches just within the batch loop.
3445 static constexpr int kPoolingAccTrancheSize = 256;
3446
3447 TFLITE_DCHECK_LE(params.quantized_activation_min,
3448 params.quantized_activation_max);
3449 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3450 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3451 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3452 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
3453 const int input_height = input_shape.Dims(1);
3454 const int input_width = input_shape.Dims(2);
3455 const int output_height = output_shape.Dims(1);
3456 const int output_width = output_shape.Dims(2);
3457 const int stride_height = params.stride_height;
3458 const int stride_width = params.stride_width;
3459
3460 uint8 acc[kPoolingAccTrancheSize];
3461 for (int batch = 0; batch < batches; ++batch) {
3462 // We proceed through the depth in tranches (see comment above). The
3463 // depth_base is the depth at the beginning of the tranche. The
3464 // tranche_depth is the depth dimension of the tranche.
3465 for (int depth_base = 0; depth_base < depth;
3466 depth_base += kPoolingAccTrancheSize) {
3467 const int tranche_depth =
3468 std::min(depth - depth_base, kPoolingAccTrancheSize);
3469 for (int out_y = 0; out_y < output_height; ++out_y) {
3470 for (int out_x = 0; out_x < output_width; ++out_x) {
3471 const int in_x_origin =
3472 (out_x * stride_width) - params.padding_values.width;
3473 const int in_y_origin =
3474 (out_y * stride_height) - params.padding_values.height;
3475 const int filter_x_start = std::max(0, -in_x_origin);
3476 const int filter_x_end =
3477 std::min(params.filter_width, input_width - in_x_origin);
3478 const int filter_y_start = std::max(0, -in_y_origin);
3479 const int filter_y_end =
3480 std::min(params.filter_height, input_height - in_y_origin);
3481 memset(acc, 0, tranche_depth * sizeof(acc[0]));
3482 const uint8* input_ptr =
3483 input_data + depth_base +
3484 depth * (in_x_origin +
3485 input_width * (in_y_origin + input_height * batch));
3486 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
3487 const uint8* input_row_ptr =
3488 input_ptr + depth * (fy * input_width + filter_x_start);
3489 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
3490 const uint8* input_channel_ptr = input_row_ptr;
3491 int channel = 0;
3492 #ifdef USE_NEON
3493 for (; channel <= tranche_depth - 16; channel += 16) {
3494 uint8x16_t acc_reg = vld1q_u8(acc + channel);
3495 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
3496 input_channel_ptr += 16;
3497 acc_reg = vmaxq_u8(acc_reg, input_reg);
3498 vst1q_u8(acc + channel, acc_reg);
3499 }
3500
3501 for (; channel <= tranche_depth - 8; channel += 8) {
3502 uint8x8_t acc_reg = vld1_u8(acc + channel);
3503 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
3504 input_channel_ptr += 8;
3505 acc_reg = vmax_u8(acc_reg, input_reg);
3506 vst1_u8(acc + channel, acc_reg);
3507 }
3508 #endif
3509 for (; channel < tranche_depth; ++channel) {
3510 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
3511 }
3512 input_row_ptr += depth;
3513 }
3514 }
3515 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
3516 out_x, depth_base);
3517 int channel = 0;
3518 #ifdef USE_NEON
3519 for (; channel <= tranche_depth - 16; channel += 16) {
3520 uint8x16_t a = vld1q_u8(acc + channel);
3521 a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
3522 a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
3523 vst1q_u8(output_ptr + channel, a);
3524 }
3525 for (; channel <= tranche_depth - 8; channel += 8) {
3526 uint8x8_t a = vld1_u8(acc + channel);
3527 a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
3528 a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
3529 vst1_u8(output_ptr + channel, a);
3530 }
3531 #endif
3532 for (; channel < tranche_depth; ++channel) {
3533 uint8 a = acc[channel];
3534 a = std::max<uint8>(a, params.quantized_activation_min);
3535 a = std::min<uint8>(a, params.quantized_activation_max);
3536 output_ptr[channel] = static_cast<uint8>(a);
3537 }
3538 }
3539 }
3540 }
3541 }
3542 }
3543
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3544 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
3545 const float* input_data, const RuntimeShape& output_shape,
3546 float* output_data) {
3547 ruy::profiler::ScopeLabel label("L2Pool");
3548 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3549 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3550 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3551 const int input_height = input_shape.Dims(1);
3552 const int input_width = input_shape.Dims(2);
3553 const int output_height = output_shape.Dims(1);
3554 const int output_width = output_shape.Dims(2);
3555 const int stride_height = params.stride_height;
3556 const int stride_width = params.stride_width;
3557 // Actually carry out L2 Pool. Code is written in forward mode: we go through
3558 // the input values once, and write to all the pooled regions that it maps to.
3559 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3560 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3561 Eigen::VectorXf in_square(in_mat.rows());
3562 Eigen::VectorXf out_count(out_mat.cols());
3563 out_count.setZero();
3564 // Prefill the output to 0.
3565 out_mat.setZero();
3566 for (int b = 0; b < batches; ++b) {
3567 for (int h = 0; h < input_height; ++h) {
3568 for (int w = 0; w < input_width; ++w) {
3569 // (h_start, h_end) * (w_start, w_end) is the range that the input
3570 // vector projects to.
3571 const int hpad = h + params.padding_values.height;
3572 const int wpad = w + params.padding_values.width;
3573 const int h_start =
3574 (hpad < params.filter_height)
3575 ? 0
3576 : (hpad - params.filter_height) / stride_height + 1;
3577 const int h_end = std::min(hpad / stride_height + 1, output_height);
3578 const int w_start =
3579 (wpad < params.filter_width)
3580 ? 0
3581 : (wpad - params.filter_width) / stride_width + 1;
3582 const int w_end = std::min(wpad / stride_width + 1, output_width);
3583 // pre-compute square
3584 const int in_offset = w + input_width * (h + input_height * b);
3585 in_square =
3586 in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
3587 // compute elementwise sum of squares
3588 for (int ph = h_start; ph < h_end; ++ph) {
3589 for (int pw = w_start; pw < w_end; ++pw) {
3590 const int out_offset = pw + output_width * (ph + output_height * b);
3591 out_mat.col(out_offset) += in_square;
3592 out_count(out_offset)++;
3593 }
3594 }
3595 }
3596 }
3597 }
3598
3599 out_count = out_count.array().inverse();
3600 out_mat =
3601 (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
3602
3603 const int flat_size = output_shape.FlatSize();
3604 for (int i = 0; i < flat_size; ++i) {
3605 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
3606 params.float_activation_min,
3607 params.float_activation_max);
3608 }
3609 }
3610
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3611 inline void LocalResponseNormalization(
3612 const tflite::LocalResponseNormalizationParams& op_params,
3613 const RuntimeShape& input_shape, const float* input_data,
3614 const RuntimeShape& output_shape, float* output_data) {
3615 ruy::profiler::ScopeLabel label("LocalResponseNormalization");
3616 MatchingFlatSize(input_shape, output_shape);
3617
3618 const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3619 auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3620
3621 // Carry out local response normalization, vector by vector.
3622 // Since the data are stored column major, making row-wise operation
3623 // probably not memory efficient anyway, we do an explicit for loop over
3624 // the columns.
3625 const int double_range = op_params.range * 2;
3626 Eigen::VectorXf padded_square(data_in.rows() + double_range);
3627 padded_square.setZero();
3628 const float bias = op_params.bias;
3629 for (int r = 0; r < data_in.cols(); ++r) {
3630 // Do local response normalization for data_in(:, r)
3631 // first, compute the square and store them in buffer for repeated use
3632 padded_square.block(op_params.range, 0, data_in.rows(), 1) =
3633 data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
3634 // Then, compute the scale and writes them to data_out
3635 float accumulated_scale = 0;
3636 for (int i = 0; i < double_range; ++i) {
3637 accumulated_scale += padded_square(i);
3638 }
3639 for (int i = 0; i < data_in.rows(); ++i) {
3640 accumulated_scale += padded_square(i + double_range);
3641 data_out(i, r) = bias + accumulated_scale;
3642 accumulated_scale -= padded_square(i);
3643 }
3644 }
3645
3646 // In a few cases, the pow computation could benefit from speedups.
3647 if (op_params.beta == 1) {
3648 data_out.array() = data_in.array() * data_out.array().inverse();
3649 } else if (op_params.beta == 0.5f) {
3650 data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
3651 } else {
3652 data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
3653 }
3654 }
3655
SoftmaxImpl(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data,int start_batch,int end_batch)3656 inline void SoftmaxImpl(const SoftmaxParams& params,
3657 const RuntimeShape& input_shape,
3658 const float* input_data,
3659 const RuntimeShape& output_shape, float* output_data,
3660 int start_batch, int end_batch) {
3661 ruy::profiler::ScopeLabel label("Softmax/Impl");
3662 MatchingFlatSize(input_shape, output_shape);
3663
3664 const int logit_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
3665 const MatrixMap<const float> in_mat(input_data + logit_size * start_batch,
3666 logit_size, end_batch - start_batch);
3667 MatrixMap<float> out_mat(output_data + logit_size * start_batch, logit_size,
3668 end_batch - start_batch);
3669 // Compute the exponential first, removing the max coefficient for numerical
3670 // stability.
3671 out_mat =
3672 (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
3673 // We are separating out the exp function so that exp can be vectorized.
3674 out_mat = out_mat.array().exp();
3675 // Normalize to get the activations.
3676 Eigen::Array<float, 1, Eigen::Dynamic> scale =
3677 out_mat.array().colwise().sum().inverse();
3678 out_mat.array().rowwise() *= scale;
3679 }
3680
3681 struct SoftmaxWorkerTask : cpu_backend_threadpool::Task {
SoftmaxWorkerTaskSoftmaxWorkerTask3682 SoftmaxWorkerTask(const SoftmaxParams& params,
3683 const RuntimeShape& input_shape, const float* input_data,
3684 const RuntimeShape& output_shape, float* output_data,
3685 int start_batch, int end_batch)
3686 : params(params),
3687 input_shape(input_shape),
3688 input_data(input_data),
3689 output_shape(output_shape),
3690 output_data(output_data),
3691 start_batch(start_batch),
3692 end_batch(end_batch) {}
RunSoftmaxWorkerTask3693 void Run() override {
3694 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data,
3695 start_batch, end_batch);
3696 }
3697
3698 private:
3699 const tflite::SoftmaxParams& params;
3700 const RuntimeShape& input_shape;
3701 const float* input_data;
3702 const RuntimeShape& output_shape;
3703 float* output_data;
3704 int start_batch;
3705 int end_batch;
3706 };
3707
3708 inline void Softmax(const SoftmaxParams& params,
3709 const RuntimeShape& input_shape, const float* input_data,
3710 const RuntimeShape& output_shape, float* output_data,
3711 CpuBackendContext* cpu_backend_context = nullptr) {
3712 ruy::profiler::ScopeLabel label("Softmax");
3713
3714 // We picture softmax input as a 2-D matrix while the last dim is the logit
3715 // dim, and the rest dims will be the batch dim for the 2-D matrix.
3716 const int batch_size =
3717 FlatSizeSkipDim(input_shape, input_shape.DimensionsCount() - 1);
3718 constexpr int kMinBatchPerThread = 8;
3719 int thread_count = batch_size / kMinBatchPerThread;
3720 thread_count = thread_count > 0 ? thread_count : 1;
3721 const int capped_thread_count =
3722 cpu_backend_context == nullptr
3723 ? 1
3724 : std::min(thread_count, cpu_backend_context->max_num_threads());
3725 if (capped_thread_count == 1) {
3726 SoftmaxImpl(params, input_shape, input_data, output_shape, output_data, 0,
3727 batch_size);
3728 } else {
3729 std::vector<SoftmaxWorkerTask> tasks;
3730 // TODO(b/131746020) don't create new heap allocations every time.
3731 // At least we make it a single heap allocation by using reserve().
3732 tasks.reserve(capped_thread_count);
3733 int batch_start = 0;
3734 for (int i = 0; i < capped_thread_count; ++i) {
3735 // Try to distribute the tasks as even as possible.
3736 int batch_end =
3737 batch_start + (batch_size - batch_start) / (capped_thread_count - i);
3738 tasks.emplace_back(params, input_shape, input_data, output_shape,
3739 output_data, batch_start, batch_end);
3740 batch_start = batch_end;
3741 }
3742 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
3743 cpu_backend_context);
3744 }
3745 }
3746
3747 template <typename T>
QuantizeSoftmaxOutput(float prob_rescaled,int32_t zero_point)3748 inline int32_t QuantizeSoftmaxOutput(float prob_rescaled, int32_t zero_point) {
3749 const int32_t prob_rnd = static_cast<int32_t>(std::round(prob_rescaled));
3750 return prob_rnd + zero_point;
3751 }
3752
3753 #if !__aarch64__
3754 // With ARM64, rounding is faster than add + truncation.
3755 template <>
3756 inline int32_t QuantizeSoftmaxOutput<uint8_t>(float prob_rescaled,
3757 int32_t zero_point) {
3758 return static_cast<int32_t>(prob_rescaled + 0.5f);
3759 }
3760 #endif
3761
PopulateSoftmaxLookupTable(SoftmaxParams * data,float input_scale,float beta)3762 inline void PopulateSoftmaxLookupTable(SoftmaxParams* data, float input_scale,
3763 float beta) {
3764 const float scale = -input_scale * beta;
3765 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3766 for (int32_t val = 0; val <= max_uint8; ++val) {
3767 data->table[max_uint8 - val] = expf(scale * val);
3768 }
3769 }
3770
3771 template <typename In, typename Out>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3772 inline void Softmax(const SoftmaxParams& params,
3773 const RuntimeShape& input_shape, const In* input_data,
3774 const RuntimeShape& output_shape, Out* output_data) {
3775 const int trailing_dim = input_shape.DimensionsCount() - 1;
3776 const int excluding_last_dim =
3777 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3778 const int last_dim =
3779 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3780
3781 const int32_t clamp_max = std::numeric_limits<Out>::max();
3782 const int32_t clamp_min = std::numeric_limits<Out>::min();
3783 for (int i = 0; i < excluding_last_dim; ++i) {
3784 int32_t max_val = std::numeric_limits<In>::min();
3785 // Find max quantized value.
3786 for (int j = 0; j < last_dim; ++j) {
3787 max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
3788 }
3789
3790 float sum_exp = 0.0f;
3791 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3792 const float* table_offset = ¶ms.table[max_uint8 - max_val];
3793 // Calculate normalizer sum(exp(x)).
3794 for (int j = 0; j < last_dim; ++j) {
3795 sum_exp += table_offset[input_data[j]];
3796 }
3797
3798 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
3799 // Normalize and quantize probabilities.
3800 for (int j = 0; j < last_dim; ++j) {
3801 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
3802 const int32_t prob_quantized =
3803 QuantizeSoftmaxOutput<Out>(prob_rescaled, params.zero_point);
3804 output_data[j] = static_cast<Out>(
3805 std::max(std::min(clamp_max, prob_quantized), clamp_min));
3806 }
3807 input_data += last_dim;
3808 output_data += last_dim;
3809 }
3810 }
3811
3812 // Here's the softmax LUT optimization strategy:
3813 // For softmax, we can do some mathmetically equivalent transformation:
3814 //
3815 // softmax(x) = e^x / sum(e^x, 0...n) ===> equals to
3816 // softmax(x) = e^(x - CONST) / sum(e^(x - CONST), 0...n)
3817 //
3818 // For quantization, `x` in our case is (input_q - input_zp) * input_s
3819 // For uint8 case (int8 can be handled similarly), the range is [0, 255]
3820 //
3821 // so if we let
3822 // CONST = (255 - input_zp) * input_s
3823 // then we will have:
3824 // softmax(x) = e^((input_q - 255) * input_s) --------- (1)
3825 // /
3826 // sum(e^(input_q - 255) * input_s, 0...n) -------- (2)
3827 //
3828 // the good thing about (1) is it's within the range of (0, 1), so we can
3829 // approximate its result with uint16.
3830 // (1) = uint8_out * 1 / 2^16.
3831 //
3832 // so (1) is lookup_uint8_table(input_zp) * 1 / 2^16.
3833 // then (2) is essentially the following:
3834 // sum(lookup_uint8_table(input_zp), 0...n) / 2^16.
3835 //
3836 // since (output_q - output_zp) * output_s = softmax(x)
3837 // output_q = lookup_uint8_table(input_zp)
3838 // /
3839 // (sum(lookup_uint8_table(input_zp), 0...n) * output_s)
3840 // +
3841 // output_zp
3842 //
3843 // We can actually further improve the performance by using uint8 instead of
3844 // uint16. But that we may lose some accuracy, so we need to pay attention
3845 // to that.
PopulateSoftmaxUInt8LookupTable(SoftmaxParams * data,float input_scale,float beta)3846 inline void PopulateSoftmaxUInt8LookupTable(SoftmaxParams* data,
3847 float input_scale, float beta) {
3848 const float scale = input_scale * beta;
3849 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3850 const int32_t max_uint16 = std::numeric_limits<uint16_t>::max();
3851
3852 for (int32_t val = 0; val <= max_uint8; ++val) {
3853 float input_to_exp = scale * (val - max_uint8);
3854 int32_t temp = static_cast<int>(expf(input_to_exp) * max_uint16 + 0.5);
3855 temp = std::min(max_uint16, temp);
3856 uint8_t part1 = temp >> 8;
3857 uint8_t part2 = temp & 0xff;
3858 data->uint8_table1[val] = static_cast<uint8_t>(part1);
3859 data->uint8_table2[val] = static_cast<uint8_t>(part2);
3860 }
3861 }
3862
FindMaxValue(int size,const uint8_t * input_data,uint8_t offset)3863 inline int FindMaxValue(int size, const uint8_t* input_data, uint8_t offset) {
3864 int32_t max_val = std::numeric_limits<uint8_t>::min();
3865 int j = 0;
3866 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3867 uint8x16_t max_val_dup = vdupq_n_u8(max_val);
3868 uint8x16_t offset_dup = vdupq_n_u8(offset);
3869 for (; j <= size - 16; j += 16) {
3870 uint8x16_t input_value = vld1q_u8(input_data + j);
3871 input_value = veorq_u8(input_value, offset_dup);
3872 max_val_dup = vmaxq_u8(input_value, max_val_dup);
3873 }
3874 max_val = std::max(max_val, static_cast<int32>(vmaxvq_u8(max_val_dup)));
3875 #endif
3876
3877 for (; j < size; ++j) {
3878 max_val = std::max(max_val, static_cast<int32_t>(input_data[j] ^ offset));
3879 }
3880 return max_val;
3881 }
3882
3883 #ifdef USE_NEON
3884 // Value_to_store layout:
3885 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,int8_t * output)3886 inline void StoreValue(int32x4x4_t value_to_store, int8_t* output) {
3887 const int16x8_t result_1 = vcombine_s16(vqmovn_s32(value_to_store.val[1]),
3888 vqmovn_s32(value_to_store.val[0]));
3889 const int16x8_t result_2 = vcombine_s16(vqmovn_s32(value_to_store.val[3]),
3890 vqmovn_s32(value_to_store.val[2]));
3891 const int8x16_t result =
3892 vcombine_s8(vqmovn_s16(result_2), vqmovn_s16(result_1));
3893 vst1q_s8(output, result);
3894 }
3895
3896 // Value_to_store layout:
3897 // [high_high, high_low, low_high, low_low].
StoreValue(int32x4x4_t value_to_store,uint8_t * output)3898 inline void StoreValue(int32x4x4_t value_to_store, uint8_t* output) {
3899 const uint16x8_t result_1 =
3900 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[1])),
3901 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[0])));
3902 const uint16x8_t result_2 =
3903 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[3])),
3904 vqmovn_u32(vreinterpretq_u32_s32(value_to_store.val[2])));
3905 const uint8x16_t result =
3906 vcombine_u8(vqmovn_u16(result_2), vqmovn_u16(result_1));
3907 vst1q_u8(output, result);
3908 }
3909
3910 #endif
3911
3912 template <typename In, typename Out>
SoftmaxInt8LUT(const SoftmaxParams & params,const RuntimeShape & input_shape,const In * input_data,const RuntimeShape & output_shape,Out * output_data)3913 inline void SoftmaxInt8LUT(const SoftmaxParams& params,
3914 const RuntimeShape& input_shape,
3915 const In* input_data,
3916 const RuntimeShape& output_shape, Out* output_data) {
3917 const int trailing_dim = input_shape.DimensionsCount() - 1;
3918 const int excluding_last_dim =
3919 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
3920 const int last_dim =
3921 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
3922
3923 const int32_t clamp_max = std::numeric_limits<Out>::max();
3924 const int32_t clamp_min = std::numeric_limits<Out>::min();
3925
3926 // Offset is used to interpret the input data "correctly".
3927 // If the input is uint8, the data will be unchanged.
3928 // If the input is int8, since it will be reinterpret as uint8.
3929 // e.g.,
3930 // int8 127 will be applied "offset" to become 255 in uint8.
3931 uint8_t offset = 0;
3932 if (std::is_same<In, int8>::value) {
3933 offset = 0x80;
3934 }
3935
3936 const uint8_t* input_data_uint = reinterpret_cast<const uint8_t*>(input_data);
3937
3938 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3939 // This code uses ARM64-only instructions.
3940 // TODO(b/143709993): Port to ARMv7
3941
3942 // Load the tables into registers. (4*4 128-bit registers)
3943 uint8x16x4_t table1[4];
3944 table1[0] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 0);
3945 table1[1] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 1);
3946 table1[2] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 2);
3947 table1[3] = vld1q_u8_x4(params.uint8_table1 + 16 * 4 * 3);
3948
3949 uint8x16x4_t table2[4];
3950 table2[0] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 0);
3951 table2[1] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 1);
3952 table2[2] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 2);
3953 table2[3] = vld1q_u8_x4(params.uint8_table2 + 16 * 4 * 3);
3954 #endif
3955
3956 for (int i = 0; i < excluding_last_dim; ++i) {
3957 // Find max quantized value.
3958 int32_t max_val = FindMaxValue(last_dim, input_data_uint, offset);
3959
3960 int32 sum_exp = 0;
3961 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
3962 const uint8_t table_offset = max_uint8 - max_val;
3963
3964 // Calculate normalizer sum(exp(x)).
3965 int sum_j = 0;
3966 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
3967 uint8x16_t table_offset_dup = vdupq_n_u8(table_offset);
3968 uint8x16_t offset_dup = vdupq_n_u8(offset);
3969 uint32x4_t sum_4 = vdupq_n_u32(0);
3970 const int multiplier_shift = 8;
3971 for (; sum_j <= last_dim - 16; sum_j += 16) {
3972 uint8x16_t input_value = vld1q_u8(input_data_uint + sum_j);
3973 input_value = veorq_u8(input_value, offset_dup);
3974 input_value = vaddq_u8(input_value, table_offset_dup);
3975
3976 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
3977 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
3978
3979 uint16x8_t exp_value1 =
3980 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
3981 uint16x8_t exp_value2 =
3982 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
3983
3984 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
3985 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
3986
3987 sum_4 = vpadalq_u16(sum_4, exp_value1);
3988 sum_4 = vpadalq_u16(sum_4, exp_value2);
3989 }
3990 int temp = vgetq_lane_u32(sum_4, 0) + vgetq_lane_u32(sum_4, 1) +
3991 vgetq_lane_u32(sum_4, 2) + vgetq_lane_u32(sum_4, 3);
3992 sum_exp += temp;
3993
3994 #endif
3995 for (; sum_j < last_dim; ++sum_j) {
3996 const uint8_t index = (input_data_uint[sum_j] ^ offset) + table_offset;
3997
3998 uint8_t part1 = params.uint8_table1[index];
3999 uint8_t part2 = params.uint8_table2[index];
4000 sum_exp += ((part1 << 8) + part2);
4001 }
4002
4003 const float inv_sum_exp = 1.0f / (sum_exp * params.scale);
4004
4005 int32 multiplier, shift;
4006 QuantizeMultiplier(inv_sum_exp, &multiplier, &shift);
4007
4008 // Normalize and quantize probabilities.
4009 int j = 0;
4010 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
4011 const int32x4_t output_zp_dup = vdupq_n_s32(params.zero_point);
4012 const int32x4_t max_val_dup = vdupq_n_s32(clamp_max);
4013 const int32x4_t min_val_dup = vdupq_n_s32(clamp_min);
4014
4015 for (; j <= last_dim - 16; j += 16) {
4016 uint8x16_t input_value = vld1q_u8(input_data_uint + j);
4017 input_value = veorq_u8(input_value, offset_dup);
4018 input_value = vaddq_u8(input_value, table_offset_dup);
4019
4020 const uint8x16_t output1 = aarch64_lookup_vector(table1, input_value);
4021 const uint8x16_t output2 = aarch64_lookup_vector(table2, input_value);
4022
4023 uint16x8_t exp_value1 =
4024 vshll_n_u8(vget_high_u8(output1), multiplier_shift);
4025 uint16x8_t exp_value2 =
4026 vshll_n_u8(vget_low_u8(output1), multiplier_shift);
4027
4028 exp_value1 = vaddw_u8(exp_value1, vget_high_u8(output2));
4029 exp_value2 = vaddw_u8(exp_value2, vget_low_u8(output2));
4030
4031 int32x4x4_t output_value;
4032 output_value.val[0] =
4033 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value1)));
4034 output_value.val[1] =
4035 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value1)));
4036 output_value.val[2] =
4037 vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(exp_value2)));
4038 output_value.val[3] =
4039 vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(exp_value2)));
4040
4041 int32x4x4_t temp_val =
4042 MultiplyByQuantizedMultiplier4Rows(output_value, multiplier, shift);
4043
4044 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
4045 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
4046 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
4047 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
4048
4049 temp_val.val[0] =
4050 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
4051 temp_val.val[1] =
4052 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
4053 temp_val.val[2] =
4054 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
4055 temp_val.val[3] =
4056 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
4057
4058 StoreValue(temp_val, output_data + j);
4059 }
4060 #endif
4061 for (; j < last_dim; ++j) {
4062 const uint8_t index = (input_data_uint[j] ^ offset) + table_offset;
4063 const uint8_t part1 = params.uint8_table1[index];
4064 const uint8_t part2 = params.uint8_table2[index];
4065 const int32_t exp_value = (part1 << 8) + part2;
4066 const int32_t output_value =
4067 MultiplyByQuantizedMultiplier(exp_value, multiplier, shift);
4068
4069 output_data[j] = static_cast<Out>(std::max(
4070 std::min(clamp_max, output_value + params.zero_point), clamp_min));
4071 }
4072 input_data_uint += last_dim;
4073 output_data += last_dim;
4074 }
4075 }
4076
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4077 inline void LogSoftmax(const SoftmaxParams& params,
4078 const RuntimeShape& input_shape, const float* input_data,
4079 const RuntimeShape& output_shape, float* output_data) {
4080 ruy::profiler::ScopeLabel label("LogSoftmax");
4081 const int trailing_dim = input_shape.DimensionsCount() - 1;
4082 const int outer_size =
4083 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4084 const int depth =
4085 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4086
4087 for (int i = 0; i < outer_size; ++i) {
4088 VectorMap<const float> block_input(input_data + i * depth, depth, 1);
4089 VectorMap<float> block_output(output_data + i * depth, depth, 1);
4090 // Find max element value which we'll use to ensure numerical stability
4091 // taking advantage of the following equality:
4092 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4093 const float max = block_input.maxCoeff();
4094 const float log_sum = std::log((block_input.array() - max).exp().sum());
4095 block_output = block_input.array() - max - log_sum;
4096 }
4097 }
4098
4099 // Backwards compatibility. Less optimized than below version.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4100 inline void LogSoftmax(const SoftmaxParams& params,
4101 const RuntimeShape& input_shape, const uint8* input_data,
4102 const RuntimeShape& output_shape, uint8* output_data) {
4103 reference_ops::LogSoftmax(params, input_shape, input_data, output_shape,
4104 output_data);
4105 }
4106
4107 // Compute LogSoftmax as (x - x_max) - ln(sum(e^(x_i - x_max)...)
4108 // as done in tf.nn.log_softmax to prevent underflow and overflow.
4109 // This is in contrast to just log(softmax(x))
4110 //
4111 // To handle quantization, first dequantize the inputs (from doing
4112 // e^(input scale * val) where we ignore the zero point since it cancels
4113 // out during subtraction due to the ln) and do a rescale at the end to int8.
4114 //
4115 // Notably this makes use of float and is intended as the optimized
4116 // form for quantized execution on CPU. For a fully integer version,
4117 // see the reference op.
4118 //
4119 // TODO(tflite): notes for optimization:
4120 // 1) See if e^ is also bottleneck in the reference fully-integer
4121 // version and apply lookup there and compare.
4122 template <typename T>
LogSoftmax(const SoftmaxParams & params,float input_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4123 inline void LogSoftmax(const SoftmaxParams& params, float input_scale,
4124 const RuntimeShape& input_shape, const T* input_data,
4125 const RuntimeShape& output_shape, T* output_data) {
4126 ruy::profiler::ScopeLabel label("LogSoftmax");
4127 const int trailing_dim = input_shape.DimensionsCount() - 1;
4128 const int excluding_last_dim =
4129 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4130 const int last_dim =
4131 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4132
4133 const int32_t clamp_max = std::numeric_limits<T>::max();
4134 const int32_t clamp_min = std::numeric_limits<T>::min();
4135
4136 for (int i = 0; i < excluding_last_dim; ++i) {
4137 T max_val = std::numeric_limits<T>::min();
4138 // Find max quantized value.
4139 for (int j = 0; j < last_dim; ++j) {
4140 max_val = std::max(max_val, input_data[j]);
4141 }
4142
4143 float sum_exp = 0.0f;
4144 const int32_t max_uint8 = std::numeric_limits<uint8>::max();
4145 // Offset into table to compute exp(scale*(x - xmax)) instead of
4146 // exp(scale*(x)) to prevent overflow.
4147 const float* table_offset = ¶ms.table[max_uint8 - max_val];
4148 // Calculate sum(exp(scale*(x - x_max))).
4149 for (int j = 0; j < last_dim; ++j) {
4150 sum_exp += table_offset[input_data[j]];
4151 }
4152 const float log_sum_exp = std::log(sum_exp);
4153
4154 // params.scale is the output scale.
4155 const float scale = input_scale / params.scale;
4156 const float precomputed =
4157 (input_scale * max_val + log_sum_exp) / params.scale;
4158 for (int j = 0; j < last_dim; ++j) {
4159 // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) /
4160 // output_scale.
4161 const float log_prob = scale * input_data[j] - precomputed;
4162
4163 // TODO(tflite): look into better solution.
4164 // Use std::rint over std::round (which is used in
4165 // FakeQuant) since it's multiple times faster on tested arm32.
4166 const int32_t prob_quantized = std::rint(log_prob) + params.zero_point;
4167 output_data[j] = static_cast<T>(
4168 std::max(std::min(clamp_max, prob_quantized), clamp_min));
4169 }
4170 input_data += last_dim;
4171 output_data += last_dim;
4172 }
4173 }
4174
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4175 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4176 const RuntimeShape& output_shape, float* output_data) {
4177 ruy::profiler::ScopeLabel label("Logistic");
4178 auto input_map = MapAsVector(input_data, input_shape);
4179 auto output_map = MapAsVector(output_data, output_shape);
4180 output_map.array() =
4181 input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4182 }
4183
4184 // Convenience version that allows, for example, generated-code calls to be
4185 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4186 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4187 const float* input_data, const RuntimeShape& output_shape,
4188 float* output_data) {
4189 // Drop params: not needed.
4190 Logistic(input_shape, input_data, output_shape, output_data);
4191 }
4192
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4193 inline void Logistic(const LogisticParams& params,
4194 const RuntimeShape& input_shape, const int16* input_data,
4195 const RuntimeShape& output_shape, int16* output_data) {
4196 ruy::profiler::ScopeLabel label("Logistic/Int16");
4197 const int flat_size = MatchingFlatSize(input_shape, output_shape);
4198
4199 for (int i = 0; i < flat_size; i++) {
4200 }
4201
4202 int c = 0;
4203 const int16* input_data_ptr = input_data;
4204 int16* output_data_ptr = output_data;
4205 #ifdef GEMMLOWP_NEON
4206 {
4207 // F0 uses 0 integer bits, range [-1, 1].
4208 // This is the return type of math functions such as tanh, logistic,
4209 // whose range is in [-1, 1].
4210 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4211 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4212 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4213
4214 for (; c <= flat_size - 16; c += 16) {
4215 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4216 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4217 F0 output0 = gemmlowp::logistic(input0);
4218 F0 output1 = gemmlowp::logistic(input1);
4219 vst1q_s16(output_data_ptr, output0.raw());
4220 vst1q_s16(output_data_ptr + 8, output1.raw());
4221
4222 input_data_ptr += 16;
4223 output_data_ptr += 16;
4224 }
4225 for (; c <= flat_size - 8; c += 8) {
4226 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4227 F0 output = gemmlowp::logistic(input);
4228 vst1q_s16(output_data_ptr, output.raw());
4229
4230 input_data_ptr += 8;
4231 output_data_ptr += 8;
4232 }
4233 }
4234 #endif
4235 #ifdef GEMMLOWP_SSE4
4236 {
4237 // F0 uses 0 integer bits, range [-1, 1].
4238 // This is the return type of math functions such as tanh, logistic,
4239 // whose range is in [-1, 1].
4240 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4241 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4242 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4243
4244 for (; c <= flat_size - 16; c += 16) {
4245 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4246 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4247 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4248 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4249 F0 output0 = gemmlowp::logistic(input0);
4250 F0 output1 = gemmlowp::logistic(input1);
4251 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4252 output0.raw().v);
4253 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4254 output1.raw().v);
4255 input_data_ptr += 16;
4256 output_data_ptr += 16;
4257 }
4258 for (; c <= flat_size - 8; c += 8) {
4259 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4260 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4261 F0 output = gemmlowp::logistic(input);
4262 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4263 output.raw().v);
4264 input_data_ptr += 8;
4265 output_data_ptr += 8;
4266 }
4267 }
4268 #endif
4269
4270 {
4271 // F0 uses 0 integer bits, range [-1, 1].
4272 // This is the return type of math functions such as tanh, logistic,
4273 // whose range is in [-1, 1].
4274 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4275 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4276 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4277
4278 for (; c < flat_size; ++c) {
4279 F3 input = F3::FromRaw(*input_data_ptr);
4280 F0 output = gemmlowp::logistic(input);
4281 *output_data_ptr = output.raw();
4282
4283 ++input_data_ptr;
4284 ++output_data_ptr;
4285 }
4286 }
4287 }
4288
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4289 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
4290 const RuntimeShape& output_shape, float* output_data) {
4291 ruy::profiler::ScopeLabel label("Tanh");
4292 auto input_map = MapAsVector(input_data, input_shape);
4293 auto output_map = MapAsVector(output_data, output_shape);
4294 output_map.array() = input_map.array().tanh();
4295 }
4296
4297 // Convenience version that allows, for example, generated-code calls to be
4298 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4299 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
4300 const float* input_data, const RuntimeShape& output_shape,
4301 float* output_data) {
4302 // Drop params: not needed.
4303 Tanh(input_shape, input_data, output_shape, output_data);
4304 }
4305
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4306 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
4307 const int16* input_data, const RuntimeShape& output_shape,
4308 int16* output_data) {
4309 ruy::profiler::ScopeLabel label("Tanh/Int16");
4310 const int input_left_shift = params.input_left_shift;
4311 // Support for shifts is limited until we have a parameterized version of
4312 // SaturatingRoundingMultiplyByPOT().
4313 TFLITE_DCHECK_GE(input_left_shift, 0);
4314 TFLITE_DCHECK_LE(input_left_shift, 1);
4315
4316 const int flat_size = MatchingFlatSize(input_shape, output_shape);
4317
4318 int c = 0;
4319 const int16* input_data_ptr = input_data;
4320 int16* output_data_ptr = output_data;
4321 #ifdef GEMMLOWP_NEON
4322 {
4323 // F0 uses 0 integer bits, range [-1, 1].
4324 // This is the return type of math functions such as tanh, logistic,
4325 // whose range is in [-1, 1].
4326 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4327 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4328 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4329
4330 if (input_left_shift == 0) {
4331 for (; c <= flat_size - 16; c += 16) {
4332 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4333 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4334 F0 output0 = gemmlowp::tanh(input0);
4335 F0 output1 = gemmlowp::tanh(input1);
4336 vst1q_s16(output_data_ptr, output0.raw());
4337 vst1q_s16(output_data_ptr + 8, output1.raw());
4338
4339 input_data_ptr += 16;
4340 output_data_ptr += 16;
4341 }
4342 for (; c <= flat_size - 8; c += 8) {
4343 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4344 F0 output = gemmlowp::tanh(input);
4345 vst1q_s16(output_data_ptr, output.raw());
4346
4347 input_data_ptr += 8;
4348 output_data_ptr += 8;
4349 }
4350 } else {
4351 for (; c <= flat_size - 16; c += 16) {
4352 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4353 vld1q_s16(input_data_ptr)));
4354 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4355 vld1q_s16(input_data_ptr + 8)));
4356 F0 output0 = gemmlowp::tanh(input0);
4357 F0 output1 = gemmlowp::tanh(input1);
4358 vst1q_s16(output_data_ptr, output0.raw());
4359 vst1q_s16(output_data_ptr + 8, output1.raw());
4360
4361 input_data_ptr += 16;
4362 output_data_ptr += 16;
4363 }
4364 for (; c <= flat_size - 8; c += 8) {
4365 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4366 vld1q_s16(input_data_ptr)));
4367 F0 output = gemmlowp::tanh(input);
4368 vst1q_s16(output_data_ptr, output.raw());
4369
4370 input_data_ptr += 8;
4371 output_data_ptr += 8;
4372 }
4373 }
4374 }
4375 #endif
4376 #ifdef GEMMLOWP_SSE4
4377 {
4378 // F0 uses 0 integer bits, range [-1, 1].
4379 // This is the return type of math functions such as tanh, logistic,
4380 // whose range is in [-1, 1].
4381 using F0 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 0>;
4382 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4383 using F3 = gemmlowp::FixedPoint<gemmlowp::int16x8_m128i, 3>;
4384
4385 if (input_left_shift == 0) {
4386 for (; c <= flat_size - 16; c += 16) {
4387 F3 input0 = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4388 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4389 F3 input1 = F3::FromRaw(gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4390 reinterpret_cast<const __m128i*>(input_data_ptr + 8))));
4391 F0 output0 = gemmlowp::tanh(input0);
4392 F0 output1 = gemmlowp::tanh(input1);
4393 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4394 output0.raw().v);
4395 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4396 output1.raw().v);
4397
4398 input_data_ptr += 16;
4399 output_data_ptr += 16;
4400 }
4401 for (; c <= flat_size - 8; c += 8) {
4402 F3 input = F3::FromRaw(gemmlowp::to_int16x8_m128i(
4403 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_data_ptr))));
4404 F0 output = gemmlowp::tanh(input);
4405 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4406 output.raw().v);
4407 input_data_ptr += 8;
4408 output_data_ptr += 8;
4409 }
4410 } else {
4411 for (; c <= flat_size - 16; c += 16) {
4412 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4413 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4414 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4415 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4416 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4417 reinterpret_cast<const __m128i*>(input_data_ptr + 8)))));
4418 F0 output0 = gemmlowp::tanh(input0);
4419 F0 output1 = gemmlowp::tanh(input1);
4420 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4421 output0.raw().v);
4422 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr + 8),
4423 output1.raw().v);
4424
4425 input_data_ptr += 16;
4426 output_data_ptr += 16;
4427 }
4428 for (; c <= flat_size - 8; c += 8) {
4429 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
4430 gemmlowp::to_int16x8_m128i(_mm_loadu_si128(
4431 reinterpret_cast<const __m128i*>(input_data_ptr)))));
4432 F0 output = gemmlowp::tanh(input);
4433 _mm_storeu_si128(reinterpret_cast<__m128i*>(output_data_ptr),
4434 output.raw().v);
4435 input_data_ptr += 8;
4436 output_data_ptr += 8;
4437 }
4438 }
4439 }
4440 #endif
4441
4442 {
4443 // F0 uses 0 integer bits, range [-1, 1].
4444 // This is the return type of math functions such as tanh, logistic,
4445 // whose range is in [-1, 1].
4446 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4447 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4448 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4449
4450 if (input_left_shift == 0) {
4451 for (; c < flat_size; ++c) {
4452 F3 input = F3::FromRaw(*input_data_ptr);
4453 F0 output = gemmlowp::tanh(input);
4454 *output_data_ptr = output.raw();
4455
4456 ++input_data_ptr;
4457 ++output_data_ptr;
4458 }
4459 } else {
4460 for (; c < flat_size; ++c) {
4461 F3 input = F3::FromRaw(
4462 gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
4463 F0 output = gemmlowp::tanh(input);
4464 *output_data_ptr = output.raw();
4465
4466 ++input_data_ptr;
4467 ++output_data_ptr;
4468 }
4469 }
4470 }
4471 }
4472
4473 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)4474 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
4475 const RuntimeShape& output_shape, DstT* output_data) {
4476 ruy::profiler::ScopeLabel label("Cast");
4477 auto input_map = MapAsVector(input_data, input_shape);
4478 auto output_map = MapAsVector(output_data, output_shape);
4479 output_map.array() = input_map.array().template cast<DstT>();
4480 }
4481
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4482 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4483 const RuntimeShape& output_shape, float* output_data) {
4484 ruy::profiler::ScopeLabel label("Floor");
4485 auto input_map = MapAsVector(input_data, input_shape);
4486 auto output_map = MapAsVector(output_data, output_shape);
4487 output_map.array() = Eigen::floor(input_map.array());
4488 }
4489
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4490 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
4491 const RuntimeShape& output_shape, float* output_data) {
4492 ruy::profiler::ScopeLabel label("Ceil");
4493 auto input_map = MapAsVector(input_data, input_shape);
4494 auto output_map = MapAsVector(output_data, output_shape);
4495 output_map.array() = Eigen::ceil(input_map.array());
4496 }
4497
4498 // Helper methods for BatchToSpaceND.
4499 // `spatial_index_dim` specifies post-crop offset index in this spatial
4500 // dimension, i.e. spatial offset introduced by flattening batch to spatial
4501 // dimension minus the crop size at beginning. `block_shape_dim` is the block
4502 // size in current dimension. `input_dim` and `output_dim` are input and output
4503 // size of BatchToSpaceND operation in current dimension.
4504 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)4505 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
4506 int input_dim, int output_dim, int* start_index,
4507 int* end_index) {
4508 // (*start_index) * block_shape_dim is effectively rounded up to the next
4509 // multiple of block_shape_dim by the integer division.
4510 *start_index =
4511 std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4512 // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
4513 // end_index is exclusive).
4514 *end_index = std::min(
4515 input_dim,
4516 (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
4517 }
4518
4519 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)4520 inline void BatchToSpaceND(
4521 const RuntimeShape& unextended_input1_shape, const T* input1_data,
4522 const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
4523 const RuntimeShape& unextended_input3_shape, const int32* crops_data,
4524 const RuntimeShape& unextended_output_shape, T* output_data) {
4525 ruy::profiler::ScopeLabel label("BatchToSpaceND");
4526
4527 TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
4528 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
4529 TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
4530 unextended_output_shape.DimensionsCount());
4531
4532 // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
4533 auto extend_shape = [](const RuntimeShape& shape) {
4534 if (shape.DimensionsCount() == 4) {
4535 return shape;
4536 }
4537 RuntimeShape new_shape(4, 1);
4538 new_shape.SetDim(0, shape.Dims(0));
4539 new_shape.SetDim(1, shape.Dims(1));
4540 new_shape.SetDim(3, shape.Dims(2));
4541 return new_shape;
4542 };
4543 const RuntimeShape input1_shape = extend_shape(unextended_input1_shape);
4544 const RuntimeShape output_shape = extend_shape(unextended_output_shape);
4545
4546 const int output_width = output_shape.Dims(2);
4547 const int output_height = output_shape.Dims(1);
4548 const int output_batch_size = output_shape.Dims(0);
4549
4550 const int depth = input1_shape.Dims(3);
4551 const int input_width = input1_shape.Dims(2);
4552 const int input_height = input1_shape.Dims(1);
4553 const int input_batch_size = input1_shape.Dims(0);
4554
4555 const int block_shape_height = block_shape_data[0];
4556 const int block_shape_width =
4557 unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
4558 const int crops_top = crops_data[0];
4559 const int crops_left =
4560 unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
4561
4562 for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
4563 const int out_batch = in_batch % output_batch_size;
4564 const int spatial_offset = in_batch / output_batch_size;
4565
4566 int in_h_start = 0;
4567 int in_h_end = 0;
4568 // GetIndexRange ensures start and end indices are in [0, output_height).
4569 GetIndexRange(spatial_offset / block_shape_width - crops_top,
4570 block_shape_height, input_height, output_height, &in_h_start,
4571 &in_h_end);
4572
4573 for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
4574 const int out_h = in_h * block_shape_height +
4575 spatial_offset / block_shape_width - crops_top;
4576 TFLITE_DCHECK_GE(out_h, 0);
4577 TFLITE_DCHECK_LT(out_h, output_height);
4578
4579 int in_w_start = 0;
4580 int in_w_end = 0;
4581 // GetIndexRange ensures start and end indices are in [0, output_width).
4582 GetIndexRange(spatial_offset % block_shape_width - crops_left,
4583 block_shape_width, input_width, output_width, &in_w_start,
4584 &in_w_end);
4585
4586 for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
4587 const int out_w = in_w * block_shape_width +
4588 spatial_offset % block_shape_width - crops_left;
4589 TFLITE_DCHECK_GE(out_w, 0);
4590 TFLITE_DCHECK_LT(out_w, output_width);
4591 T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
4592 const T* in =
4593 input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
4594 memcpy(out, in, depth * sizeof(T));
4595 }
4596 }
4597 }
4598 }
4599
4600 template <typename T>
TypedMemset(void * ptr,T value,size_t num)4601 void TypedMemset(void* ptr, T value, size_t num) {
4602 // Optimization for common cases where memset() will suffice.
4603 if (value == 0 || std::is_same<T, uint8_t>::value) {
4604 memset(ptr, value, num * sizeof(T));
4605 } else {
4606 // Default implementation for cases where memset() will not preserve the
4607 // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
4608 char* pos = static_cast<char*>(ptr);
4609 for (size_t i = 0; i < num; ++i) {
4610 memcpy(pos, &value, sizeof(T));
4611 pos = pos + sizeof(T);
4612 }
4613 }
4614 }
4615
4616 // This makes heavy use of Offset, along with conditional branches. There may be
4617 // opportunities for improvement.
4618 //
4619 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
4620 // scalar input that provides the padding value. Therefore pad_value_ptr can be
4621 // equivalent to a simple input1_data. For Pad, it should point to a zero
4622 // value.
4623 //
4624 // Note that two typenames are required, so that T=P=int32 is considered a
4625 // specialization distinct from P=int32.
4626 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4627 inline void PadImpl(const tflite::PadParams& op_params,
4628 const RuntimeShape& input_shape, const T* input_data,
4629 const P* pad_value_ptr, const RuntimeShape& output_shape,
4630 T* output_data) {
4631 ruy::profiler::ScopeLabel label("PadImpl");
4632 const int max_supported_dims = 5;
4633 const RuntimeShape ext_input_shape =
4634 RuntimeShape::ExtendedShape(max_supported_dims, input_shape);
4635 const RuntimeShape ext_output_shape =
4636 RuntimeShape::ExtendedShape(max_supported_dims, output_shape);
4637 TFLITE_DCHECK_LE(op_params.left_padding_count, max_supported_dims);
4638 TFLITE_DCHECK_LE(op_params.right_padding_count, max_supported_dims);
4639
4640 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4641 // to 4 dims (yes, we are "padding the padding").
4642 std::vector<int> left_padding_copy(max_supported_dims, 0);
4643 const int left_padding_extend =
4644 max_supported_dims - op_params.left_padding_count;
4645 for (int i = 0; i < op_params.left_padding_count; ++i) {
4646 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4647 }
4648 std::vector<int> right_padding_copy(max_supported_dims, 0);
4649 const int right_padding_extend =
4650 max_supported_dims - op_params.right_padding_count;
4651 for (int i = 0; i < op_params.right_padding_count; ++i) {
4652 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4653 }
4654
4655 const int output_batch = ext_output_shape.Dims(0);
4656 const int output_spatial_dim1 = ext_output_shape.Dims(1);
4657 const int output_spatial_dim2 = ext_output_shape.Dims(2);
4658 const int output_spatial_dim3 = ext_output_shape.Dims(3);
4659 const int output_channel = ext_output_shape.Dims(4);
4660
4661 const int left_b_padding = left_padding_copy[0];
4662 const int left_s1_padding = left_padding_copy[1];
4663 const int left_s2_padding = left_padding_copy[2];
4664 const int left_s3_padding = left_padding_copy[3];
4665 const int left_c_padding = left_padding_copy[4];
4666
4667 const int right_b_padding = right_padding_copy[0];
4668 const int right_s1_padding = right_padding_copy[1];
4669 const int right_s2_padding = right_padding_copy[2];
4670 const int right_s3_padding = right_padding_copy[3];
4671 const int right_c_padding = right_padding_copy[4];
4672
4673 const int input_depth = ext_input_shape.Dims(4);
4674 const T pad_value = *pad_value_ptr;
4675
4676 if (left_b_padding != 0) {
4677 TypedMemset<T>(output_data, pad_value,
4678 left_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4679 output_spatial_dim3 * output_channel);
4680 }
4681 for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
4682 ++out_b) {
4683 if (left_s1_padding != 0) {
4684 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0, 0),
4685 pad_value,
4686 left_s1_padding * output_spatial_dim2 *
4687 output_spatial_dim3 * output_channel);
4688 }
4689 for (int out_p = left_s1_padding;
4690 out_p < output_spatial_dim1 - right_s1_padding; ++out_p) {
4691 if (left_s2_padding != 0) {
4692 TypedMemset<T>(
4693 output_data + Offset(ext_output_shape, out_b, out_p, 0, 0, 0),
4694 pad_value, left_s2_padding * output_spatial_dim3 * output_channel);
4695 }
4696 for (int out_h = left_s2_padding;
4697 out_h < output_spatial_dim2 - right_s2_padding; ++out_h) {
4698 if (left_s3_padding != 0) {
4699 TypedMemset<T>(
4700 output_data + Offset(ext_output_shape, out_b, out_p, out_h, 0, 0),
4701 pad_value, left_s3_padding * output_channel);
4702 }
4703 for (int out_w = left_s3_padding;
4704 out_w < output_spatial_dim3 - right_s3_padding; ++out_w) {
4705 if (left_c_padding != 0) {
4706 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_p,
4707 out_h, out_w, 0),
4708 pad_value, left_c_padding);
4709 }
4710
4711 T* out = output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4712 out_w, left_c_padding);
4713 const T* in = input_data +
4714 Offset(ext_input_shape, out_b - left_b_padding,
4715 out_p - left_s1_padding, out_h - left_s2_padding,
4716 out_w - left_s3_padding, 0);
4717 memcpy(out, in, input_depth * sizeof(T));
4718
4719 if (right_c_padding != 0) {
4720 TypedMemset<T>(
4721 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4722 out_w, output_channel - right_c_padding),
4723 pad_value, right_c_padding);
4724 }
4725 }
4726 if (right_s3_padding != 0) {
4727 TypedMemset<T>(
4728 output_data + Offset(ext_output_shape, out_b, out_p, out_h,
4729 output_spatial_dim3 - right_s3_padding, 0),
4730 pad_value, right_s3_padding * output_channel);
4731 }
4732 }
4733 if (right_s2_padding != 0) {
4734 TypedMemset<T>(
4735 output_data + Offset(ext_output_shape, out_b, out_p,
4736 output_spatial_dim2 - right_s2_padding, 0, 0),
4737 pad_value, right_s2_padding * output_spatial_dim3 * output_channel);
4738 }
4739 }
4740 if (right_s1_padding != 0) {
4741 TypedMemset<T>(
4742 output_data + Offset(ext_output_shape, out_b,
4743 output_spatial_dim1 - right_s1_padding, 0, 0, 0),
4744 pad_value,
4745 right_s1_padding * output_spatial_dim2 * output_spatial_dim3 *
4746 output_channel);
4747 }
4748 }
4749 if (right_b_padding != 0) {
4750 TypedMemset<T>(
4751 output_data + Offset(ext_output_shape, output_batch - right_b_padding,
4752 0, 0, 0, 0),
4753 pad_value,
4754 right_b_padding * output_spatial_dim1 * output_spatial_dim2 *
4755 output_spatial_dim3 * output_channel);
4756 }
4757 }
4758
4759 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4760 inline void Pad(const tflite::PadParams& op_params,
4761 const RuntimeShape& input_shape, const T* input_data,
4762 const P* pad_value_ptr, const RuntimeShape& output_shape,
4763 T* output_data) {
4764 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4765 output_data);
4766 }
4767
4768 // The second (pad-value) input can be int32 when, say, the first is uint8.
4769 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4770 inline void Pad(const tflite::PadParams& op_params,
4771 const RuntimeShape& input_shape, const T* input_data,
4772 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4773 T* output_data) {
4774 const T converted_pad_value = static_cast<T>(*pad_value_ptr);
4775 PadImpl(op_params, input_shape, input_data, &converted_pad_value,
4776 output_shape, output_data);
4777 }
4778
4779 // This version avoids conflicting template matching.
4780 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)4781 inline void Pad(const tflite::PadParams& op_params,
4782 const RuntimeShape& input_shape, const int32* input_data,
4783 const int32* pad_value_ptr, const RuntimeShape& output_shape,
4784 int32* output_data) {
4785 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4786 output_data);
4787 }
4788
4789 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
4790 //
4791 // This pad requires that (a) left and right paddings are in the 4D patterns
4792 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
4793 // T is uint8.
4794 //
4795 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
4796 // scalar input that provides the padding value. Therefore pad_value_ptr can be
4797 // equivalent to a simple input1_data. For Pad, it should point to a zero
4798 // value.
4799 //
4800 // Note that two typenames are required, so that T=P=int32 is considered a
4801 // specialization distinct from P=int32.
4802 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4803 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
4804 const RuntimeShape& input_shape,
4805 const T* input_data, const P* pad_value_ptr,
4806 const RuntimeShape& output_shape,
4807 T* output_data) {
4808 ruy::profiler::ScopeLabel label("PadImageStyle");
4809 const RuntimeShape ext_input_shape =
4810 RuntimeShape::ExtendedShape(4, input_shape);
4811 const RuntimeShape ext_output_shape =
4812 RuntimeShape::ExtendedShape(4, output_shape);
4813 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
4814 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
4815
4816 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
4817 // to 4 dims (yes, we are "padding the padding").
4818 std::vector<int> left_padding_copy(4, 0);
4819 const int left_padding_extend = 4 - op_params.left_padding_count;
4820 for (int i = 0; i < op_params.left_padding_count; ++i) {
4821 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
4822 }
4823 std::vector<int> right_padding_copy(4, 0);
4824 const int right_padding_extend = 4 - op_params.right_padding_count;
4825 for (int i = 0; i < op_params.right_padding_count; ++i) {
4826 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
4827 }
4828 // The following padding restrictions are contractual requirements, and
4829 // embody what it means for a padding op to be "image-style".
4830 TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
4831 TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
4832 TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
4833 TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
4834
4835 const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
4836 const int output_height = ext_output_shape.Dims(1);
4837 const int output_width = ext_output_shape.Dims(2);
4838 const int input_height = ext_input_shape.Dims(1);
4839 const int input_width = ext_input_shape.Dims(2);
4840 const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
4841
4842 const int left_h_padding = left_padding_copy[1];
4843 const int left_w_padding = left_padding_copy[2];
4844 const int right_h_padding = right_padding_copy[1];
4845 const int right_w_padding = right_padding_copy[2];
4846
4847 TFLITE_DCHECK_EQ(output_height,
4848 input_height + left_h_padding + right_h_padding);
4849 TFLITE_DCHECK_EQ(output_width,
4850 input_width + left_w_padding + right_w_padding);
4851
4852 const T pad_value = *pad_value_ptr;
4853 const int top_block_size = left_h_padding * output_width * depth;
4854 const size_t num_top_block_bytes = top_block_size * sizeof(T);
4855 const int bottom_block_size = right_h_padding * output_width * depth;
4856 const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
4857 const int left_blocks_size = left_w_padding * depth;
4858 const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
4859 const int right_blocks_size = right_w_padding * depth;
4860 const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
4861 const int inner_line_size = input_width * depth;
4862 const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
4863
4864 if (input_height == 0) {
4865 memset(output_data, pad_value,
4866 num_top_block_bytes + num_bottom_block_bytes);
4867 } else {
4868 for (int i = 0; i < batch; ++i) {
4869 // For each image in the batch, apply the top padding, then iterate
4870 // through rows, then apply the bottom padding.
4871 //
4872 // By unwinding one iteration, we can combine the first left-margin
4873 // padding with the top padding, and the last right-margin padding with
4874 // the bottom padding.
4875 memset(output_data, pad_value,
4876 num_top_block_bytes + num_left_block_bytes);
4877 output_data += top_block_size + left_blocks_size;
4878 memcpy(output_data, input_data, num_inner_line_bytes);
4879 input_data += inner_line_size;
4880 output_data += inner_line_size;
4881 // One iteration unwound.
4882 // Unwinding this loop affords the opportunity to reorder the loop work
4883 // and hence combine memset() calls.
4884 //
4885 // Before unwinding:
4886 // for (int j = 0; j < input_height; ++j) {
4887 // // Pad on left, copy central data, pad on right.
4888 // memset(output_data, pad_value, num_left_block_bytes);
4889 // output_data += left_blocks_size;
4890 // memcpy(output_data, input_data, num_inner_line_bytes);
4891 // input_data += inner_line_size;
4892 // output_data += inner_line_size;
4893 // memset(output_data, pad_value, num_right_block_bytes);
4894 // output_data += right_blocks_size;
4895 // }
4896 for (int j = 1; j < input_height; ++j) {
4897 memset(output_data, pad_value,
4898 num_right_block_bytes + num_left_block_bytes);
4899 output_data += right_blocks_size + left_blocks_size;
4900 memcpy(output_data, input_data, num_inner_line_bytes);
4901 input_data += inner_line_size;
4902 output_data += inner_line_size;
4903 }
4904 memset(output_data, pad_value,
4905 num_right_block_bytes + num_bottom_block_bytes);
4906 output_data += right_blocks_size + bottom_block_size;
4907 }
4908 }
4909 }
4910
4911 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)4912 inline void PadImageStyle(const tflite::PadParams& op_params,
4913 const RuntimeShape& input_shape, const T* input_data,
4914 const P* pad_value_ptr,
4915 const RuntimeShape& output_shape, T* output_data) {
4916 reference_ops::PadImageStyle(op_params, input_shape, input_data,
4917 pad_value_ptr, output_shape, output_data);
4918 }
4919
4920 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)4921 inline void PadImageStyle(const tflite::PadParams& op_params,
4922 const RuntimeShape& input_shape,
4923 const uint8* input_data, const P* pad_value_ptr,
4924 const RuntimeShape& output_shape,
4925 uint8* output_data) {
4926 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4927 output_shape, output_data);
4928 }
4929
4930 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)4931 inline void PadImageStyle(const tflite::PadParams& op_params,
4932 const RuntimeShape& input_shape,
4933 const float* input_data, const P* pad_value_ptr,
4934 const RuntimeShape& output_shape,
4935 float* output_data) {
4936 const float converted_pad_value = static_cast<float>(*pad_value_ptr);
4937 if (converted_pad_value == 0.0f) {
4938 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
4939 output_shape, output_data);
4940 } else {
4941 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
4942 output_data);
4943 }
4944 }
4945
4946 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)4947 inline void Slice(const tflite::SliceParams& op_params,
4948 const RuntimeShape& input_shape,
4949 const RuntimeShape& output_shape,
4950 SequentialTensorWriter<T>* writer) {
4951 ruy::profiler::ScopeLabel label("Slice");
4952 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
4953 TFLITE_DCHECK_LE(op_params.begin_count, 5);
4954 TFLITE_DCHECK_LE(op_params.size_count, 5);
4955 const int begin_count = op_params.begin_count;
4956 const int size_count = op_params.size_count;
4957 // We front-pad the begin and size vectors.
4958 std::array<int, 5> start;
4959 std::array<int, 5> stop;
4960 for (int i = 0; i < 5; ++i) {
4961 int padded_i = 5 - i;
4962 start[i] =
4963 begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
4964 stop[i] =
4965 (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
4966 ? ext_shape.Dims(i)
4967 : start[i] + op_params.size[size_count - padded_i];
4968 }
4969
4970 for (int i0 = start[0]; i0 < stop[0]; ++i0) {
4971 for (int i1 = start[1]; i1 < stop[1]; ++i1) {
4972 for (int i2 = start[2]; i2 < stop[2]; ++i2) {
4973 for (int i3 = start[3]; i3 < stop[3]; ++i3) {
4974 const int len = stop[4] - start[4];
4975 if (len > 0)
4976 writer->WriteN(Offset(ext_shape, i0, i1, i2, i3, start[4]), len);
4977 }
4978 }
4979 }
4980 }
4981 }
4982
4983 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)4984 inline void Slice(const tflite::SliceParams& op_params,
4985 const RuntimeShape& input_shape, const T* input_data,
4986 const RuntimeShape& output_shape, T* output_data) {
4987 SequentialTensorWriter<T> writer(input_data, output_data);
4988 return Slice(op_params, input_shape, output_shape, &writer);
4989 }
4990
4991 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)4992 inline void Slice(const tflite::SliceParams& op_params,
4993 const RuntimeShape& input_shape, const TfLiteTensor* input,
4994 const RuntimeShape& output_shape, TfLiteTensor* output) {
4995 SequentialTensorWriter<T> writer(input, output);
4996 return Slice(op_params, input_shape, output_shape, &writer);
4997 }
4998
4999 // Note: This implementation is only optimized for the case where the inner
5000 // stride == 1.
5001 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const RuntimeShape & unextended_output_shape,SequentialTensorWriter<T> * writer)5002 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5003 const RuntimeShape& unextended_input_shape,
5004 const RuntimeShape& unextended_output_shape,
5005 SequentialTensorWriter<T>* writer) {
5006 using strided_slice::LoopCondition;
5007 using strided_slice::StartForAxis;
5008 using strided_slice::StopForAxis;
5009
5010 ruy::profiler::ScopeLabel label("StridedSlice");
5011
5012 // Note that the output_shape is not used herein.
5013 tflite::StridedSliceParams params_copy = op_params;
5014
5015 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
5016 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
5017 const RuntimeShape input_shape =
5018 RuntimeShape::ExtendedShape(5, unextended_input_shape);
5019 const RuntimeShape output_shape =
5020 RuntimeShape::ExtendedShape(5, unextended_output_shape);
5021
5022 // Reverse and pad to 5 dimensions because that is what the runtime code
5023 // requires (ie. all shapes must be 5D and are given backwards).
5024 strided_slice::StridedSlicePadIndices(¶ms_copy, 5);
5025
5026 const int start_0 = StartForAxis(params_copy, input_shape, 0);
5027 const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
5028 const int start_1 = StartForAxis(params_copy, input_shape, 1);
5029 const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
5030 const int start_2 = StartForAxis(params_copy, input_shape, 2);
5031 const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
5032 const int start_3 = StartForAxis(params_copy, input_shape, 3);
5033 const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
5034 const int start_4 = StartForAxis(params_copy, input_shape, 4);
5035 const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
5036 const bool inner_stride_is_1 = params_copy.strides[4] == 1;
5037
5038 for (int offset_0 = start_0 * input_shape.Dims(1),
5039 end_0 = stop_0 * input_shape.Dims(1),
5040 step_0 = params_copy.strides[0] * input_shape.Dims(1);
5041 !LoopCondition(offset_0, end_0, params_copy.strides[0]);
5042 offset_0 += step_0) {
5043 for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
5044 end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
5045 step_1 = params_copy.strides[1] * input_shape.Dims(2);
5046 !LoopCondition(offset_1, end_1, params_copy.strides[1]);
5047 offset_1 += step_1) {
5048 for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
5049 end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
5050 step_2 = params_copy.strides[2] * input_shape.Dims(3);
5051 !LoopCondition(offset_2, end_2, params_copy.strides[2]);
5052 offset_2 += step_2) {
5053 for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
5054 end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
5055 step_3 = params_copy.strides[3] * input_shape.Dims(4);
5056 !LoopCondition(offset_3, end_3, params_copy.strides[3]);
5057 offset_3 += step_3) {
5058 // When the stride is 1, the inner loop is equivalent to the
5059 // optimized slice inner loop. Otherwise, it is identical to the
5060 // strided_slice reference implementation inner loop.
5061 if (inner_stride_is_1) {
5062 const int len = stop_4 - start_4;
5063 if (len > 0) {
5064 writer->WriteN(offset_3 + start_4, len);
5065 }
5066 } else {
5067 for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
5068 !LoopCondition(offset_4, end_4, params_copy.strides[4]);
5069 offset_4 += params_copy.strides[4]) {
5070 writer->Write(offset_4);
5071 }
5072 }
5073 }
5074 }
5075 }
5076 }
5077 }
5078
5079 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)5080 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5081 const RuntimeShape& unextended_input_shape,
5082 const T* input_data,
5083 const RuntimeShape& unextended_output_shape,
5084 T* output_data) {
5085 SequentialTensorWriter<T> writer(input_data, output_data);
5086 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5087 &writer);
5088 }
5089
5090 template <typename T>
StridedSlice(const tflite::StridedSliceParams & op_params,const RuntimeShape & unextended_input_shape,const TfLiteTensor * input,const RuntimeShape & unextended_output_shape,TfLiteTensor * output)5091 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
5092 const RuntimeShape& unextended_input_shape,
5093 const TfLiteTensor* input,
5094 const RuntimeShape& unextended_output_shape,
5095 TfLiteTensor* output) {
5096 SequentialTensorWriter<T> writer(input, output);
5097 StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
5098 &writer);
5099 }
5100
5101 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5102 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5103 const T* input2_data, const RuntimeShape& output_shape,
5104 T* output_data) {
5105 ruy::profiler::ScopeLabel label("TensorFlowMinimum");
5106 auto input1_map = MapAsVector(input1_data, input1_shape);
5107 auto output_map = MapAsVector(output_data, output_shape);
5108 auto min_value = input2_data[0];
5109 output_map.array() = input1_map.array().min(min_value);
5110 }
5111
5112 // Convenience version that allows, for example, generated-code calls to be
5113 // the same as other binary ops.
5114 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5115 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
5116 const RuntimeShape&, const T* input2_data,
5117 const RuntimeShape& output_shape, T* output_data) {
5118 // Drop shape of second input: not needed.
5119 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
5120 }
5121
5122 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5123 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5124 const T* input2_data, const RuntimeShape& output_shape,
5125 T* output_data) {
5126 ruy::profiler::ScopeLabel label("TensorFlowMaximum");
5127 auto input1_map = MapAsVector(input1_data, input1_shape);
5128 auto output_map = MapAsVector(output_data, output_shape);
5129 auto max_value = input2_data[0];
5130 output_map.array() = input1_map.array().max(max_value);
5131 }
5132
5133 // Convenience version that allows, for example, generated-code calls to be
5134 // the same as other binary ops.
5135 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)5136 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
5137 const RuntimeShape&, const T* input2_data,
5138 const RuntimeShape& output_shape, T* output_data) {
5139 // Drop shape of second input: not needed.
5140 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
5141 }
5142
5143 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)5144 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
5145 const RuntimeShape& input_shape, const T* input_data,
5146 const RuntimeShape& filter_shape,
5147 const RuntimeShape& output_shape, T* im2col_data) {
5148 ruy::profiler::ScopeLabel label("TransposeIm2col");
5149 const int stride_width = params.stride_width;
5150 const int stride_height = params.stride_height;
5151 const int pad_width = params.padding_values.width;
5152 const int pad_height = params.padding_values.height;
5153 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5154 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
5155 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5156 TFLITE_DCHECK(im2col_data);
5157
5158 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
5159 const int input_height = input_shape.Dims(1);
5160 const int input_width = input_shape.Dims(2);
5161 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
5162 const int filter_height = filter_shape.Dims(1);
5163 const int filter_width = filter_shape.Dims(2);
5164 const int output_height = output_shape.Dims(1);
5165 const int output_width = output_shape.Dims(2);
5166 MatchingDim(output_shape, 3, filter_shape, 0); // output_depth
5167
5168 // Construct the MxN sized im2col matrix.
5169 // The rows M, are sub-ordered B x H x W
5170 const RuntimeShape row_shape({1, batches, output_height, output_width});
5171 // The columns, N, are sub-ordered Kh x Kw x Din
5172 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
5173 // Use dimensions M and N to construct dims for indexing directly into im2col
5174 const RuntimeShape im2col_shape(
5175 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
5176
5177 // Build the im2col matrix by looping through all the input pixels,
5178 // computing their influence on the output, rather than looping through all
5179 // the output pixels. We therefore must initialize the im2col array to zero.
5180 // This is potentially inefficient because we subsequently overwrite bytes
5181 // set here. However, in practice memset is very fast and costs negligible.
5182 memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
5183
5184 // Loop through the output batches
5185 for (int batch = 0; batch < batches; ++batch) {
5186 // Loop through input pixels one at a time.
5187 for (int in_y = 0; in_y < input_height; ++in_y) {
5188 for (int in_x = 0; in_x < input_width; ++in_x) {
5189 // Loop through the output pixels it will influence
5190 const int out_x_origin = (in_x * stride_width) - pad_width;
5191 const int out_y_origin = (in_y * stride_height) - pad_height;
5192 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
5193 const int out_y = out_y_origin + filter_y;
5194 // Is output pixel within height bounds?
5195 if ((out_y >= 0) && (out_y < output_height)) {
5196 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
5197 const int out_x = out_x_origin + filter_x;
5198 // Is output pixel within width bounds?
5199 if ((out_x >= 0) && (out_x < output_width)) {
5200 // Copy the input elements of this pixel
5201 T const* src =
5202 input_data + Offset(input_shape, batch, in_y, in_x, 0);
5203 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
5204 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
5205 T* dst = im2col_data +
5206 Offset(im2col_shape, 0, 0, row_offset, col_offset);
5207 memcpy(dst, src, input_depth * sizeof(T));
5208 }
5209 }
5210 }
5211 }
5212 }
5213 }
5214 }
5215 }
5216
5217 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
5218 // order (height, width, depth), constructed from patches in 'col_data', which
5219 // is required to be in storage order (out_height * out_width, filter_height,
5220 // filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
5221 // Copied from //tensorflow/core/kernels/conv_grad_input_ops.cc
5222 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)5223 void Col2im(const T* col_data, const int depth, const int height,
5224 const int width, const int filter_h, const int filter_w,
5225 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
5226 const int stride_h, const int stride_w, T* im_data) {
5227 ruy::profiler::ScopeLabel label("Col2im");
5228 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
5229 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
5230 int h_pad = -pad_t;
5231 for (int h = 0; h < height_col; ++h) {
5232 int w_pad = -pad_l;
5233 for (int w = 0; w < width_col; ++w) {
5234 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
5235 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
5236 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
5237 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
5238 // TODO(andydavis) Vectorize this loop (if compiler does not).
5239 for (int i = 0; i < depth; ++i) {
5240 im_patch_data[i] += col_data[i];
5241 }
5242 }
5243 im_patch_data += depth;
5244 col_data += depth;
5245 }
5246 // Jump over remaining number of depth.
5247 im_patch_data += depth * (width - filter_w);
5248 }
5249 w_pad += stride_w;
5250 }
5251 h_pad += stride_h;
5252 }
5253 }
5254
5255 // TODO(b/188008864) Optimize this function by combining outer loops.
5256 template <typename T>
BiasAdd(T * im_data,const T * bias_data,const int batch_size,const int height,const int width,const int depth)5257 void BiasAdd(T* im_data, const T* bias_data, const int batch_size,
5258 const int height, const int width, const int depth) {
5259 if (bias_data) {
5260 for (int n = 0; n < batch_size; ++n) {
5261 for (int h = 0; h < height; ++h) {
5262 for (int w = 0; w < width; ++w) {
5263 for (int d = 0; d < depth; ++d) {
5264 im_data[d] += bias_data[d];
5265 }
5266 im_data += depth;
5267 }
5268 }
5269 }
5270 }
5271 }
5272
5273 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const float * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)5274 inline void TransposeConvV2(
5275 const ConvParams& params, const RuntimeShape& input_shape,
5276 const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5277 const float* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5278 const float* bias_data, const RuntimeShape& output_shape,
5279 float* const output_data, const RuntimeShape& col2im_shape,
5280 float* col2im_data, CpuBackendContext* cpu_backend_context) {
5281 ruy::profiler::ScopeLabel label("TransposeConvV2/float");
5282 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5283 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5284 TFLITE_DCHECK(col2im_data);
5285 TFLITE_DCHECK(hwoi_ordered_filter_data);
5286
5287 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5288 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5289 const int output_height = output_shape.Dims(1);
5290 const int output_width = output_shape.Dims(2);
5291 const int output_image_size = output_height * output_width;
5292 const int input_depth =
5293 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5294 const int output_depth =
5295 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5296 const int input_offset = input_image_size * input_depth;
5297 const int output_offset = output_image_size * output_depth;
5298
5299 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5300 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5301 const int padding_top = params.padding_values.height;
5302 const int padding_bottom =
5303 params.padding_values.height + params.padding_values.height_offset;
5304 const int padding_left = params.padding_values.width;
5305 const int padding_right =
5306 params.padding_values.width + params.padding_values.width_offset;
5307 const int stride_height = params.stride_height;
5308 const int stride_width = params.stride_width;
5309
5310 const int hwoi_ordered_filter_total_size =
5311 filter_height * filter_width * output_depth;
5312
5313 cpu_backend_gemm::MatrixParams<float> lhs_params;
5314 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5315 lhs_params.rows = hwoi_ordered_filter_total_size;
5316 lhs_params.cols = input_depth;
5317 float* output_data_p = output_data;
5318 std::fill_n(output_data, output_offset * batch_size, 0.0f);
5319 for (int i = 0; i < batch_size; ++i) {
5320 cpu_backend_gemm::MatrixParams<float> rhs_params;
5321 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5322 rhs_params.rows = input_depth;
5323 rhs_params.cols = input_image_size;
5324 cpu_backend_gemm::MatrixParams<float> dst_params;
5325 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5326 dst_params.rows = hwoi_ordered_filter_total_size;
5327 dst_params.cols = input_image_size;
5328 cpu_backend_gemm::GemmParams<float, float> gemm_params;
5329 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5330 input_data + input_offset * i, dst_params,
5331 col2im_data, gemm_params, cpu_backend_context);
5332
5333 Col2im(col2im_data, output_depth, output_height, output_width,
5334 filter_height, filter_width, padding_top, padding_left,
5335 padding_bottom, padding_right, stride_height, stride_width,
5336 output_data_p);
5337 output_data_p += output_offset;
5338 }
5339 output_data_p = output_data;
5340 BiasAdd(output_data_p, bias_data, batch_size, output_height, output_width,
5341 output_depth);
5342 }
5343
Quantize(int32_t multiplier,int32_t shift,int32_t total_size,int32_t output_zp,int32_t * scratch,uint8_t * output)5344 inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
5345 int32_t output_zp, int32_t* scratch, uint8_t* output) {
5346 ruy::profiler::ScopeLabel label("Quantize/uint8");
5347 int i = 0;
5348 const int32_t output_min = std::numeric_limits<uint8_t>::min();
5349 const int32_t output_max = std::numeric_limits<uint8_t>::max();
5350
5351 #ifdef USE_NEON
5352 const int32x4_t output_zp_dup = vdupq_n_s32(output_zp);
5353 const int32x4_t max_val_dup = vdupq_n_s32(output_max);
5354 const int32x4_t min_val_dup = vdupq_n_s32(output_min);
5355
5356 using gemmlowp::RoundingDivideByPOT;
5357 using gemmlowp::SaturatingRoundingDoublingHighMul;
5358
5359 for (; i <= total_size - 16; i += 16) {
5360 int32x4x4_t scratch_val;
5361 scratch_val.val[0] = vld1q_s32(scratch + i);
5362 scratch_val.val[1] = vld1q_s32(scratch + i + 4);
5363 scratch_val.val[2] = vld1q_s32(scratch + i + 8);
5364 scratch_val.val[3] = vld1q_s32(scratch + i + 12);
5365
5366 int32x4x4_t temp_val =
5367 MultiplyByQuantizedMultiplier4Rows(scratch_val, multiplier, shift);
5368
5369 temp_val.val[0] = vaddq_s32(temp_val.val[0], output_zp_dup);
5370 temp_val.val[1] = vaddq_s32(temp_val.val[1], output_zp_dup);
5371 temp_val.val[2] = vaddq_s32(temp_val.val[2], output_zp_dup);
5372 temp_val.val[3] = vaddq_s32(temp_val.val[3], output_zp_dup);
5373
5374 temp_val.val[0] =
5375 vmaxq_s32(vminq_s32(temp_val.val[0], max_val_dup), min_val_dup);
5376 temp_val.val[1] =
5377 vmaxq_s32(vminq_s32(temp_val.val[1], max_val_dup), min_val_dup);
5378 temp_val.val[2] =
5379 vmaxq_s32(vminq_s32(temp_val.val[2], max_val_dup), min_val_dup);
5380 temp_val.val[3] =
5381 vmaxq_s32(vminq_s32(temp_val.val[3], max_val_dup), min_val_dup);
5382
5383 const uint16x8_t result_1 =
5384 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[0])),
5385 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[1])));
5386 const uint16x8_t result_2 =
5387 vcombine_u16(vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[2])),
5388 vqmovn_u32(vreinterpretq_u32_s32(temp_val.val[3])));
5389 const uint8x16_t result =
5390 vcombine_u8(vqmovn_u16(result_1), vqmovn_u16(result_2));
5391 vst1q_u8(output + i, result);
5392 }
5393 #endif
5394 for (; i < total_size; ++i) {
5395 int32_t temp = MultiplyByQuantizedMultiplier(scratch[i], multiplier, shift);
5396 temp += output_zp;
5397 if (temp > output_max) {
5398 temp = output_max;
5399 }
5400 if (temp < output_min) {
5401 temp = output_min;
5402 }
5403 output[i] = static_cast<uint8_t>(temp);
5404 }
5405 }
5406
Quantize(const int32_t * multiplier,const int32_t * shift,int32_t channel_size,int32_t total_size,int32_t output_zp,int32_t output_min,int32_t output_max,int32_t * scratch,int8_t * output)5407 inline void Quantize(const int32_t* multiplier, const int32_t* shift,
5408 int32_t channel_size, int32_t total_size,
5409 int32_t output_zp, int32_t output_min, int32_t output_max,
5410 int32_t* scratch, int8_t* output) {
5411 ruy::profiler::ScopeLabel label("Quantize/int8");
5412
5413 // Here we're trying to quantize the raw accumulators:
5414 // output_channels
5415 // data data data data data
5416 // rows data data data data data
5417 // data data data data data
5418 // ....
5419 //
5420 // In order to minimize the reload of the multipliers & shifts, once we load
5421 // the multipliers & shifts, we load & quantize the raw accumulators for every
5422 // row.
5423 #ifdef USE_NEON
5424 const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
5425 const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
5426 const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
5427 const int32x4_t zeros = vdupq_n_s32(0);
5428 #endif
5429
5430 TFLITE_DCHECK_EQ(total_size % channel_size, 0);
5431 const int32_t rows = total_size / channel_size;
5432
5433 int c = 0;
5434
5435 #ifdef USE_NEON
5436 using gemmlowp::RoundingDivideByPOT;
5437 for (; c <= channel_size - 8; c += 8) {
5438 int32x4_t out_shift_1 = vld1q_s32(shift + c);
5439 int32x4_t out_shift_2 = vld1q_s32(shift + c + 4);
5440 int32x4_t left_shift_1 = vmaxq_s32(out_shift_1, zeros);
5441 int32x4_t left_shift_2 = vmaxq_s32(out_shift_2, zeros);
5442
5443 // Right shift will be performed as left shift with negative values.
5444 int32x4_t right_shift_1 = vminq_s32(out_shift_1, zeros);
5445 int32x4_t right_shift_2 = vminq_s32(out_shift_2, zeros);
5446
5447 int32x4_t out_mul_1 = vld1q_s32(multiplier + c);
5448 int32x4_t out_mul_2 = vld1q_s32(multiplier + c + 4);
5449 for (int n = 0; n < rows; ++n) {
5450 int loc = n * channel_size + c;
5451 int32x4_t acc_1 = vld1q_s32(scratch + loc);
5452 int32x4_t acc_2 = vld1q_s32(scratch + loc + 4);
5453
5454 // Saturating Rounding Doubling High Mul.
5455 acc_1 = vshlq_s32(acc_1, left_shift_1);
5456 acc_1 = vqrdmulhq_s32(acc_1, out_mul_1);
5457 acc_2 = vshlq_s32(acc_2, left_shift_2);
5458 acc_2 = vqrdmulhq_s32(acc_2, out_mul_2);
5459
5460 // Rounding Dividing By POT.
5461 acc_1 = vrshlq_s32(acc_1, right_shift_1);
5462 acc_2 = vrshlq_s32(acc_2, right_shift_2);
5463
5464 // Add the output offset.
5465 acc_1 = vaddq_s32(acc_1, output_offset_vec);
5466 acc_2 = vaddq_s32(acc_2, output_offset_vec);
5467
5468 // Apply the activation function.
5469 acc_1 = vmaxq_s32(acc_1, output_activation_min_vec);
5470 acc_1 = vminq_s32(acc_1, output_activation_max_vec);
5471 acc_2 = vmaxq_s32(acc_2, output_activation_min_vec);
5472 acc_2 = vminq_s32(acc_2, output_activation_max_vec);
5473
5474 // Saturating cast to int8 and store to destination.
5475 const int16x4_t acc_s16_1 = vqmovn_s32(acc_1);
5476 const int16x4_t acc_s16_2 = vqmovn_s32(acc_2);
5477 const int16x8_t res_s16 = vcombine_s16(acc_s16_1, acc_s16_2);
5478 const int8x8_t res_s8 = vqmovn_s16(res_s16);
5479 vst1_s8(output + loc, res_s8);
5480 }
5481 }
5482
5483 #endif // USE_NEON
5484 // Handle leftover values, one by one. This is very slow.
5485 for (; c < channel_size; c++) {
5486 for (int n = 0; n < rows; ++n) {
5487 int loc = n * channel_size + c;
5488 int32 acc = scratch[loc];
5489 acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
5490 acc += output_zp;
5491 acc = std::max(acc, output_min);
5492 acc = std::min(acc, output_max);
5493 output[loc] = static_cast<int8>(acc);
5494 }
5495 }
5496 }
5497
5498 // TransposeConvV2 expect the weights in HWOI order.
TransposeConvV2(const ConvParams & params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & hwoi_ordered_filter_shape,const uint8_t * hwoi_ordered_filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8_t * output_data,const RuntimeShape & col2im_shape,int32_t * col2im_data,int32_t * scratch_data,CpuBackendContext * cpu_backend_context)5499 inline void TransposeConvV2(
5500 const ConvParams& params, const RuntimeShape& input_shape,
5501 const uint8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
5502 const uint8_t* hwoi_ordered_filter_data, const RuntimeShape& bias_shape,
5503 const int32* bias_data, const RuntimeShape& output_shape,
5504 uint8_t* output_data, const RuntimeShape& col2im_shape,
5505 int32_t* col2im_data, int32_t* scratch_data,
5506 CpuBackendContext* cpu_backend_context) {
5507 ruy::profiler::ScopeLabel label("TransposeConvV2/uint8");
5508 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5509 TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
5510 TFLITE_DCHECK(col2im_data);
5511 TFLITE_DCHECK(hwoi_ordered_filter_data);
5512
5513 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
5514 const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
5515 const int output_height = output_shape.Dims(1);
5516 const int output_width = output_shape.Dims(2);
5517 const int output_image_size = output_height * output_width;
5518 const int input_depth =
5519 MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
5520 const int output_depth =
5521 MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
5522 const int input_offset = input_image_size * input_depth;
5523 const int output_offset = output_image_size * output_depth;
5524
5525 const int filter_height = hwoi_ordered_filter_shape.Dims(0);
5526 const int filter_width = hwoi_ordered_filter_shape.Dims(1);
5527 const int padding_top = params.padding_values.height;
5528 const int padding_bottom =
5529 params.padding_values.height + params.padding_values.height_offset;
5530 const int padding_left = params.padding_values.width;
5531 const int padding_right =
5532 params.padding_values.width + params.padding_values.width_offset;
5533 const int stride_height = params.stride_height;
5534 const int stride_width = params.stride_width;
5535
5536 const int hwoi_ordered_filter_total_size =
5537 filter_height * filter_width * output_depth;
5538
5539 cpu_backend_gemm::MatrixParams<uint8_t> lhs_params;
5540 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
5541 lhs_params.rows = hwoi_ordered_filter_total_size;
5542 lhs_params.cols = input_depth;
5543 lhs_params.zero_point = -params.weights_offset;
5544
5545 int32_t* scratch_data_p = scratch_data;
5546 std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
5547 for (int i = 0; i < batch_size; ++i) {
5548 cpu_backend_gemm::MatrixParams<uint8_t> rhs_params;
5549 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
5550 rhs_params.rows = input_depth;
5551 rhs_params.cols = input_image_size;
5552 rhs_params.zero_point = -params.input_offset;
5553
5554 cpu_backend_gemm::MatrixParams<int32_t> dst_params;
5555 dst_params.order = cpu_backend_gemm::Order::kColMajor;
5556 dst_params.rows = hwoi_ordered_filter_total_size;
5557 dst_params.cols = input_image_size;
5558
5559 cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
5560 cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
5561 input_data + input_offset * i, dst_params,
5562 col2im_data, gemm_params, cpu_backend_context);
5563
5564 Col2im(col2im_data, output_depth, output_height, output_width,
5565 filter_height, filter_width, padding_top, padding_left,
5566 padding_bottom, padding_right, stride_height, stride_width,
5567 scratch_data_p);
5568
5569 scratch_data_p += output_offset;
5570 }
5571 scratch_data_p = scratch_data;
5572 BiasAdd(scratch_data_p, bias_data, batch_size, output_height, output_width,
5573 output_depth);
5574
5575 Quantize(params.output_multiplier, params.output_shift,
5576 output_shape.FlatSize(), params.output_offset, scratch_data,
5577 output_data);
5578 }
5579
5580 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
5581 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
5582 // reference version. Debug checks are in place to test if this occurs.
5583 // NOTE: If align_corners or half_pixel_centers is true, we use the reference
5584 // version.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5585 inline void ResizeNearestNeighbor(
5586 const tflite::ResizeNearestNeighborParams& op_params,
5587 const RuntimeShape& unextended_input_shape, const uint8* input_data,
5588 const RuntimeShape& output_size_shape, const int32* output_size_data,
5589 const RuntimeShape& unextended_output_shape, uint8* output_data) {
5590 if (op_params.align_corners || op_params.half_pixel_centers) {
5591 // TODO(b/149823713): Add support for align_corners & half_pixel_centers in
5592 // this kernel.
5593 reference_ops::ResizeNearestNeighbor(
5594 op_params, unextended_input_shape, input_data, output_size_shape,
5595 output_size_data, unextended_output_shape, output_data);
5596 return;
5597 }
5598 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5599 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5600
5601 const RuntimeShape input_shape =
5602 RuntimeShape::ExtendedShape(4, unextended_input_shape);
5603 const RuntimeShape output_shape =
5604 RuntimeShape::ExtendedShape(4, unextended_output_shape);
5605
5606 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5607 int32 input_height = input_shape.Dims(1);
5608 int32 input_width = input_shape.Dims(2);
5609 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5610
5611 // The Tensorflow version of this op allows resize on the width and height
5612 // axis only.
5613 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5614 int32 output_height = output_size_data[0];
5615 int32 output_width = output_size_data[1];
5616
5617 // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
5618 // error factor and to avoid zero scales. For example, with input_height = 1,
5619 // output_height = 3, the float scaling factor would be non-zero at 1/3.
5620 // With fixed-point, this is zero.
5621 int32 height_scale = (input_height << 16) / output_height + 1;
5622 int32 width_scale = (input_width << 16) / output_width + 1;
5623
5624 const int col_offset = input_shape.Dims(3);
5625 const int row_offset = input_shape.Dims(2) * col_offset;
5626 const int batch_offset = input_shape.Dims(1) * row_offset;
5627
5628 const uint8* input_ptr = input_data;
5629 uint8* output_ptr = output_data;
5630 for (int b = 0; b < batches; ++b) {
5631 for (int y = 0; y < output_height; ++y) {
5632 int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
5633 // Check offset calculation is the same as the reference version. See
5634 // function comment for details. We check using a non-float version of:
5635 // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
5636 // / output_height)));
5637 TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
5638 TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
5639 const uint8* y_input_ptr = input_ptr + in_y * row_offset;
5640 for (int x = 0; x < output_width; ++x) {
5641 int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
5642 // Check offset calculation is the same as the reference version. See
5643 // function comment for details. We check using a non-float version of:
5644 // TFLITE_DCHECK_EQ(in_y,
5645 // std::floor(y * (static_cast<float>(input_width)
5646 // / output_width)));
5647 TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
5648 TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
5649 const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
5650 memcpy(output_ptr, x_input_ptr, depth);
5651 output_ptr += depth;
5652 }
5653 }
5654 input_ptr += batch_offset;
5655 }
5656 }
5657
5658 template <typename input_type, typename output_type>
Requantize(const input_type * input_data,int32_t size,int32_t effective_scale_multiplier,int32_t effective_scale_shift,int32_t input_zeropoint,int32_t output_zeropoint,output_type * output_data)5659 inline void Requantize(const input_type* input_data, int32_t size,
5660 int32_t effective_scale_multiplier,
5661 int32_t effective_scale_shift, int32_t input_zeropoint,
5662 int32_t output_zeropoint, output_type* output_data) {
5663 reference_ops::Requantize(input_data, size, effective_scale_multiplier,
5664 effective_scale_shift, input_zeropoint,
5665 output_zeropoint, output_data);
5666 }
5667
5668 template <>
5669 inline void Requantize<int8_t, uint8_t>(const int8_t* input_data, int32_t size,
5670 int32_t effective_scale_multiplier,
5671 int32_t effective_scale_shift,
5672 int32_t input_zeropoint,
5673 int32_t output_zeropoint,
5674 uint8_t* output_data) {
5675 ruy::profiler::ScopeLabel label("Requantize/Int8ToUint8");
5676
5677 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5678 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5679
5680 int i = 0;
5681 #ifdef USE_NEON
5682 // Constants.
5683 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5684 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5685 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5686 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5687
5688 for (; i <= size - 16; i += 16) {
5689 const int8x16_t input_vec = vld1q_s8(input_data + i);
5690 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5691 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5692 int32x4x4_t input;
5693 input.val[0] = vmovl_s16(vget_low_s16(first_half));
5694 input.val[1] = vmovl_s16(vget_high_s16(first_half));
5695 input.val[2] = vmovl_s16(vget_low_s16(second_half));
5696 input.val[3] = vmovl_s16(vget_high_s16(second_half));
5697 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5698 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5699 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5700 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5701
5702 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5703 input, effective_scale_multiplier, effective_scale_shift);
5704
5705 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5706 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5707 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5708 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5709 result.val[0] =
5710 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5711 result.val[1] =
5712 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5713 result.val[2] =
5714 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5715 result.val[3] =
5716 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5717
5718 const uint32x4_t result_val_1_unsigned =
5719 vreinterpretq_u32_s32(result.val[0]);
5720 const uint32x4_t result_val_2_unsigned =
5721 vreinterpretq_u32_s32(result.val[1]);
5722 const uint32x4_t result_val_3_unsigned =
5723 vreinterpretq_u32_s32(result.val[2]);
5724 const uint32x4_t result_val_4_unsigned =
5725 vreinterpretq_u32_s32(result.val[3]);
5726
5727 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5728 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5729 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5730 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5731 const uint16x8_t output_first_half =
5732 vcombine_u16(narrowed_val_1, narrowed_val_2);
5733 const uint16x8_t output_second_half =
5734 vcombine_u16(narrowed_val_3, narrowed_val_4);
5735 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5736 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5737 const uint8x16_t narrowed_result =
5738 vcombine_u8(narrowed_first_half, narrowed_second_half);
5739 vst1q_u8(output_data + i, narrowed_result);
5740 }
5741
5742 #endif
5743 for (; i < size; ++i) {
5744 const int32_t input = input_data[i] - input_zeropoint;
5745 const int32_t output =
5746 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5747 effective_scale_shift) +
5748 output_zeropoint;
5749 const int32_t clamped_output =
5750 std::max(std::min(output, kMaxOutput), kMinOutput);
5751 output_data[i] = static_cast<uint8_t>(clamped_output);
5752 }
5753 }
5754
5755 template <>
5756 inline void Requantize<uint8_t, int8_t>(const uint8_t* input_data, int32_t size,
5757 int32_t effective_scale_multiplier,
5758 int32_t effective_scale_shift,
5759 int32_t input_zeropoint,
5760 int32_t output_zeropoint,
5761 int8_t* output_data) {
5762 ruy::profiler::ScopeLabel label("Requantize/Uint8ToInt8");
5763
5764 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5765 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5766
5767 int i = 0;
5768 #ifdef USE_NEON
5769 // Constants.
5770 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5771 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5772 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5773 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5774
5775 for (; i <= size - 16; i += 16) {
5776 const uint8x16_t input_vec = vld1q_u8(input_data + i);
5777 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5778 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5779 int32x4x4_t input;
5780 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5781 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5782 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5783 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5784 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5785 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5786 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5787 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5788
5789 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5790 input, effective_scale_multiplier, effective_scale_shift);
5791
5792 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5793 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5794 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5795 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5796 result.val[0] =
5797 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5798 result.val[1] =
5799 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5800 result.val[2] =
5801 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5802 result.val[3] =
5803 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5804
5805 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5806 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5807 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5808 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5809 const int16x8_t output_first_half =
5810 vcombine_s16(narrowed_val_1, narrowed_val_2);
5811 const int16x8_t output_second_half =
5812 vcombine_s16(narrowed_val_3, narrowed_val_4);
5813 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5814 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5815 const int8x16_t narrowed_result =
5816 vcombine_s8(narrowed_first_half, narrowed_second_half);
5817 vst1q_s8(output_data + i, narrowed_result);
5818 }
5819
5820 #endif
5821 for (; i < size; ++i) {
5822 const int32_t input = input_data[i] - input_zeropoint;
5823 const int32_t output =
5824 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5825 effective_scale_shift) +
5826 output_zeropoint;
5827 const int32_t clamped_output =
5828 std::max(std::min(output, kMaxOutput), kMinOutput);
5829 output_data[i] = static_cast<int8_t>(clamped_output);
5830 }
5831 }
5832
5833 template <>
5834 inline void Requantize<int8_t, int8_t>(const int8_t* input_data, int32_t size,
5835 int32_t effective_scale_multiplier,
5836 int32_t effective_scale_shift,
5837 int32_t input_zeropoint,
5838 int32_t output_zeropoint,
5839 int8_t* output_data) {
5840 ruy::profiler::ScopeLabel label("Requantize/Int8ToInt8");
5841
5842 static constexpr int32_t kMinOutput = std::numeric_limits<int8_t>::min();
5843 static constexpr int32_t kMaxOutput = std::numeric_limits<int8_t>::max();
5844
5845 int i = 0;
5846 #ifdef USE_NEON
5847 // Constants.
5848 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5849 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5850 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5851 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5852
5853 for (; i <= size - 16; i += 16) {
5854 const int8x16_t input_vec = vld1q_s8(input_data + i);
5855 const int16x8_t first_half = vmovl_s8(vget_low_s8(input_vec));
5856 const int16x8_t second_half = vmovl_s8(vget_high_s8(input_vec));
5857 int32x4x4_t input;
5858 input.val[0] = vmovl_s16(vget_low_s16(first_half));
5859 input.val[1] = vmovl_s16(vget_high_s16(first_half));
5860 input.val[2] = vmovl_s16(vget_low_s16(second_half));
5861 input.val[3] = vmovl_s16(vget_high_s16(second_half));
5862
5863 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5864 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5865 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5866 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5867
5868 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5869 input, effective_scale_multiplier, effective_scale_shift);
5870
5871 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5872 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5873 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5874 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5875 result.val[0] =
5876 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5877 result.val[1] =
5878 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5879 result.val[2] =
5880 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5881 result.val[3] =
5882 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5883
5884 const int16x4_t narrowed_val_1 = vqmovn_s32(result.val[0]);
5885 const int16x4_t narrowed_val_2 = vqmovn_s32(result.val[1]);
5886 const int16x4_t narrowed_val_3 = vqmovn_s32(result.val[2]);
5887 const int16x4_t narrowed_val_4 = vqmovn_s32(result.val[3]);
5888 const int16x8_t output_first_half =
5889 vcombine_s16(narrowed_val_1, narrowed_val_2);
5890 const int16x8_t output_second_half =
5891 vcombine_s16(narrowed_val_3, narrowed_val_4);
5892 const int8x8_t narrowed_first_half = vqmovn_s16(output_first_half);
5893 const int8x8_t narrowed_second_half = vqmovn_s16(output_second_half);
5894 const int8x16_t narrowed_result =
5895 vcombine_s8(narrowed_first_half, narrowed_second_half);
5896 vst1q_s8(output_data + i, narrowed_result);
5897 }
5898
5899 #endif
5900 for (; i < size; ++i) {
5901 const int32_t input = input_data[i] - input_zeropoint;
5902 const int32_t output =
5903 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5904 effective_scale_shift) +
5905 output_zeropoint;
5906 const int32_t clamped_output =
5907 std::max(std::min(output, kMaxOutput), kMinOutput);
5908 output_data[i] = static_cast<int8_t>(clamped_output);
5909 }
5910 }
5911
5912 template <>
5913 inline void Requantize<uint8_t, uint8_t>(
5914 const uint8_t* input_data, int32_t size, int32_t effective_scale_multiplier,
5915 int32_t effective_scale_shift, int32_t input_zeropoint,
5916 int32_t output_zeropoint, uint8_t* output_data) {
5917 ruy::profiler::ScopeLabel label("Requantize/Uint8ToUint8");
5918
5919 static constexpr int32_t kMinOutput = std::numeric_limits<uint8_t>::min();
5920 static constexpr int32_t kMaxOutput = std::numeric_limits<uint8_t>::max();
5921
5922 int i = 0;
5923 #ifdef USE_NEON
5924 // Constants.
5925 const int32x4_t input_zero_point_dup = vdupq_n_s32(-input_zeropoint);
5926 const int32x4_t output_zero_point_dup = vdupq_n_s32(output_zeropoint);
5927 const int32x4_t min_val_dup = vdupq_n_s32(kMinOutput);
5928 const int32x4_t max_val_dup = vdupq_n_s32(kMaxOutput);
5929
5930 for (; i <= size - 16; i += 16) {
5931 const uint8x16_t input_vec = vld1q_u8(input_data + i);
5932 const uint16x8_t first_half = vmovl_u8(vget_low_u8(input_vec));
5933 const uint16x8_t second_half = vmovl_u8(vget_high_u8(input_vec));
5934 int32x4x4_t input;
5935 input.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(first_half)));
5936 input.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(first_half)));
5937 input.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(second_half)));
5938 input.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(second_half)));
5939 input.val[0] = vaddq_s32(input.val[0], input_zero_point_dup);
5940 input.val[1] = vaddq_s32(input.val[1], input_zero_point_dup);
5941 input.val[2] = vaddq_s32(input.val[2], input_zero_point_dup);
5942 input.val[3] = vaddq_s32(input.val[3], input_zero_point_dup);
5943
5944 int32x4x4_t result = MultiplyByQuantizedMultiplier4Rows(
5945 input, effective_scale_multiplier, effective_scale_shift);
5946
5947 result.val[0] = vaddq_s32(result.val[0], output_zero_point_dup);
5948 result.val[1] = vaddq_s32(result.val[1], output_zero_point_dup);
5949 result.val[2] = vaddq_s32(result.val[2], output_zero_point_dup);
5950 result.val[3] = vaddq_s32(result.val[3], output_zero_point_dup);
5951 result.val[0] =
5952 vmaxq_s32(vminq_s32(result.val[0], max_val_dup), min_val_dup);
5953 result.val[1] =
5954 vmaxq_s32(vminq_s32(result.val[1], max_val_dup), min_val_dup);
5955 result.val[2] =
5956 vmaxq_s32(vminq_s32(result.val[2], max_val_dup), min_val_dup);
5957 result.val[3] =
5958 vmaxq_s32(vminq_s32(result.val[3], max_val_dup), min_val_dup);
5959
5960 const uint32x4_t result_val_1_unsigned =
5961 vreinterpretq_u32_s32(result.val[0]);
5962 const uint32x4_t result_val_2_unsigned =
5963 vreinterpretq_u32_s32(result.val[1]);
5964 const uint32x4_t result_val_3_unsigned =
5965 vreinterpretq_u32_s32(result.val[2]);
5966 const uint32x4_t result_val_4_unsigned =
5967 vreinterpretq_u32_s32(result.val[3]);
5968
5969 const uint16x4_t narrowed_val_1 = vqmovn_u32(result_val_1_unsigned);
5970 const uint16x4_t narrowed_val_2 = vqmovn_u32(result_val_2_unsigned);
5971 const uint16x4_t narrowed_val_3 = vqmovn_u32(result_val_3_unsigned);
5972 const uint16x4_t narrowed_val_4 = vqmovn_u32(result_val_4_unsigned);
5973 const uint16x8_t output_first_half =
5974 vcombine_u16(narrowed_val_1, narrowed_val_2);
5975 const uint16x8_t output_second_half =
5976 vcombine_u16(narrowed_val_3, narrowed_val_4);
5977 const uint8x8_t narrowed_first_half = vqmovn_u16(output_first_half);
5978 const uint8x8_t narrowed_second_half = vqmovn_u16(output_second_half);
5979 const uint8x16_t narrowed_result =
5980 vcombine_u8(narrowed_first_half, narrowed_second_half);
5981 vst1q_u8(output_data + i, narrowed_result);
5982 }
5983
5984 #endif
5985 for (; i < size; ++i) {
5986 const int32_t input = input_data[i] - input_zeropoint;
5987 const int32_t output =
5988 MultiplyByQuantizedMultiplier(input, effective_scale_multiplier,
5989 effective_scale_shift) +
5990 output_zeropoint;
5991 const int32_t clamped_output =
5992 std::max(std::min(output, kMaxOutput), kMinOutput);
5993 output_data[i] = static_cast<uint8_t>(clamped_output);
5994 }
5995 }
5996
HardSwish(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5997 inline void HardSwish(const RuntimeShape& input_shape, const float* input_data,
5998 const RuntimeShape& output_shape, float* output_data) {
5999 ruy::profiler::ScopeLabel label("HardSwish/Float");
6000 auto size = MatchingFlatSize(input_shape, output_shape);
6001 int i = 0;
6002 #ifdef USE_NEON
6003 const float32x4_t zero = vdupq_n_f32(0.0f);
6004 const float32x4_t three = vdupq_n_f32(3.0f);
6005 const float32x4_t six = vdupq_n_f32(6.0f);
6006 const float32x4_t one_sixth = vdupq_n_f32(1.0f / 6.0f);
6007
6008 for (; i <= size - 16; i += 16) {
6009 // 4x partially unrolled version of the loop below. Refer to its comments.
6010 const float32x4_t in_0 = vld1q_f32(input_data + i + 0);
6011 const float32x4_t in_1 = vld1q_f32(input_data + i + 4);
6012 const float32x4_t in_2 = vld1q_f32(input_data + i + 8);
6013 const float32x4_t in_3 = vld1q_f32(input_data + i + 12);
6014 const float32x4_t in_scaled_0 = vmulq_f32(in_0, one_sixth);
6015 const float32x4_t in_scaled_1 = vmulq_f32(in_1, one_sixth);
6016 const float32x4_t in_scaled_2 = vmulq_f32(in_2, one_sixth);
6017 const float32x4_t in_scaled_3 = vmulq_f32(in_3, one_sixth);
6018 const float32x4_t in_reluish_0 =
6019 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_0, three)));
6020 const float32x4_t in_reluish_1 =
6021 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_1, three)));
6022 const float32x4_t in_reluish_2 =
6023 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_2, three)));
6024 const float32x4_t in_reluish_3 =
6025 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in_3, three)));
6026 const float32x4_t product_0 = vmulq_f32(in_scaled_0, in_reluish_0);
6027 const float32x4_t product_1 = vmulq_f32(in_scaled_1, in_reluish_1);
6028 const float32x4_t product_2 = vmulq_f32(in_scaled_2, in_reluish_2);
6029 const float32x4_t product_3 = vmulq_f32(in_scaled_3, in_reluish_3);
6030 vst1q_f32(output_data + i + 0, product_0);
6031 vst1q_f32(output_data + i + 4, product_1);
6032 vst1q_f32(output_data + i + 8, product_2);
6033 vst1q_f32(output_data + i + 12, product_3);
6034 }
6035 for (; i <= size - 4; i += 4) {
6036 // The expression to be computed is:
6037 // out = one_sixth * in * min(six, max(zero, (in + three)))
6038 // We structure the AST to have two roughly balanced, independent branches:
6039 // - Multiplication: in_scaled = one_sixth * in.
6040 // - Addition and clamping: in_reluish = min(six, max(zero, (in + three))).
6041 // Then the remaining multiplication at the root of the tree.
6042 const float32x4_t in = vld1q_f32(input_data + i);
6043 const float32x4_t in_scaled = vmulq_f32(in, one_sixth);
6044 const float32x4_t in_reluish =
6045 vminq_f32(six, vmaxq_f32(zero, vaddq_f32(in, three)));
6046 const float32x4_t product = vmulq_f32(in_scaled, in_reluish);
6047 vst1q_f32(output_data + i, product);
6048 }
6049 #endif
6050 for (; i < size; i++) {
6051 const float in = input_data[i];
6052 output_data[i] =
6053 in * std::min(6.0f, std::max(0.0f, in + 3.0f)) * (1.0f / 6.0f);
6054 }
6055 }
6056
6057 #ifdef USE_NEON
SaturateAndStore(int16x8_t src,std::uint8_t * dst)6058 inline void SaturateAndStore(int16x8_t src, std::uint8_t* dst) {
6059 // Narrow values down to 8 bit unsigned, saturating.
6060 uint8x8_t res8 = vqmovun_s16(src);
6061 // Store results to destination.
6062 vst1_u8(dst, res8);
6063 }
6064
SaturateAndStore(int16x8_t src,std::int8_t * dst)6065 inline void SaturateAndStore(int16x8_t src, std::int8_t* dst) {
6066 // Narrow values down to 8 bit unsigned, saturating.
6067 int8x8_t res8 = vqmovn_s16(src);
6068 // Store results to destination.
6069 vst1_s8(dst, res8);
6070 }
6071 #endif
6072
6073 template <typename T>
HardSwish(const HardSwishParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6074 inline void HardSwish(const HardSwishParams& params,
6075 const RuntimeShape& input_shape, const T* input_data,
6076 const RuntimeShape& output_shape, T* output_data) {
6077 ruy::profiler::ScopeLabel label("HardSwish/Quantized");
6078
6079 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6080
6081 int i = 0;
6082 // This code heavily uses NEON saturating left shifts (vqshl*) with shift
6083 // amounts that can be zero, in which case we rely on the correct behavior
6084 // of a left shift by zero returning just its first operand unmodified.
6085 // Unfortunately, the Intel arm_neon_sse.h implementation of vqshl* is
6086 // buggy in the case of zero shift amounts, see b/137199585. That is why
6087 // this NEON code path is restricted to true ARM NEON, excluding
6088 // arm_neon_sse.h. Anyway, the arm_neon_sse.h implementation of saturating
6089 // left shifts is slow scalar code, so there may not be much benefit in
6090 // running that over just plain reference code.
6091 //
6092 // TODO(b/137199585): revisit when this is fixed.
6093 #ifdef __ARM_NEON
6094 const int16x8_t positive_reluish_multiplier_exponent_minus_one =
6095 vdupq_n_s16(std::max(0, params.reluish_multiplier_exponent - 1));
6096 const int16x8_t positive_reluish_multiplier_exponent_last_bit =
6097 vdupq_n_s16(params.reluish_multiplier_exponent > 0 ? 1 : 0);
6098 const int16x8_t negative_reluish_multiplier_exponent =
6099 vdupq_n_s16(std::min(0, params.reluish_multiplier_exponent));
6100 const int16x8_t constant_32767 = vdupq_n_s16(32767);
6101 const int16x8_t output_multiplier_exponent =
6102 vdupq_n_s16(params.output_multiplier_exponent);
6103 const int16x8_t output_zero_point = vdupq_n_s16(params.output_zero_point);
6104 // 4x unrolled version of the below NEON loop. Read that first.
6105 for (; i <= flat_size - 32; i += 32) {
6106 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6107 const int16x8x2_t input_value_0_1 =
6108 Load16AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6109 const int16x8x2_t input_value_2_3 = Load16AndSubtractZeroPoint(
6110 input_data + i + 16, params.input_zero_point);
6111 const int16x8_t input_value_on_hires_input_scale_0 =
6112 vshlq_n_s16(input_value_0_1.val[0], 7);
6113 const int16x8_t input_value_on_hires_input_scale_1 =
6114 vshlq_n_s16(input_value_0_1.val[1], 7);
6115 const int16x8_t input_value_on_hires_input_scale_2 =
6116 vshlq_n_s16(input_value_2_3.val[0], 7);
6117 const int16x8_t input_value_on_hires_input_scale_3 =
6118 vshlq_n_s16(input_value_2_3.val[1], 7);
6119 const int16x8_t input_value_on_preshift_output_scale_0 =
6120 vqrdmulhq_n_s16(input_value_on_hires_input_scale_0,
6121 params.output_multiplier_fixedpoint_int16);
6122 const int16x8_t input_value_on_preshift_output_scale_1 =
6123 vqrdmulhq_n_s16(input_value_on_hires_input_scale_1,
6124 params.output_multiplier_fixedpoint_int16);
6125 const int16x8_t input_value_on_preshift_output_scale_2 =
6126 vqrdmulhq_n_s16(input_value_on_hires_input_scale_2,
6127 params.output_multiplier_fixedpoint_int16);
6128 const int16x8_t input_value_on_preshift_output_scale_3 =
6129 vqrdmulhq_n_s16(input_value_on_hires_input_scale_3,
6130 params.output_multiplier_fixedpoint_int16);
6131 int16x8_t reluish_value_0 = input_value_on_hires_input_scale_0;
6132 int16x8_t reluish_value_1 = input_value_on_hires_input_scale_1;
6133 int16x8_t reluish_value_2 = input_value_on_hires_input_scale_2;
6134 int16x8_t reluish_value_3 = input_value_on_hires_input_scale_3;
6135 reluish_value_0 = vqshlq_s16(
6136 reluish_value_0, positive_reluish_multiplier_exponent_minus_one);
6137 reluish_value_1 = vqshlq_s16(
6138 reluish_value_1, positive_reluish_multiplier_exponent_minus_one);
6139 reluish_value_2 = vqshlq_s16(
6140 reluish_value_2, positive_reluish_multiplier_exponent_minus_one);
6141 reluish_value_3 = vqshlq_s16(
6142 reluish_value_3, positive_reluish_multiplier_exponent_minus_one);
6143 reluish_value_0 = vqrdmulhq_n_s16(
6144 reluish_value_0, params.reluish_multiplier_fixedpoint_int16);
6145 reluish_value_1 = vqrdmulhq_n_s16(
6146 reluish_value_1, params.reluish_multiplier_fixedpoint_int16);
6147 reluish_value_2 = vqrdmulhq_n_s16(
6148 reluish_value_2, params.reluish_multiplier_fixedpoint_int16);
6149 reluish_value_3 = vqrdmulhq_n_s16(
6150 reluish_value_3, params.reluish_multiplier_fixedpoint_int16);
6151 reluish_value_0 = vqshlq_s16(reluish_value_0,
6152 positive_reluish_multiplier_exponent_last_bit);
6153 reluish_value_1 = vqshlq_s16(reluish_value_1,
6154 positive_reluish_multiplier_exponent_last_bit);
6155 reluish_value_2 = vqshlq_s16(reluish_value_2,
6156 positive_reluish_multiplier_exponent_last_bit);
6157 reluish_value_3 = vqshlq_s16(reluish_value_3,
6158 positive_reluish_multiplier_exponent_last_bit);
6159 reluish_value_0 =
6160 vrshlq_s16(reluish_value_0, negative_reluish_multiplier_exponent);
6161 reluish_value_1 =
6162 vrshlq_s16(reluish_value_1, negative_reluish_multiplier_exponent);
6163 reluish_value_2 =
6164 vrshlq_s16(reluish_value_2, negative_reluish_multiplier_exponent);
6165 reluish_value_3 =
6166 vrshlq_s16(reluish_value_3, negative_reluish_multiplier_exponent);
6167 reluish_value_0 = vrhaddq_s16(reluish_value_0, constant_32767);
6168 reluish_value_1 = vrhaddq_s16(reluish_value_1, constant_32767);
6169 reluish_value_2 = vrhaddq_s16(reluish_value_2, constant_32767);
6170 reluish_value_3 = vrhaddq_s16(reluish_value_3, constant_32767);
6171 const int16x8_t preshift_output_value_0 =
6172 vqdmulhq_s16(reluish_value_0, input_value_on_preshift_output_scale_0);
6173 const int16x8_t preshift_output_value_1 =
6174 vqdmulhq_s16(reluish_value_1, input_value_on_preshift_output_scale_1);
6175 const int16x8_t preshift_output_value_2 =
6176 vqdmulhq_s16(reluish_value_2, input_value_on_preshift_output_scale_2);
6177 const int16x8_t preshift_output_value_3 =
6178 vqdmulhq_s16(reluish_value_3, input_value_on_preshift_output_scale_3);
6179 int16x8_t output_value_0 =
6180 vrshlq_s16(preshift_output_value_0, output_multiplier_exponent);
6181 int16x8_t output_value_1 =
6182 vrshlq_s16(preshift_output_value_1, output_multiplier_exponent);
6183 int16x8_t output_value_2 =
6184 vrshlq_s16(preshift_output_value_2, output_multiplier_exponent);
6185 int16x8_t output_value_3 =
6186 vrshlq_s16(preshift_output_value_3, output_multiplier_exponent);
6187 output_value_0 = vaddq_s16(output_value_0, output_zero_point);
6188 output_value_1 = vaddq_s16(output_value_1, output_zero_point);
6189 output_value_2 = vaddq_s16(output_value_2, output_zero_point);
6190 output_value_3 = vaddq_s16(output_value_3, output_zero_point);
6191 SaturateAndStore(output_value_0, output_data + i);
6192 SaturateAndStore(output_value_1, output_data + i + 8);
6193 SaturateAndStore(output_value_2, output_data + i + 16);
6194 SaturateAndStore(output_value_3, output_data + i + 24);
6195 }
6196 // NEON version of reference_ops::HardSwish. Read that first.
6197 for (; i <= flat_size - 8; i += 8) {
6198 using cpu_backend_gemm::detail::Load8AndSubtractZeroPoint;
6199 const int16x8_t input_value =
6200 Load8AndSubtractZeroPoint(input_data + i, params.input_zero_point);
6201 const int16x8_t input_value_on_hires_input_scale =
6202 vshlq_n_s16(input_value, 7);
6203 const int16x8_t input_value_on_preshift_output_scale =
6204 vqrdmulhq_n_s16(input_value_on_hires_input_scale,
6205 params.output_multiplier_fixedpoint_int16);
6206 int16x8_t reluish_value = input_value_on_hires_input_scale;
6207 reluish_value = vqshlq_s16(reluish_value,
6208 positive_reluish_multiplier_exponent_minus_one);
6209 reluish_value = vqrdmulhq_n_s16(reluish_value,
6210 params.reluish_multiplier_fixedpoint_int16);
6211 reluish_value = vqshlq_s16(reluish_value,
6212 positive_reluish_multiplier_exponent_last_bit);
6213 reluish_value =
6214 vrshlq_s16(reluish_value, negative_reluish_multiplier_exponent);
6215 reluish_value = vrhaddq_s16(reluish_value, constant_32767);
6216 const int16x8_t preshift_output_value =
6217 vqdmulhq_s16(reluish_value, input_value_on_preshift_output_scale);
6218 int16x8_t output_value =
6219 vrshlq_s16(preshift_output_value, output_multiplier_exponent);
6220 output_value = vaddq_s16(output_value, output_zero_point);
6221 SaturateAndStore(output_value, output_data + i);
6222 }
6223 #endif
6224 // TODO(b/137208495): revisit when unit tests cover reference code.
6225 // Fall back to reference_ops::HardSwish. In general we have preferred
6226 // to duplicate such scalar code rather than call reference code to handle
6227 // leftovers, thinking that code duplication was not a big concern.
6228 // However, most of our unit tests happen to test only optimized code,
6229 // and the quantized HardSwish implementation is nontrivial enough that
6230 // I really want test coverage for the reference code.
6231 if (i < flat_size) {
6232 const RuntimeShape leftover_shape{flat_size - i};
6233 reference_ops::HardSwish(params, leftover_shape, input_data + i,
6234 leftover_shape, output_data + i);
6235 }
6236 }
6237
6238 template <typename T>
IntegerExponentPow(const ArithmeticParams & params,const RuntimeShape & unextended_base_shape,const T * base_data,const int exponent,const RuntimeShape & unextended_output_shape,T * output_data)6239 inline void IntegerExponentPow(const ArithmeticParams& params,
6240 const RuntimeShape& unextended_base_shape,
6241 const T* base_data, const int exponent,
6242 const RuntimeShape& unextended_output_shape,
6243 T* output_data) {
6244 TFLITE_DCHECK_GE(exponent, 1);
6245 if (exponent == 1) {
6246 // copy data over.
6247 std::memcpy(output_data, base_data,
6248 unextended_base_shape.FlatSize() * sizeof(T));
6249 } else {
6250 IntegerExponentPow(params, unextended_base_shape, base_data, exponent / 2,
6251 unextended_output_shape, output_data);
6252 Mul(params, unextended_base_shape, output_data, unextended_base_shape,
6253 output_data, unextended_output_shape, output_data);
6254 if (exponent % 2 == 1) {
6255 Mul(params, unextended_base_shape, base_data, unextended_base_shape,
6256 output_data, unextended_output_shape, output_data);
6257 }
6258 }
6259 }
6260
6261 template <typename T>
BroadcastPow4D(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)6262 inline void BroadcastPow4D(const RuntimeShape& unextended_input1_shape,
6263 const T* input1_data,
6264 const RuntimeShape& unextended_input2_shape,
6265 const T* input2_data,
6266 const RuntimeShape& unextended_output_shape,
6267 T* output_data) {
6268 ruy::profiler::ScopeLabel label("PowBroadcast");
6269
6270 if (unextended_input2_shape.FlatSize() == 1) {
6271 static const float epsilon = 1e-5;
6272 const T exponent = input2_data[0];
6273 const int int_exponent = static_cast<int>(std::round(exponent));
6274 if ((std::abs(input2_data[0] - int_exponent) < epsilon) &&
6275 (int_exponent >= 1)) {
6276 ArithmeticParams params;
6277 if (std::is_same<T, float>::value) {
6278 params.float_activation_max = std::numeric_limits<float>::max();
6279 params.float_activation_min = std::numeric_limits<float>::lowest();
6280 } else if (std::is_same<T, int>::value) {
6281 params.quantized_activation_max = std::numeric_limits<int>::max();
6282 params.quantized_activation_min = std::numeric_limits<int>::lowest();
6283 }
6284 IntegerExponentPow(params, unextended_input1_shape, input1_data,
6285 int_exponent, unextended_output_shape, output_data);
6286 return;
6287 }
6288 }
6289 reference_ops::BroadcastPow4DSlow(unextended_input1_shape, input1_data,
6290 unextended_input2_shape, input2_data,
6291 unextended_output_shape, output_data);
6292 }
6293
6294 #ifdef USE_NEON
6295
ScaleWithNewZeroPoint(const int32x4_t input,const float32x4_t scale_dup,const float32x4_t zero_times_scale_dup,float32x4_t * output)6296 inline void ScaleWithNewZeroPoint(const int32x4_t input,
6297 const float32x4_t scale_dup,
6298 const float32x4_t zero_times_scale_dup,
6299 float32x4_t* output) {
6300 #ifdef __ARM_FEATURE_FMA
6301 *output = vfmaq_f32(zero_times_scale_dup, vcvtq_f32_s32(input), scale_dup);
6302 #else
6303 *output = vaddq_f32(vmulq_f32(vcvtq_f32_s32(input), scale_dup),
6304 zero_times_scale_dup);
6305 #endif
6306 }
6307
6308 #endif // USE_NEON
6309
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,const RuntimeShape & output_shape,float * output_data)6310 inline void Dequantize(const tflite::DequantizationParams& op_params,
6311 const RuntimeShape& input_shape,
6312 const uint8_t* input_data,
6313 const RuntimeShape& output_shape, float* output_data) {
6314 ruy::profiler::ScopeLabel label("Dequantize/Uint8");
6315 const int32 zero_point = op_params.zero_point;
6316 const double scale = op_params.scale;
6317 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6318
6319 int i = 0;
6320 #ifdef USE_NEON
6321 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6322 const float32x4_t zero_times_scale_dup =
6323 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6324 for (; i <= flat_size - 8; i += 8) {
6325 const uint8x8_t input_u8 = vld1_u8(input_data + i);
6326 const uint16x8_t input_u16 = vmovl_u8(input_u8);
6327 const int16x8_t input_s16 = vreinterpretq_s16_u16(input_u16);
6328 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6329 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6330 const int32x4_t val_low = vmovl_s16(input_s16_low);
6331 const int32x4_t val_high = vmovl_s16(input_s16_high);
6332
6333 float32x4_t result_low, result_high;
6334 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6335 &result_low);
6336 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6337 &result_high);
6338
6339 vst1q_f32(output_data + i, result_low);
6340 vst1q_f32(output_data + i + 4, result_high);
6341 }
6342 #endif // NEON
6343 for (; i < flat_size; ++i) {
6344 const int32 val = input_data[i];
6345 const float result = static_cast<float>(scale * (val - zero_point));
6346 output_data[i] = result;
6347 }
6348 }
6349
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & output_shape,float * output_data)6350 inline void Dequantize(const tflite::DequantizationParams& op_params,
6351 const RuntimeShape& input_shape,
6352 const int8_t* input_data,
6353 const RuntimeShape& output_shape, float* output_data) {
6354 ruy::profiler::ScopeLabel label("Dequantize/Int8");
6355 const int32 zero_point = op_params.zero_point;
6356 const double scale = op_params.scale;
6357 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6358
6359 int i = 0;
6360 #ifdef USE_NEON
6361 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6362 const float32x4_t zero_times_scale_dup =
6363 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6364 for (; i <= flat_size - 8; i += 8) {
6365 const int8x8_t input_s8 = vld1_s8(input_data + i);
6366 const int16x8_t input_s16 = vmovl_s8(input_s8);
6367 const int16x4_t input_s16_low = vget_low_s16(input_s16);
6368 const int16x4_t input_s16_high = vget_high_s16(input_s16);
6369 const int32x4_t val_low = vmovl_s16(input_s16_low);
6370 const int32x4_t val_high = vmovl_s16(input_s16_high);
6371
6372 float32x4_t result_low, result_high;
6373 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6374 &result_low);
6375 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6376 &result_high);
6377
6378 vst1q_f32(output_data + i, result_low);
6379 vst1q_f32(output_data + i + 4, result_high);
6380 }
6381 #endif // NEON
6382 for (; i < flat_size; ++i) {
6383 const int32 val = input_data[i];
6384 const float result = static_cast<float>(scale * (val - zero_point));
6385 output_data[i] = result;
6386 }
6387 }
6388
Dequantize(const tflite::DequantizationParams & op_params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,float * output_data)6389 inline void Dequantize(const tflite::DequantizationParams& op_params,
6390 const RuntimeShape& input_shape,
6391 const int16_t* input_data,
6392 const RuntimeShape& output_shape, float* output_data) {
6393 ruy::profiler::ScopeLabel label("Dequantize/Int16");
6394 const int32 zero_point = op_params.zero_point;
6395 const double scale = op_params.scale;
6396 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6397
6398 int i = 0;
6399 #ifdef USE_NEON
6400 const float32x4_t scale_dup = vdupq_n_f32(static_cast<float>(scale));
6401 const float32x4_t zero_times_scale_dup =
6402 vdupq_n_f32(static_cast<float>(-zero_point * scale));
6403 for (; i <= flat_size - 8; i += 8) {
6404 const int16x4_t input_s16_low = vld1_s16(input_data + i);
6405 const int16x4_t input_s16_high = vld1_s16(input_data + i + 4);
6406 const int32x4_t val_low = vmovl_s16(input_s16_low);
6407 const int32x4_t val_high = vmovl_s16(input_s16_high);
6408
6409 float32x4_t result_low, result_high;
6410 ScaleWithNewZeroPoint(val_low, scale_dup, zero_times_scale_dup,
6411 &result_low);
6412 ScaleWithNewZeroPoint(val_high, scale_dup, zero_times_scale_dup,
6413 &result_high);
6414
6415 vst1q_f32(output_data + i, result_low);
6416 vst1q_f32(output_data + i + 4, result_high);
6417 }
6418 #endif // NEON
6419 for (; i < flat_size; ++i) {
6420 const int32 val = input_data[i];
6421 const float result = static_cast<float>(scale * (val - zero_point));
6422 output_data[i] = result;
6423 }
6424 }
6425
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)6426 inline void Dequantize(const RuntimeShape& input_shape,
6427 const Eigen::half* input_data,
6428 const RuntimeShape& output_shape, float* output_data) {
6429 reference_ops::Dequantize(input_shape, input_data, output_shape, output_data);
6430 }
6431
6432 template <typename T>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,T * output_data)6433 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6434 const RuntimeShape& input_shape,
6435 const float* input_data,
6436 const RuntimeShape& output_shape, T* output_data) {
6437 reference_ops::AffineQuantize(op_params, input_shape, input_data,
6438 output_shape, output_data);
6439 }
6440
6441 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int8_t * output_data)6442 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6443 const RuntimeShape& input_shape,
6444 const float* input_data,
6445 const RuntimeShape& output_shape,
6446 int8_t* output_data) {
6447 ruy::profiler::ScopeLabel label("Quantize/Int8");
6448 const int32 zero_point = op_params.zero_point;
6449 const double scale = static_cast<double>(op_params.scale);
6450 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6451 static constexpr int32 min_val = std::numeric_limits<int8_t>::min();
6452 static constexpr int32 max_val = std::numeric_limits<int8_t>::max();
6453
6454 int i = 0;
6455 #ifdef USE_NEON
6456 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6457 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6458 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6459 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6460
6461 for (; i <= flat_size - 8; i += 8) {
6462 const float* src_data_ptr = input_data + i;
6463 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6464 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6465
6466 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6467 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6468
6469 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6470 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6471
6472 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6473 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6474
6475 // Clamp the values to fit the target type's range.
6476 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6477 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6478 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6479 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6480
6481 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6482 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6483 const int16x8_t combined_val = vcombine_s16(narrowed_val_0, narrowed_val_1);
6484 const int8x8_t combined_val_narrowed = vmovn_s16(combined_val);
6485 vst1_s8(output_data + i, combined_val_narrowed);
6486 }
6487 #endif // NEON
6488
6489 for (; i < flat_size; ++i) {
6490 const float val = input_data[i];
6491 const int32 unclamped =
6492 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6493 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6494 output_data[i] = clamped;
6495 }
6496 }
6497
6498 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,uint8_t * output_data)6499 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6500 const RuntimeShape& input_shape,
6501 const float* input_data,
6502 const RuntimeShape& output_shape,
6503 uint8_t* output_data) {
6504 ruy::profiler::ScopeLabel label("Quantize/Uint8");
6505 const int32 zero_point = op_params.zero_point;
6506 const double scale = static_cast<double>(op_params.scale);
6507 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6508 static constexpr int32 min_val = std::numeric_limits<uint8_t>::min();
6509 static constexpr int32 max_val = std::numeric_limits<uint8_t>::max();
6510
6511 int i = 0;
6512 #ifdef USE_NEON
6513 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6514 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6515 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6516 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6517
6518 for (; i <= flat_size - 8; i += 8) {
6519 const float* src_data_ptr = input_data + i;
6520 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6521 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6522
6523 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6524 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6525
6526 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6527 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6528
6529 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6530 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6531
6532 // Clamp the values to fit the target type's range.
6533 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6534 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6535 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6536 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6537
6538 const uint16x4_t narrowed_val_0 = vqmovun_s32(casted_val_0);
6539 const uint16x4_t narrowed_val_1 = vqmovun_s32(casted_val_1);
6540 const uint16x8_t combined_val =
6541 vcombine_u16(narrowed_val_0, narrowed_val_1);
6542 const uint8x8_t combined_val_narrowed = vmovn_u16(combined_val);
6543 vst1_u8(output_data + i, combined_val_narrowed);
6544 }
6545 #endif // NEON
6546
6547 for (; i < flat_size; ++i) {
6548 const float val = input_data[i];
6549 const int32 unclamped =
6550 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6551 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6552 output_data[i] = clamped;
6553 }
6554 }
6555
6556 template <>
AffineQuantize(const tflite::QuantizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,int16_t * output_data)6557 inline void AffineQuantize(const tflite::QuantizationParams& op_params,
6558 const RuntimeShape& input_shape,
6559 const float* input_data,
6560 const RuntimeShape& output_shape,
6561 int16_t* output_data) {
6562 ruy::profiler::ScopeLabel label("Quantize/Int16");
6563 const int32 zero_point = op_params.zero_point;
6564 const double scale = static_cast<double>(op_params.scale);
6565 const int flat_size = MatchingFlatSize(input_shape, output_shape);
6566 static constexpr int32 min_val = std::numeric_limits<int16_t>::min();
6567 static constexpr int32 max_val = std::numeric_limits<int16_t>::max();
6568
6569 int i = 0;
6570 #ifdef USE_NEON
6571 const float32x4_t reverse_scale_dup = vdupq_n_f32(1.0f / scale);
6572 const int32x4_t zero_point_dup = vdupq_n_s32(zero_point);
6573 const int32x4_t min_val_dup = vdupq_n_s32(min_val);
6574 const int32x4_t max_val_dup = vdupq_n_s32(max_val);
6575
6576 for (; i <= flat_size - 8; i += 8) {
6577 const float* src_data_ptr = input_data + i;
6578 float32x4_t input_val_0 = vld1q_f32(src_data_ptr);
6579 float32x4_t input_val_1 = vld1q_f32(src_data_ptr + 4);
6580
6581 input_val_0 = vmulq_f32(input_val_0, reverse_scale_dup);
6582 input_val_1 = vmulq_f32(input_val_1, reverse_scale_dup);
6583
6584 int32x4_t casted_val_0 = RoundToNearest(input_val_0);
6585 int32x4_t casted_val_1 = RoundToNearest(input_val_1);
6586
6587 casted_val_0 = vaddq_s32(casted_val_0, zero_point_dup);
6588 casted_val_1 = vaddq_s32(casted_val_1, zero_point_dup);
6589
6590 // Clamp the values to fit the target type's range.
6591 casted_val_0 = vmaxq_s32(casted_val_0, min_val_dup);
6592 casted_val_1 = vmaxq_s32(casted_val_1, min_val_dup);
6593 casted_val_0 = vminq_s32(casted_val_0, max_val_dup);
6594 casted_val_1 = vminq_s32(casted_val_1, max_val_dup);
6595
6596 const int16x4_t narrowed_val_0 = vmovn_s32(casted_val_0);
6597 const int16x4_t narrowed_val_1 = vmovn_s32(casted_val_1);
6598 vst1_s16(output_data + i, narrowed_val_0);
6599 vst1_s16(output_data + i + 4, narrowed_val_1);
6600 }
6601 #endif // NEON
6602
6603 for (; i < flat_size; ++i) {
6604 const float val = input_data[i];
6605 const int32 unclamped =
6606 static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
6607 const int32 clamped = std::min(std::max(unclamped, min_val), max_val);
6608 output_data[i] = clamped;
6609 }
6610 }
6611
6612 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6613 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6614 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6615 #ifdef GEMMLOWP_NEON
6616
SaturatingRounding(int16x8_t input_val_0,int16x8_t input_val_1,int16x8_t input_val_2,int16x8_t input_val_3,int input_left_shift,int input_multiplier)6617 inline int16x8x4_t SaturatingRounding(
6618 int16x8_t input_val_0, int16x8_t input_val_1, int16x8_t input_val_2,
6619 int16x8_t input_val_3, int input_left_shift, int input_multiplier) {
6620 // This performs what is expressed in the scalar code as
6621 // const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6622 // static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6623 // static_cast<int16>(input_multiplier));
6624 const int16x8_t left_shift_dup = vdupq_n_s16(input_left_shift);
6625 const int16x8_t input_val_shifted_0 = vshlq_s16(input_val_0, left_shift_dup);
6626 const int16x8_t input_val_shifted_1 = vshlq_s16(input_val_1, left_shift_dup);
6627 const int16x8_t input_val_shifted_2 = vshlq_s16(input_val_2, left_shift_dup);
6628 const int16x8_t input_val_shifted_3 = vshlq_s16(input_val_3, left_shift_dup);
6629 int16x8x4_t result;
6630 result.val[0] = vqrdmulhq_n_s16(input_val_shifted_0, input_multiplier);
6631 result.val[1] = vqrdmulhq_n_s16(input_val_shifted_1, input_multiplier);
6632 result.val[2] = vqrdmulhq_n_s16(input_val_shifted_2, input_multiplier);
6633 result.val[3] = vqrdmulhq_n_s16(input_val_shifted_3, input_multiplier);
6634 return result;
6635 }
6636
6637 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6638 // considering 7 digits under zero.
FixedPoint4Logistic(int16x8x4_t input_val)6639 inline int16x8x4_t FixedPoint4Logistic(int16x8x4_t input_val) {
6640 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6641 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6642 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6643 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6644 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6645 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6646 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6647
6648 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6649 // method, gemmlowp::tanh spends about 80% of the execution times. The
6650 // current implementation is rougly 12-bit accurate in the 16-bit fixed
6651 // point case. Until reaching to error bounds, there are rooms for
6652 // improvements.
6653 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
6654 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
6655 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
6656 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
6657
6658 // Divide by 2^7 as in the scalar code
6659 int16x8x4_t result;
6660 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 7);
6661 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 7);
6662 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 7);
6663 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 7);
6664 return result;
6665 }
6666
6667 // 4-bit fixed point is enough for tanh since tanh(16) is almost same with one,
6668 // considering 11 digits under zero at least.
FixedPoint4Tanh(int16x8x4_t input_val)6669 inline int16x8x4_t FixedPoint4Tanh(int16x8x4_t input_val) {
6670 // Invoke gemmlowp::logistic on FixedPoint wrapping int16x8_t
6671 using FixedPoint4 = gemmlowp::FixedPoint<int16x8_t, 4>;
6672 using FixedPoint0 = gemmlowp::FixedPoint<int16x8_t, 0>;
6673 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val.val[0]);
6674 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val.val[1]);
6675 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val.val[2]);
6676 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val.val[3]);
6677
6678 // TODO(b/134622898) Implement a low accuracy version of logistic. In this
6679 // method, gemmlowp::tanh spends about 80% of the execution times. The
6680 // current implementation is rougly 12-bit accurate in the 16-bit fixed
6681 // point case. Until reaching to error bounds, there are rooms for
6682 // improvements.
6683 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
6684 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
6685 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
6686 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
6687
6688 // Divide by 2^7 as in the scalar code
6689 int16x8x4_t result;
6690 result.val[0] = vrshrq_n_s16(output_val_f0_0.raw(), 8);
6691 result.val[1] = vrshrq_n_s16(output_val_f0_1.raw(), 8);
6692 result.val[2] = vrshrq_n_s16(output_val_f0_2.raw(), 8);
6693 result.val[3] = vrshrq_n_s16(output_val_f0_3.raw(), 8);
6694 return result;
6695 }
6696
CalculateUnsignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6697 inline uint8x16x2_t CalculateUnsignedClampingWithRangeBitMasks(
6698 int16x8x2_t input_val, int16x8_t range_radius_dup,
6699 int16x8_t neg_range_radius_dup) {
6700 const uint16x8_t mask_rightclamp_0 =
6701 vcgtq_s16(input_val.val[0], range_radius_dup);
6702 const uint16x8_t mask_rightclamp_1 =
6703 vcgtq_s16(input_val.val[1], range_radius_dup);
6704
6705 const uint16x8_t mask_leftclamp_0 =
6706 vcgeq_s16(input_val.val[0], neg_range_radius_dup);
6707 const uint16x8_t mask_leftclamp_1 =
6708 vcgeq_s16(input_val.val[1], neg_range_radius_dup);
6709
6710 uint8x16x2_t result;
6711 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6712 vshrn_n_u16(mask_leftclamp_1, 8));
6713 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6714 vshrn_n_u16(mask_rightclamp_1, 8));
6715 return result;
6716 }
6717
CalculateSignedClampingWithRangeBitMasks(int16x8x2_t input_val,int16x8_t range_radius_dup,int16x8_t neg_range_radius_dup)6718 inline uint8x16x2_t CalculateSignedClampingWithRangeBitMasks(
6719 int16x8x2_t input_val, int16x8_t range_radius_dup,
6720 int16x8_t neg_range_radius_dup) {
6721 const uint16x8_t mask_rightclamp_0 =
6722 vcgtq_s16(input_val.val[0], range_radius_dup);
6723 const uint16x8_t mask_rightclamp_1 =
6724 vcgtq_s16(input_val.val[1], range_radius_dup);
6725
6726 const uint16x8_t mask_leftclamp_0 =
6727 vcltq_s16(input_val.val[0], neg_range_radius_dup);
6728 const uint16x8_t mask_leftclamp_1 =
6729 vcltq_s16(input_val.val[1], neg_range_radius_dup);
6730
6731 uint8x16x2_t result;
6732 result.val[0] = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
6733 vshrn_n_u16(mask_leftclamp_1, 8));
6734 result.val[1] = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
6735 vshrn_n_u16(mask_rightclamp_1, 8));
6736 return result;
6737 }
6738
ClampWithRangeAndStore(uint8_t * output_dst,uint8x16_t input_val,uint8x16x2_t masks_clamp)6739 inline void ClampWithRangeAndStore(uint8_t* output_dst, uint8x16_t input_val,
6740 uint8x16x2_t masks_clamp) {
6741 // Store back to memory
6742 vst1q_u8(output_dst, vandq_u8(vorrq_u8(input_val, masks_clamp.val[1]),
6743 masks_clamp.val[0]));
6744 }
6745
ClampWithRangeAndStore(int8_t * output_dst,int8x16_t input_val,uint8x16x2_t masks_clamp)6746 inline void ClampWithRangeAndStore(int8_t* output_dst, int8x16_t input_val,
6747 uint8x16x2_t masks_clamp) {
6748 static const int8x16_t max_dup = vdupq_n_s8(127);
6749 static const int8x16_t min_dup = vdupq_n_s8(-128);
6750 // Store back to memory
6751 vst1q_s8(output_dst,
6752 vbslq_s8(masks_clamp.val[1], max_dup,
6753 vbslq_s8(masks_clamp.val[0], min_dup, input_val)));
6754 }
6755
6756 #endif // GEMMLOWP_NEON
6757
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6758 inline void Tanh16bitPrecision(const TanhParams& params,
6759 const RuntimeShape& input_shape,
6760 const uint8* input_data,
6761 const RuntimeShape& output_shape,
6762 uint8* output_data) {
6763 // Note that this is almost the exact same code as in Logistic().
6764 ruy::profiler::ScopeLabel label("Tanh/Uint8");
6765 const int32 input_zero_point = params.input_zero_point;
6766 const int32 input_range_radius = params.input_range_radius;
6767 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6768 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6769 const int size = MatchingFlatSize(input_shape, output_shape);
6770
6771 int c = 0;
6772 int16_t output_zero_point = 128;
6773
6774 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6775 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6776 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6777 #ifdef GEMMLOWP_NEON
6778 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6779 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6780 const int16x8_t output_zero_point_s16 = vdupq_n_s16(output_zero_point);
6781
6782 // Handle 32 values at a time
6783 for (; c <= size - 32; c += 32) {
6784 // Read input uint8 values, cast to int16 and subtract input_zero_point
6785 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6786 const int16x8x2_t input_val_centered_0_1 =
6787 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6788 const int16x8x2_t input_val_centered_2_3 =
6789 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6790
6791 // Prepare the bit masks that we will use at the end to implement the logic
6792 // that was expressed in the scalar code with branching:
6793 // if (input_val_centered < -input_range_radius) {
6794 // output_val = 0;
6795 // } else if (input_val_centered > input_range_radius) {
6796 // output_val = 255;
6797 // } else {
6798 // ...
6799 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6800 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6801 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6802 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6803
6804 int16x8x4_t input_val_rescaled = SaturatingRounding(
6805 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6806 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6807 input_left_shift, input_multiplier);
6808
6809 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6810
6811 // Add the output zero point
6812 output_val_s16.val[0] =
6813 vaddq_s16(output_val_s16.val[0], output_zero_point_s16);
6814 output_val_s16.val[1] =
6815 vaddq_s16(output_val_s16.val[1], output_zero_point_s16);
6816 output_val_s16.val[2] =
6817 vaddq_s16(output_val_s16.val[2], output_zero_point_s16);
6818 output_val_s16.val[3] =
6819 vaddq_s16(output_val_s16.val[3], output_zero_point_s16);
6820
6821 // Cast output values to uint8, saturating
6822 uint8x16_t output_val_u8_0_1 = vcombine_u8(
6823 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
6824 uint8x16_t output_val_u8_2_3 = vcombine_u8(
6825 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
6826
6827 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
6828 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
6829 masks_clamp_2_3);
6830 }
6831 #endif // GEMMLOWP_NEON
6832 // Leftover loop: handle one value at a time with scalar code.
6833 for (; c < size; ++c) {
6834 const uint8 input_val_u8 = input_data[c];
6835 const int16 input_val_centered =
6836 static_cast<int16>(input_val_u8) - input_zero_point;
6837 uint8 output_val;
6838 if (input_val_centered < -input_range_radius) {
6839 output_val = 0;
6840 } else if (input_val_centered > input_range_radius) {
6841 output_val = 255;
6842 } else {
6843 using gemmlowp::SaturatingRoundingDoublingHighMul;
6844 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6845 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6846 static_cast<int16>(input_multiplier));
6847 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6848 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6849 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6850 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6851 using gemmlowp::RoundingDivideByPOT;
6852 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6853 output_val_s16 += output_zero_point;
6854 if (output_val_s16 == 256) {
6855 output_val_s16 = 255;
6856 }
6857 TFLITE_DCHECK_GE(output_val_s16, 0);
6858 TFLITE_DCHECK_LE(output_val_s16, 255);
6859 output_val = static_cast<uint8>(output_val_s16);
6860 }
6861 output_data[c] = output_val;
6862 }
6863 }
6864
Tanh16bitPrecision(const TanhParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)6865 inline void Tanh16bitPrecision(const TanhParams& params,
6866 const RuntimeShape& input_shape,
6867 const int8* input_data,
6868 const RuntimeShape& output_shape,
6869 int8* output_data) {
6870 // Note that this is almost the exact same code as in Logistic().
6871 ruy::profiler::ScopeLabel label("Tanh/Int8");
6872 const int32 input_zero_point = params.input_zero_point;
6873 const int32 input_range_radius = params.input_range_radius;
6874 const int16 input_multiplier = static_cast<int16>(params.input_multiplier);
6875 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6876 const int size = MatchingFlatSize(input_shape, output_shape);
6877
6878 int c = 0;
6879 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6880 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6881 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6882 #ifdef GEMMLOWP_NEON
6883 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6884 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6885
6886 // Handle 32 values at a time
6887 for (; c <= size - 32; c += 32) {
6888 // Read input int8 values, cast to int16 and subtract input_zero_point
6889 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6890 const int16x8x2_t input_val_centered_0_1 =
6891 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6892 const int16x8x2_t input_val_centered_2_3 =
6893 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6894
6895 // Prepare the bit masks that we will use at the end to implement the logic
6896 // that was expressed in the scalar code with branching:
6897 // if (input_val_centered < -input_range_radius) {
6898 // output_val = -128;
6899 // } else if (input_val_centered > input_range_radius) {
6900 // output_val = 127;
6901 // } else {
6902 // ...
6903 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
6904 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6905 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
6906 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6907
6908 int16x8x4_t input_val_rescaled = SaturatingRounding(
6909 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
6910 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
6911 input_left_shift, input_multiplier);
6912
6913 int16x8x4_t output_val_s16 = FixedPoint4Tanh(input_val_rescaled);
6914
6915 // Cast output values to uint8, saturating
6916 int8x16_t output_val_s8_0_1 = vcombine_s8(
6917 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
6918 int8x16_t output_val_s8_2_3 = vcombine_s8(
6919 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
6920
6921 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
6922 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
6923 masks_clamp_2_3);
6924 }
6925 #endif // GEMMLOWP_NEON
6926 // Leftover loop: handle one value at a time with scalar code.
6927 for (; c < size; ++c) {
6928 const int8 input_val_s8 = input_data[c];
6929 const int16 input_val_centered =
6930 static_cast<int16>(input_val_s8) - input_zero_point;
6931 int8 output_val;
6932 if (input_val_centered <= -input_range_radius) {
6933 output_val = -128;
6934 } else if (input_val_centered >= input_range_radius) {
6935 output_val = 127;
6936 } else {
6937 using gemmlowp::SaturatingRoundingDoublingHighMul;
6938 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
6939 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
6940 static_cast<int16>(input_multiplier));
6941 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
6942 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
6943 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
6944 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
6945 using gemmlowp::RoundingDivideByPOT;
6946 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 8);
6947 if (output_val_s16 == 128) {
6948 output_val_s16 = 127;
6949 }
6950 TFLITE_DCHECK_GE(output_val_s16, -128);
6951 TFLITE_DCHECK_LE(output_val_s16, 127);
6952 output_val = static_cast<int8>(output_val_s16);
6953 }
6954 output_data[c] = output_val;
6955 }
6956 }
6957
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)6958 inline void Logistic16bitPrecision(const LogisticParams& params,
6959 const RuntimeShape& input_shape,
6960 const uint8* input_data,
6961 const RuntimeShape& output_shape,
6962 uint8* output_data) {
6963 ruy::profiler::ScopeLabel label("Logistic/Uint8");
6964 const int32 input_zero_point = params.input_zero_point;
6965 const int32 input_range_radius = params.input_range_radius;
6966 const int32 input_multiplier = params.input_multiplier;
6967 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
6968 const int size = MatchingFlatSize(input_shape, output_shape);
6969
6970 int c = 0;
6971 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
6972 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
6973 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
6974 #ifdef GEMMLOWP_NEON
6975 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
6976 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
6977
6978 // Handle 32 values at a time
6979 for (; c <= size - 32; c += 32) {
6980 // Read input uint8 values, cast to int16 and subtract input_zero_point
6981 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
6982 const int16x8x2_t input_val_centered_0_1 =
6983 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
6984 const int16x8x2_t input_val_centered_2_3 =
6985 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
6986
6987 // Prepare the bit masks that we will use at the end to implement the logic
6988 // that was expressed in the scalar code with branching:
6989 // if (input_val_centered < -input_range_radius) {
6990 // output_val = 0;
6991 // } else if (input_val_centered > input_range_radius) {
6992 // output_val = 255;
6993 // } else {
6994 // ...
6995 uint8x16x2_t masks_clamp_0_1 = CalculateUnsignedClampingWithRangeBitMasks(
6996 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
6997 uint8x16x2_t masks_clamp_2_3 = CalculateUnsignedClampingWithRangeBitMasks(
6998 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
6999
7000 int16x8x4_t input_val_rescaled = SaturatingRounding(
7001 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7002 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7003 input_left_shift, input_multiplier);
7004
7005 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7006
7007 // Cast output values to uint8, saturating
7008 uint8x16_t output_val_u8_0_1 = vcombine_u8(
7009 vqmovun_s16(output_val_s16.val[0]), vqmovun_s16(output_val_s16.val[1]));
7010 uint8x16_t output_val_u8_2_3 = vcombine_u8(
7011 vqmovun_s16(output_val_s16.val[2]), vqmovun_s16(output_val_s16.val[3]));
7012
7013 ClampWithRangeAndStore(output_data + c, output_val_u8_0_1, masks_clamp_0_1);
7014 ClampWithRangeAndStore(output_data + c + 16, output_val_u8_2_3,
7015 masks_clamp_2_3);
7016 }
7017 #endif // GEMMLOWP_NEON
7018 // Leftover loop: handle one value at a time with scalar code.
7019 for (; c < size; ++c) {
7020 const uint8 input_val_u8 = input_data[c];
7021 const int16 input_val_centered =
7022 static_cast<int16>(input_val_u8) - input_zero_point;
7023 uint8 output_val;
7024 if (input_val_centered < -input_range_radius) {
7025 output_val = 0;
7026 } else if (input_val_centered > input_range_radius) {
7027 output_val = 255;
7028 } else {
7029 using gemmlowp::SaturatingRoundingDoublingHighMul;
7030 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7031 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7032 static_cast<int16>(input_multiplier));
7033 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7034 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7035 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7036 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7037 using gemmlowp::RoundingDivideByPOT;
7038 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7039 if (output_val_s16 == 256) {
7040 output_val_s16 = 255;
7041 }
7042 TFLITE_DCHECK_GE(output_val_s16, 0);
7043 TFLITE_DCHECK_LE(output_val_s16, 255);
7044 output_val = static_cast<uint8>(output_val_s16);
7045 }
7046 output_data[c] = output_val;
7047 }
7048 }
7049
Logistic16bitPrecision(const LogisticParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,int8 * output_data)7050 inline void Logistic16bitPrecision(const LogisticParams& params,
7051 const RuntimeShape& input_shape,
7052 const int8* input_data,
7053 const RuntimeShape& output_shape,
7054 int8* output_data) {
7055 ruy::profiler::ScopeLabel label("Logistic/Int8");
7056 const int32 input_zero_point = params.input_zero_point;
7057 const int32 input_range_radius = params.input_range_radius;
7058 const int32 input_multiplier = params.input_multiplier;
7059 const int16 input_left_shift = static_cast<int16>(params.input_left_shift);
7060 const int size = MatchingFlatSize(input_shape, output_shape);
7061
7062 int c = 0;
7063 const int16 output_zero_point = 128;
7064 // TODO(b/139252020): Replace GEMMLOWP_NEON with USE_NEON when the bug is fixed.
7065 // The converted versions of gemmlowp::tanh and gemmlowp::logistic, done by
7066 // arm_sse_2_neon.h, produce incorrect results with int16x8_t data types.
7067 #ifdef GEMMLOWP_NEON
7068 const int16x8_t range_radius_dup = vdupq_n_s16(input_range_radius);
7069 const int16x8_t neg_range_radius_dup = vdupq_n_s16(-input_range_radius);
7070 const int16x8_t output_zero_point_dup = vdupq_n_s16(output_zero_point);
7071
7072 // Handle 32 values at a time
7073 for (; c <= size - 32; c += 32) {
7074 // Read input int8 values, cast to int16 and subtract input_zero_point
7075 using cpu_backend_gemm::detail::Load16AndSubtractZeroPoint;
7076 const int16x8x2_t input_val_centered_0_1 =
7077 Load16AndSubtractZeroPoint(input_data + c, input_zero_point);
7078 const int16x8x2_t input_val_centered_2_3 =
7079 Load16AndSubtractZeroPoint(input_data + c + 16, input_zero_point);
7080
7081 // Prepare the bit masks that we will use at the end to implement the logic
7082 // that was expressed in the scalar code with branching:
7083 // if (input_val_centered < -input_range_radius) {
7084 // output_val = -128;
7085 // } else if (input_val_centered > input_range_radius) {
7086 // output_val = 127;
7087 // } else {
7088 // ...
7089 uint8x16x2_t masks_clamp_0_1 = CalculateSignedClampingWithRangeBitMasks(
7090 input_val_centered_0_1, range_radius_dup, neg_range_radius_dup);
7091 uint8x16x2_t masks_clamp_2_3 = CalculateSignedClampingWithRangeBitMasks(
7092 input_val_centered_2_3, range_radius_dup, neg_range_radius_dup);
7093
7094 int16x8x4_t input_val_rescaled = SaturatingRounding(
7095 input_val_centered_0_1.val[0], input_val_centered_0_1.val[1],
7096 input_val_centered_2_3.val[0], input_val_centered_2_3.val[1],
7097 input_left_shift, input_multiplier);
7098
7099 int16x8x4_t output_val_s16 = FixedPoint4Logistic(input_val_rescaled);
7100
7101 // Substract output zero point.
7102 output_val_s16.val[0] =
7103 vsubq_s16(output_val_s16.val[0], output_zero_point_dup);
7104 output_val_s16.val[1] =
7105 vsubq_s16(output_val_s16.val[1], output_zero_point_dup);
7106 output_val_s16.val[2] =
7107 vsubq_s16(output_val_s16.val[2], output_zero_point_dup);
7108 output_val_s16.val[3] =
7109 vsubq_s16(output_val_s16.val[3], output_zero_point_dup);
7110
7111 // Cast output values to int8, saturating
7112 int8x16_t output_val_s8_0_1 = vcombine_s8(
7113 vqmovn_s16(output_val_s16.val[0]), vqmovn_s16(output_val_s16.val[1]));
7114 int8x16_t output_val_s8_2_3 = vcombine_s8(
7115 vqmovn_s16(output_val_s16.val[2]), vqmovn_s16(output_val_s16.val[3]));
7116
7117 ClampWithRangeAndStore(output_data + c, output_val_s8_0_1, masks_clamp_0_1);
7118 ClampWithRangeAndStore(output_data + c + 16, output_val_s8_2_3,
7119 masks_clamp_2_3);
7120 }
7121 #endif // GEMMLOWP_NEON
7122 // Leftover loop: handle one value at a time with scalar code.
7123 for (; c < size; ++c) {
7124 const int8 input_val_s8 = input_data[c];
7125 const int16 input_val_centered =
7126 static_cast<int16>(input_val_s8) - input_zero_point;
7127 int8 output_val;
7128 if (input_val_centered < -input_range_radius) {
7129 output_val = -128;
7130 } else if (input_val_centered > input_range_radius) {
7131 output_val = 127;
7132 } else {
7133 using gemmlowp::SaturatingRoundingDoublingHighMul;
7134 const int16 input_val_rescaled = SaturatingRoundingDoublingHighMul(
7135 static_cast<int16>(input_val_centered * (1 << input_left_shift)),
7136 static_cast<int16>(input_multiplier));
7137 using FixedPoint4 = gemmlowp::FixedPoint<int16, 4>;
7138 using FixedPoint0 = gemmlowp::FixedPoint<int16, 0>;
7139 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
7140 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
7141 using gemmlowp::RoundingDivideByPOT;
7142 int16 output_val_s16 = RoundingDivideByPOT(output_val_f0.raw(), 7);
7143 output_val_s16 -= output_zero_point;
7144 if (output_val_s16 == 128) {
7145 output_val_s16 = 127;
7146 }
7147 TFLITE_DCHECK_GE(output_val_s16, -128);
7148 TFLITE_DCHECK_LE(output_val_s16, 127);
7149 output_val = static_cast<int8>(output_val_s16);
7150 }
7151 output_data[c] = output_val;
7152 }
7153 }
7154
7155 // Transpose2D only deals with typical 2D matrix transpose ops.
7156 // Perform transpose by transposing 4x4 blocks of the input, proceeding from
7157 // left to right (down the rows) of the input, and then from top to bottom.
7158 template <typename T>
Transpose2D(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7159 inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
7160 const RuntimeShape& output_shape, T* output_data) {
7161 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7162 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7163
7164 const int d0 = input_shape.DimsData()[0];
7165 const int d1 = input_shape.DimsData()[1];
7166 const int kLines = 4;
7167 const int kSkipSize = (kLines - 1) * d1;
7168
7169 const T* input = input_data;
7170
7171 int i = 0;
7172 for (; i <= d0 - kLines; i += kLines) {
7173 T* output = output_data + i;
7174
7175 const T* input_ptr = input;
7176 optimized_ops_preload_l1_keep(input_ptr);
7177 input_ptr += d1;
7178 optimized_ops_preload_l1_keep(input_ptr);
7179 input_ptr += d1;
7180 optimized_ops_preload_l1_keep(input_ptr);
7181 input_ptr += d1;
7182 optimized_ops_preload_l1_keep(input_ptr);
7183
7184 int j = 0;
7185 for (; j <= d1 - kLines; j += kLines) {
7186 input_ptr = input;
7187 const T a00 = input_ptr[0];
7188 const T a01 = input_ptr[1];
7189 const T a02 = input_ptr[2];
7190 const T a03 = input_ptr[3];
7191 input_ptr += d1;
7192 const T a10 = input_ptr[0];
7193 const T a11 = input_ptr[1];
7194 const T a12 = input_ptr[2];
7195 const T a13 = input_ptr[3];
7196 input_ptr += d1;
7197 const T a20 = input_ptr[0];
7198 const T a21 = input_ptr[1];
7199 const T a22 = input_ptr[2];
7200 const T a23 = input_ptr[3];
7201 input_ptr += d1;
7202 const T a30 = input_ptr[0];
7203 const T a31 = input_ptr[1];
7204 const T a32 = input_ptr[2];
7205 const T a33 = input_ptr[3];
7206
7207 output[0] = a00;
7208 output[1] = a10;
7209 output[2] = a20;
7210 output[3] = a30;
7211 output += d0;
7212
7213 output[0] = a01;
7214 output[1] = a11;
7215 output[2] = a21;
7216 output[3] = a31;
7217 output += d0;
7218
7219 output[0] = a02;
7220 output[1] = a12;
7221 output[2] = a22;
7222 output[3] = a32;
7223 output += d0;
7224
7225 output[0] = a03;
7226 output[1] = a13;
7227 output[2] = a23;
7228 output[3] = a33;
7229 output += d0;
7230
7231 input += kLines;
7232 }
7233 if (j == d1) {
7234 input += kSkipSize;
7235 } else {
7236 for (int p = 0; p < kLines; ++p) {
7237 for (int q = 0; q < d1 - j; ++q) {
7238 *(output + q * d0 + p) = *(input + p * d1 + q);
7239 }
7240 }
7241 input += (d1 - j) + kSkipSize;
7242 }
7243 }
7244 for (; i < d0; ++i) {
7245 T* output = output_data + i;
7246 for (int j = 0; j < d1; ++j) {
7247 *output = *input;
7248 output += d0;
7249 ++input;
7250 }
7251 }
7252 }
7253
7254 template <>
Transpose2D(const RuntimeShape & input_shape,const int32_t * input_data,const RuntimeShape & output_shape,int32_t * output_data)7255 inline void Transpose2D(const RuntimeShape& input_shape,
7256 const int32_t* input_data,
7257 const RuntimeShape& output_shape,
7258 int32_t* output_data) {
7259 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
7260 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
7261
7262 const int d0 = input_shape.DimsData()[0];
7263 const int d1 = input_shape.DimsData()[1];
7264 #ifdef USE_NEON
7265 const int kLines = 4;
7266 const int kSkipSize = (kLines - 1) * d1;
7267 #endif
7268
7269 const int32_t* input = input_data;
7270
7271 int i = 0;
7272 #ifdef USE_NEON
7273 for (; i <= d0 - kLines; i += kLines) {
7274 int32_t* output = output_data + i;
7275
7276 const int32_t* input_ptr = input;
7277 optimized_ops_preload_l1_keep(input_ptr);
7278 input_ptr += d1;
7279 optimized_ops_preload_l1_keep(input_ptr);
7280 input_ptr += d1;
7281 optimized_ops_preload_l1_keep(input_ptr);
7282 input_ptr += d1;
7283 optimized_ops_preload_l1_keep(input_ptr);
7284
7285 int j = 0;
7286 for (; j <= d1 - kLines; j += kLines) {
7287 input_ptr = input;
7288 int32x4_t a0 = vld1q_s32(input);
7289 input_ptr += d1;
7290 int32x4_t a1 = vld1q_s32(input_ptr);
7291 input_ptr += d1;
7292 int32x4_t a2 = vld1q_s32(input_ptr);
7293 input_ptr += d1;
7294 int32x4_t a3 = vld1q_s32(input_ptr);
7295
7296 int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
7297 int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
7298 int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
7299 int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
7300
7301 vst1q_s32(output, tmp3.val[0]);
7302 output += d0;
7303 vst1q_s32(output, tmp4.val[0]);
7304 output += d0;
7305 vst1q_s32(output, tmp3.val[1]);
7306 output += d0;
7307 vst1q_s32(output, tmp4.val[1]);
7308 output += d0;
7309 input += kLines;
7310 }
7311 if (j == d1) {
7312 input += kSkipSize;
7313 } else {
7314 for (int p = 0; p < kLines; ++p) {
7315 for (int q = 0; q < d1 - j; ++q) {
7316 *(output + q * d0 + p) = *(input + p * d1 + q);
7317 }
7318 }
7319 input += (d1 - j) + kSkipSize;
7320 }
7321 }
7322 #endif
7323 for (; i < d0; ++i) {
7324 int32_t* output = output_data + i;
7325 for (int j = 0; j < d1; ++j) {
7326 *output = *input;
7327 output += d0;
7328 ++input;
7329 }
7330 }
7331 }
7332
7333 // TODO(b/173718660): see if we can reduce the number
7334 // of lines of code in branching without affecting latency.
7335 template <typename T>
Transpose3D(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7336 inline void Transpose3D(const TransposeParams& params,
7337 const RuntimeShape& input_shape, const T* input_data,
7338 const RuntimeShape& output_shape, T* output_data) {
7339 int s1, s2, s3;
7340 s1 = input_shape.Dims(0);
7341 s2 = input_shape.Dims(1);
7342 s3 = input_shape.Dims(2);
7343
7344 int p1, p2, p3;
7345 if (params.perm[0] == 2) {
7346 p1 = 1;
7347 } else if (params.perm[1] == 2) {
7348 p2 = 1;
7349 } else {
7350 p3 = 1;
7351 }
7352
7353 if (params.perm[0] == 1) {
7354 p1 = s3;
7355 } else if (params.perm[1] == 1) {
7356 p2 = s3;
7357 } else {
7358 p3 = s3;
7359 }
7360
7361 if (params.perm[0] == 0) {
7362 p1 = s2 * s3;
7363 } else if (params.perm[1] == 0) {
7364 p2 = s2 * s3;
7365 } else {
7366 p3 = s2 * s3;
7367 }
7368
7369 int o_s[3];
7370 o_s[0] = input_shape.Dims(params.perm[0]);
7371 o_s[1] = input_shape.Dims(params.perm[1]);
7372 o_s[2] = input_shape.Dims(params.perm[2]);
7373
7374 for (int i1 = 0; i1 < o_s[0]; ++i1) {
7375 for (int i2 = 0; i2 < o_s[1]; ++i2) {
7376 for (int i3 = 0; i3 < o_s[2]; ++i3) {
7377 const int i = i1 * p1 + i2 * p2 + i3 * p3;
7378 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
7379 output_data[o] = input_data[i];
7380 }
7381 }
7382 }
7383 }
7384
7385 template <typename T, int N>
TransposeImpl(const TransposeParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)7386 void TransposeImpl(const TransposeParams& params,
7387 const RuntimeShape& input_shape, const T* input_data,
7388 const RuntimeShape& output_shape, T* output_data) {
7389 const int dims_cnt = input_shape.DimensionsCount();
7390
7391 int dim0, dim1;
7392 if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
7393 &dim1)) {
7394 Transpose2D(RuntimeShape({dim0, dim1}), input_data,
7395 RuntimeShape({dim1, dim0}), output_data);
7396 return;
7397 }
7398
7399 // TODO(b/141217325): notably Eigen is better suited for
7400 // larger inputs whereas Transpose3D is generally
7401 // better for smaller ones.
7402 //
7403 // E.g. on Nexus 5, Eigen is better for size 96^3 and up
7404 // and Transpose3D is better for 72^3 and down.
7405 //
7406 // 96^3 is not mobile-friendly for certain usecases
7407 // (e.g. model used in beam search for seq2seq) but is in others.
7408 // Consider tradeoffs.
7409 if (dims_cnt == 3) {
7410 Transpose3D(params, input_shape, input_data, output_shape, output_data);
7411 return;
7412 }
7413
7414 // Reroute to the reference version if an optimized method for the given data
7415 // is not available.
7416 reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
7417 output_data);
7418 }
7419
7420 template <typename T, int N = 5>
Transpose(const TransposeParams & unshrinked_params,const RuntimeShape & unshrinked_input_shape,const T * input_data,const RuntimeShape & unshrinked_output_shape,T * output_data)7421 void Transpose(const TransposeParams& unshrinked_params,
7422 const RuntimeShape& unshrinked_input_shape, const T* input_data,
7423 const RuntimeShape& unshrinked_output_shape, T* output_data) {
7424 ruy::profiler::ScopeLabel label("Transpose");
7425
7426 const int output_size = unshrinked_output_shape.DimensionsCount();
7427 TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
7428 TFLITE_DCHECK_LE(output_size, N);
7429 TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
7430
7431 RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
7432 RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
7433 TransposeParams shrinked_params = unshrinked_params;
7434
7435 // Reduce any dimensions that have one size. Lower transpose op usually
7436 // performs better since memory access patterns will be improved.
7437 transpose_utils::RemoveOneSizeDimensions(
7438 &shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
7439
7440 // Handle identity cases.
7441 // TODO(b/140779653): Add an optimization pass in the conversion process to
7442 // remove transpose op nodes where they do nothing like the below one.
7443 bool identical = true;
7444 for (int i = 0; i < shrinked_params.perm_count; ++i) {
7445 if (shrinked_params.perm[i] != i) {
7446 identical = false;
7447 break;
7448 }
7449 }
7450 if (identical) {
7451 memcpy(output_data, input_data,
7452 unshrinked_input_shape.FlatSize() * sizeof(T));
7453 return;
7454 }
7455
7456 // Reduce dimensions by flattening.
7457 if (shrinked_params.perm[0] == 0 && output_size >= 3) {
7458 RuntimeShape non_flatten_input_shape;
7459 RuntimeShape non_flatten_output_shape;
7460 TransposeParams non_flatten_params;
7461 const int total_size = shrinked_input_shape.FlatSize();
7462 const int non_flatten_size = transpose_utils::Flatten(
7463 shrinked_input_shape, shrinked_output_shape, shrinked_params,
7464 &non_flatten_input_shape, &non_flatten_output_shape,
7465 &non_flatten_params);
7466 TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
7467
7468 for (int i = 0; i < total_size; i += non_flatten_size) {
7469 TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
7470 input_data + i, non_flatten_output_shape,
7471 output_data + i);
7472 }
7473 return;
7474 }
7475
7476 // Call non-flattened case.
7477 TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
7478 shrinked_output_shape, output_data);
7479 }
7480
7481 // Assume input1 & input2 have the same scale & zero point.
MaximumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7482 inline void MaximumElementwise(int size, const ArithmeticParams& params,
7483 const int8* input1_data, const int8* input2_data,
7484 int8* output_data) {
7485 ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit");
7486 int i = 0;
7487 #ifdef USE_NEON
7488 for (; i <= size - 16; i += 16) {
7489 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7490 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7491 const int8x16_t max_data =
7492 vmaxq_s8(input1_val_original, input2_val_original);
7493 vst1q_s8(output_data + i, max_data);
7494 }
7495 #endif // USE_NEON
7496 for (; i < size; ++i) {
7497 const int8 input1_val = input1_data[i];
7498 const int8 input2_val = input2_data[i];
7499 output_data[i] = std::max(input1_val, input2_val);
7500 }
7501 }
7502
MaximumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7503 inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params,
7504 int8 input1_data, const int8* input2_data,
7505 int8* output_data) {
7506 ruy::profiler::ScopeLabel label("MaximumScalarBroadcastInt8/8bit");
7507 int i = 0;
7508
7509 #ifdef USE_NEON
7510 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7511 for (; i <= size - 16; i += 16) {
7512 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7513 const int8x16_t max_data =
7514 vmaxq_s8(input1_val_original, input2_val_original);
7515 vst1q_s8(output_data + i, max_data);
7516 }
7517 #endif // USE_NEON
7518 for (; i < size; ++i) {
7519 const int8 input2_val = input2_data[i];
7520 output_data[i] = std::max(input1_data, input2_val);
7521 }
7522 }
7523
7524 // Assume input1 & input2 have the same scale & zero point.
MinimumElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)7525 inline void MinimumElementwise(int size, const ArithmeticParams& params,
7526 const int8* input1_data, const int8* input2_data,
7527 int8* output_data) {
7528 ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit");
7529 int i = 0;
7530 #ifdef USE_NEON
7531 for (; i <= size - 16; i += 16) {
7532 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
7533 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7534 const int8x16_t min_data =
7535 vminq_s8(input1_val_original, input2_val_original);
7536 vst1q_s8(output_data + i, min_data);
7537 }
7538 #endif // USE_NEON
7539 for (; i < size; ++i) {
7540 const int8 input1_val = input1_data[i];
7541 const int8 input2_val = input2_data[i];
7542 output_data[i] = std::min(input1_val, input2_val);
7543 }
7544 }
7545
MinimumScalarBroadcast(int size,const ArithmeticParams & params,int8 input1_data,const int8 * input2_data,int8 * output_data)7546 inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params,
7547 int8 input1_data, const int8* input2_data,
7548 int8* output_data) {
7549 ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit");
7550 int i = 0;
7551
7552 #ifdef USE_NEON
7553 const int8x16_t input1_val_original = vdupq_n_s8(input1_data);
7554 for (; i <= size - 16; i += 16) {
7555 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
7556 const int8x16_t min_data =
7557 vminq_s8(input1_val_original, input2_val_original);
7558 vst1q_s8(output_data + i, min_data);
7559 }
7560 #endif // USE_NEON
7561 for (; i < size; ++i) {
7562 const int8 input2_val = input2_data[i];
7563 output_data[i] = std::min(input1_data, input2_val);
7564 }
7565 }
7566
7567 template <typename Op>
BroadcastMaximumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7568 inline void BroadcastMaximumDispatch(const ArithmeticParams& params,
7569 const RuntimeShape& input1_shape,
7570 const int8* input1_data,
7571 const RuntimeShape& input2_shape,
7572 const int8* input2_data,
7573 const RuntimeShape& output_shape,
7574 int8* output_data, Op op) {
7575 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7576 return reference_ops::MaximumMinimumBroadcastSlow(
7577 input1_shape, input1_data, input2_shape, input2_data, output_shape,
7578 output_data, op);
7579 }
7580
7581 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7582 input2_data, output_shape, output_data,
7583 MaximumElementwise, MaximumScalarBroadcast);
7584 }
7585
7586 template <typename Op>
BroadcastMinimumDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data,Op op)7587 inline void BroadcastMinimumDispatch(const ArithmeticParams& params,
7588 const RuntimeShape& input1_shape,
7589 const int8* input1_data,
7590 const RuntimeShape& input2_shape,
7591 const int8* input2_data,
7592 const RuntimeShape& output_shape,
7593 int8* output_data, Op op) {
7594 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7595 return reference_ops::MaximumMinimumBroadcastSlow(
7596 input1_shape, input1_data, input2_shape, input2_data, output_shape,
7597 output_data, op);
7598 }
7599
7600 BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape,
7601 input2_data, output_shape, output_data,
7602 MinimumElementwise, MinimumScalarBroadcast);
7603 }
7604
7605 template <typename T>
CumsumImpl(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7606 void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
7607 bool exclusive, bool reverse, T* output_data) {
7608 Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
7609
7610 for (int i = 0; i < axis; ++i) {
7611 dims[0] *= shape.Dims(i);
7612 }
7613 dims[1] = shape.Dims(axis);
7614 for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
7615 dims[2] *= shape.Dims(i);
7616 }
7617
7618 typedef Eigen::TensorMap<
7619 Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
7620 Eigen::Aligned>
7621 ConstTensor;
7622 typedef Eigen::TensorMap<
7623 Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
7624 Tensor;
7625 ConstTensor input(input_data, dims);
7626 Tensor output(output_data, dims);
7627
7628 if (reverse) {
7629 Eigen::array<bool, 3> reverse_idx = {false, true, false};
7630 output =
7631 input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
7632 } else {
7633 output = input.cumsum(1, exclusive);
7634 }
7635 }
7636
7637 template <typename T>
CumSum(const T * input_data,const RuntimeShape & shape,int axis,bool exclusive,bool reverse,T * output_data)7638 void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
7639 bool exclusive, bool reverse, T* output_data) {
7640 const int dim = shape.DimensionsCount();
7641 TFLITE_DCHECK_GE(dim, 1);
7642 CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
7643 }
7644
PReluScalarBroadcast(int size,const ArithmeticParams & params,float alpha,const float * input_data,float * output_data)7645 inline void PReluScalarBroadcast(int size, const ArithmeticParams& params,
7646 float alpha, const float* input_data,
7647 float* output_data) {
7648 ruy::profiler::ScopeLabel label("PreluScalarBroadcast/float");
7649 int i = 0;
7650
7651 #ifdef USE_NEON
7652 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7653 const float32x4_t alpha_dup = vdupq_n_f32(alpha);
7654 for (; i <= size - 16; i += 16) {
7655 const float32x4_t input1 = vld1q_f32(input_data + i);
7656 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7657 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7658 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7659
7660 const float32x4_t temp1 = vmulq_f32(input1, alpha_dup);
7661 const float32x4_t temp2 = vmulq_f32(input2, alpha_dup);
7662 const float32x4_t temp3 = vmulq_f32(input3, alpha_dup);
7663 const float32x4_t temp4 = vmulq_f32(input4, alpha_dup);
7664
7665 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7666 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7667 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7668 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7669
7670 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7671 vst1q_f32(output_data + i, result1);
7672 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7673 vst1q_f32(output_data + i + 4, result2);
7674 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7675 vst1q_f32(output_data + i + 8, result3);
7676 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7677 vst1q_f32(output_data + i + 12, result4);
7678 }
7679
7680 for (; i <= size - 4; i += 4) {
7681 const float32x4_t input = vld1q_f32(input_data + i);
7682 const float32x4_t temp = vmulq_f32(input, alpha_dup);
7683 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7684 const float32x4_t result = vbslq_f32(mask, input, temp);
7685 vst1q_f32(output_data + i, result);
7686 }
7687 #endif // USE_NEON
7688 for (; i < size; ++i) {
7689 const float input = input_data[i];
7690 output_data[i] = input >= 0.f ? input : input * alpha;
7691 }
7692 }
7693
PReluElementWise(int flat_size,const ArithmeticParams & params,const float * alpha_data,const float * input_data,float * output_data)7694 inline void PReluElementWise(int flat_size, const ArithmeticParams& params,
7695 const float* alpha_data, const float* input_data,
7696 float* output_data) {
7697 ruy::profiler::ScopeLabel label("PreluElementWise/float");
7698
7699 int i = 0;
7700 #ifdef USE_NEON
7701 const float32x4_t zero_dup = vdupq_n_f32(0.0f);
7702 for (; i <= flat_size - 16; i += 16) {
7703 const float32x4_t input1 = vld1q_f32(input_data + i);
7704 const float32x4_t alpha1 = vld1q_f32(alpha_data + i);
7705 const float32x4_t input2 = vld1q_f32(input_data + i + 4);
7706 const float32x4_t alpha2 = vld1q_f32(alpha_data + i + 4);
7707 const float32x4_t input3 = vld1q_f32(input_data + i + 8);
7708 const float32x4_t alpha3 = vld1q_f32(alpha_data + i + 8);
7709 const float32x4_t input4 = vld1q_f32(input_data + i + 12);
7710 const float32x4_t alpha4 = vld1q_f32(alpha_data + i + 12);
7711
7712 const float32x4_t temp1 = vmulq_f32(input1, alpha1);
7713 const float32x4_t temp2 = vmulq_f32(input2, alpha2);
7714 const float32x4_t temp3 = vmulq_f32(input3, alpha3);
7715 const float32x4_t temp4 = vmulq_f32(input4, alpha4);
7716
7717 const uint32x4_t mask1 = vcgeq_f32(input1, zero_dup);
7718 const uint32x4_t mask2 = vcgeq_f32(input2, zero_dup);
7719 const uint32x4_t mask3 = vcgeq_f32(input3, zero_dup);
7720 const uint32x4_t mask4 = vcgeq_f32(input4, zero_dup);
7721
7722 const float32x4_t result1 = vbslq_f32(mask1, input1, temp1);
7723 vst1q_f32(output_data + i, result1);
7724 const float32x4_t result2 = vbslq_f32(mask2, input2, temp2);
7725 vst1q_f32(output_data + i + 4, result2);
7726 const float32x4_t result3 = vbslq_f32(mask3, input3, temp3);
7727 vst1q_f32(output_data + i + 8, result3);
7728 const float32x4_t result4 = vbslq_f32(mask4, input4, temp4);
7729 vst1q_f32(output_data + i + 12, result4);
7730 }
7731
7732 for (; i <= flat_size - 4; i += 4) {
7733 const float32x4_t input = vld1q_f32(input_data + i);
7734 const float32x4_t alpha = vld1q_f32(alpha_data + i);
7735
7736 const float32x4_t temp = vmulq_f32(input, alpha);
7737 const uint32x4_t mask = vcgeq_f32(input, zero_dup);
7738 const float32x4_t result = vbslq_f32(mask, input, temp);
7739 vst1q_f32(output_data + i, result);
7740 }
7741 #endif // USE_NEON
7742 for (; i < flat_size; ++i) {
7743 const float input = input_data[i];
7744 const float alpha = alpha_data[i];
7745 output_data[i] = input >= 0.f ? input : input * alpha;
7746 }
7747 }
7748
BroadcastPReluDispatch(const ArithmeticParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & alpha_shape,const float * alpha_data,const RuntimeShape & output_shape,float * output_data,float (* func)(float,float))7749 inline void BroadcastPReluDispatch(
7750 const ArithmeticParams& params, const RuntimeShape& input_shape,
7751 const float* input_data, const RuntimeShape& alpha_shape,
7752 const float* alpha_data, const RuntimeShape& output_shape,
7753 float* output_data, float (*func)(float, float)) {
7754 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
7755 return reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
7756 input_shape, input_data, alpha_shape, alpha_data, output_shape,
7757 output_data, func);
7758 }
7759
7760 BinaryBroadcastFiveFold(params, input_shape, input_data, alpha_shape,
7761 alpha_data, output_shape, output_data,
7762 PReluElementWise, PReluScalarBroadcast);
7763 }
7764
7765 // Returns the index with minimum value within `input_data`.
7766 // If there is a tie, returns the smaller index.
7767 template <typename T>
ArgMinVector(const T * input_data,int size)7768 inline int ArgMinVector(const T* input_data, int size) {
7769 T min_value = input_data[0];
7770 int min_index = 0;
7771 for (int i = 1; i < size; ++i) {
7772 const T curr_value = input_data[i];
7773 if (curr_value < min_value) {
7774 min_value = curr_value;
7775 min_index = i;
7776 }
7777 }
7778 return min_index;
7779 }
7780
7781 // Returns the index with maximum value within `input_data`.
7782 // If there is a tie, returns the smaller index.
7783 template <typename T>
ArgMaxVector(const T * input_data,int size)7784 inline int ArgMaxVector(const T* input_data, int size) {
7785 T max_value = input_data[0];
7786 int max_index = 0;
7787 for (int i = 1; i < size; ++i) {
7788 const T curr_value = input_data[i];
7789 if (curr_value > max_value) {
7790 max_value = curr_value;
7791 max_index = i;
7792 }
7793 }
7794 return max_index;
7795 }
7796
7797 template <>
ArgMinVector(const float * input_data,int size)7798 inline int ArgMinVector(const float* input_data, int size) {
7799 int32_t min_index = 0;
7800 float min_value = input_data[0];
7801 int32_t i = 1;
7802 #ifdef USE_NEON
7803 if (size >= 4) {
7804 float32x4_t min_value_f32x4 = vld1q_f32(input_data);
7805 const int32_t index_init[4] = {0, 1, 2, 3};
7806 int32x4_t min_index_s32x4 = vld1q_s32(index_init);
7807 int32x4_t index_s32x4 = min_index_s32x4;
7808 int32x4_t inc = vdupq_n_s32(4);
7809 for (i = 4; i <= size - 4; i += 4) {
7810 // Increase indices by 4.
7811 index_s32x4 = vaddq_s32(index_s32x4, inc);
7812 float32x4_t v = vld1q_f32(&input_data[i]);
7813 uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
7814 min_value_f32x4 = vminq_f32(min_value_f32x4, v);
7815 min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
7816 }
7817 // Find min element within float32x4_t.
7818 #ifdef __aarch64__
7819 min_value = vminvq_f32(min_value_f32x4);
7820 #else
7821 float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
7822 vget_high_f32(min_value_f32x4));
7823 min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
7824 min_value = vget_lane_f32(min_value_f32x2, 0);
7825 #endif // __aarch64__
7826 // Mask indices of non-min values with max int32_t.
7827 float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
7828 uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
7829 int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7830 min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
7831 // Find min index of min values.
7832 #ifdef __aarch64__
7833 min_index = vminvq_s32(min_index_s32x4);
7834 #else
7835 int32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
7836 vget_high_s32(min_index_s32x4));
7837 min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
7838 min_index = vget_lane_s32(min_index_s32x2, 0);
7839 #endif // __aarch64__
7840 }
7841 #endif // USE_NEON
7842 // Leftover loop.
7843 for (; i < size; ++i) {
7844 const float curr_value = input_data[i];
7845 if (curr_value < min_value) {
7846 min_value = curr_value;
7847 min_index = i;
7848 }
7849 }
7850 return min_index;
7851 }
7852
7853 template <>
ArgMaxVector(const float * input_data,int size)7854 inline int ArgMaxVector(const float* input_data, int size) {
7855 int32_t max_index = 0;
7856 float max_value = input_data[0];
7857 int32_t i = 1;
7858 #ifdef USE_NEON
7859 if (size >= 4) {
7860 float32x4_t max_value_f32x4 = vld1q_f32(input_data);
7861 const int32_t index_init[4] = {0, 1, 2, 3};
7862 int32x4_t max_index_s32x4 = vld1q_s32(index_init);
7863 int32x4_t index_s32x4 = max_index_s32x4;
7864 int32x4_t inc = vdupq_n_s32(4);
7865 for (i = 4; i <= size - 4; i += 4) {
7866 // Increase indices by 4.
7867 index_s32x4 = vaddq_s32(index_s32x4, inc);
7868 float32x4_t v = vld1q_f32(&input_data[i]);
7869 uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
7870 max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
7871 max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
7872 }
7873 // Find max element within float32x4_t.
7874 #ifdef __aarch64__
7875 max_value = vmaxvq_f32(max_value_f32x4);
7876 #else
7877 float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
7878 vget_high_f32(max_value_f32x4));
7879 max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
7880 max_value = vget_lane_f32(max_value_f32x2, 0);
7881 #endif // __aarch64__
7882 // Mask indices of non-max values with max int32_t.
7883 float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
7884 uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
7885 int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
7886 max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
7887 // Find min index of max values.
7888 #ifdef __aarch64__
7889 max_index = vminvq_s32(max_index_s32x4);
7890 #else
7891 int32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
7892 vget_high_s32(max_index_s32x4));
7893 max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
7894 max_index = vget_lane_s32(max_index_s32x2, 0);
7895 #endif // __aarch64__
7896 }
7897 #endif // USE_NEON
7898 // Leftover loop.
7899 for (; i < size; ++i) {
7900 const float curr_value = input_data[i];
7901 if (curr_value > max_value) {
7902 max_value = curr_value;
7903 max_index = i;
7904 }
7905 }
7906 return max_index;
7907 }
7908
7909 template <>
ArgMaxVector(const int8_t * input_data,int size)7910 inline int ArgMaxVector(const int8_t* input_data, int size) {
7911 int32_t max_index = 0;
7912 int8_t max_value = input_data[0];
7913 int32_t i = 0;
7914 #ifdef USE_NEON
7915 constexpr int VECTOR_SIZE = 16;
7916 if (size >= VECTOR_SIZE) {
7917 int8x16_t max_value_s8x16;
7918 for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7919 max_value_s8x16 = vld1q_s8(input_data + i);
7920 int8_t max_from_vec;
7921 #ifdef __aarch64__
7922 max_from_vec = vmaxvq_s8(max_value_s8x16);
7923 #else // 32 bit
7924 int8x8_t max_val_s8x8 =
7925 vpmax_s8(vget_low_s8(max_value_s8x16), vget_high_s8(max_value_s8x16));
7926 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7927 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7928 max_val_s8x8 = vpmax_s8(max_val_s8x8, max_val_s8x8);
7929 max_from_vec = vget_lane_s8(max_val_s8x8, 0);
7930 #endif // __aarch64__
7931 if (max_from_vec > max_value) {
7932 max_value = max_from_vec;
7933 max_index = i;
7934 }
7935 }
7936 }
7937 for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7938 start_idx++) {
7939 if (input_data[start_idx] == max_value) {
7940 max_index = start_idx;
7941 break;
7942 }
7943 }
7944
7945 #endif // USE_NEON
7946 // Leftover loop.
7947 for (; i < size; ++i) {
7948 const int8_t curr_value = input_data[i];
7949 if (curr_value > max_value) {
7950 max_value = curr_value;
7951 max_index = i;
7952 }
7953 }
7954
7955 return max_index;
7956 }
7957
7958 template <>
ArgMaxVector(const uint8_t * input_data,int size)7959 inline int ArgMaxVector(const uint8_t* input_data, int size) {
7960 int32_t max_index = 0;
7961 uint8_t max_value = input_data[0];
7962 int32_t i = 0;
7963 #ifdef USE_NEON
7964 constexpr int VECTOR_SIZE = 16;
7965 if (size >= VECTOR_SIZE) {
7966 uint8x16_t max_value_u8x16;
7967 for (; i <= size - VECTOR_SIZE; i += VECTOR_SIZE) {
7968 max_value_u8x16 = vld1q_u8(input_data + i);
7969 uint8_t max_from_vec;
7970 #ifdef __aarch64__
7971 max_from_vec = vmaxvq_u8(max_value_u8x16);
7972 #else // 32 bit
7973 uint8x8_t max_val_u8x8 =
7974 vpmax_u8(vget_low_u8(max_value_u8x16), vget_high_u8(max_value_u8x16));
7975 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7976 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7977 max_val_u8x8 = vpmax_u8(max_val_u8x8, max_val_u8x8);
7978 max_from_vec = vget_lane_u8(max_val_u8x8, 0);
7979 #endif // __aarch64__
7980 if (max_from_vec > max_value) {
7981 max_value = max_from_vec;
7982 max_index = i;
7983 }
7984 }
7985 }
7986 for (int start_idx = max_index; start_idx < max_index + VECTOR_SIZE;
7987 start_idx++) {
7988 if (input_data[start_idx] == max_value) {
7989 max_index = start_idx;
7990 break;
7991 }
7992 }
7993
7994 #endif // USE_NEON
7995 // Leftover loop.
7996 for (; i < size; ++i) {
7997 const uint8_t curr_value = input_data[i];
7998 if (curr_value > max_value) {
7999 max_value = curr_value;
8000 max_index = i;
8001 }
8002 }
8003
8004 return max_index;
8005 }
8006
8007 // Specializes ArgMinMax function with axis=dims-1.
8008 // In this case, ArgMinMax reduction is applied on contiguous memory.
8009 template <typename T1, typename T2, bool is_arg_max>
ArgMinMaxLastAxis(const RuntimeShape & input_shape,const T1 * input_data,const RuntimeShape & output_shape,T2 * output_data)8010 inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
8011 const T1* input_data,
8012 const RuntimeShape& output_shape,
8013 T2* output_data) {
8014 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
8015 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
8016 TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
8017
8018 int outer_size = input_shape.Dims(0);
8019 int axis_size = input_shape.Dims(1);
8020 for (int outer = 0; outer < outer_size; ++outer) {
8021 if (is_arg_max) {
8022 output_data[outer] = static_cast<T2>(
8023 ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
8024 } else {
8025 output_data[outer] = static_cast<T2>(
8026 ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
8027 }
8028 }
8029 }
8030
8031 template <typename T1, typename T2, typename T3>
ArgMinMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data,const bool is_arg_max)8032 inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
8033 const T3* input2_data, const RuntimeShape& output_shape,
8034 T2* output_data, const bool is_arg_max) {
8035 ruy::profiler::ScopeLabel label("ArgMinMax");
8036
8037 TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
8038 TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
8039 output_shape.DimensionsCount());
8040 int axis = input2_data[0];
8041 if (axis < 0) {
8042 axis += input1_shape.DimensionsCount();
8043 }
8044 const int axis_size = input1_shape.Dims(axis);
8045
8046 int outer_size = 1;
8047 for (int i = 0; i < axis; ++i) {
8048 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
8049 outer_size *= input1_shape.Dims(i);
8050 }
8051
8052 int inner_size = 1;
8053 const int dims_count = input1_shape.DimensionsCount();
8054 for (int i = axis + 1; i < dims_count; ++i) {
8055 TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
8056 inner_size *= input1_shape.Dims(i);
8057 }
8058
8059 // Call specialized function when axis=dims-1. So far, only float32 is
8060 // optimized so reroute to specialized function only when T1 is float32.
8061 if (inner_size == 1 &&
8062 (std::is_same<T1, float>::value || std::is_same<T1, int8_t>::value ||
8063 std::is_same<T1, uint8_t>::value)) {
8064 if (is_arg_max) {
8065 ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
8066 {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8067 } else {
8068 ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
8069 {outer_size, axis_size}, input1_data, {outer_size}, output_data);
8070 }
8071 return;
8072 }
8073
8074 reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
8075 output_data, is_arg_max);
8076 }
8077
8078 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8079 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8080 const T3* input2_data, const RuntimeShape& output_shape,
8081 T2* output_data) {
8082 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
8083 /*is_arg_max=*/true);
8084 }
8085
8086 // Convenience version that allows, for example, generated-code calls to be
8087 // the same as other binary ops.
8088 // For backward compatibility, reference_ops has ArgMax function.
8089 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)8090 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
8091 const RuntimeShape& input2_shape, const T3* input2_data,
8092 const RuntimeShape& output_shape, T2* output_data) {
8093 // Drop shape of second input: not needed.
8094 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
8095 }
8096
Conv3D(const Conv3DParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data,const RuntimeShape & transposed_filter_shape,float * transposed_filter_data,CpuBackendContext * cpu_backend_context)8097 inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
8098 const float* input_data, const RuntimeShape& filter_shape,
8099 const float* filter_data, const RuntimeShape& bias_shape,
8100 const float* bias_data, const RuntimeShape& output_shape,
8101 float* output_data, const RuntimeShape& im2col_shape,
8102 float* im2col_data,
8103 const RuntimeShape& transposed_filter_shape,
8104 float* transposed_filter_data,
8105 CpuBackendContext* cpu_backend_context) {
8106 const int stride_depth = params.stride_depth;
8107 const int stride_height = params.stride_height;
8108 const int stride_width = params.stride_width;
8109 const int dilation_depth_factor = params.dilation_depth;
8110 const int dilation_height_factor = params.dilation_height;
8111 const int dilation_width_factor = params.dilation_width;
8112 const float output_activation_min = params.float_activation_min;
8113 const float output_activation_max = params.float_activation_max;
8114 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8115 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8116 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8117
8118 ruy::profiler::ScopeLabel label("Conv3D");
8119
8120 // NB: the float 0.0f value is represented by all zero bytes.
8121 const uint8 float_zero_byte = 0x00;
8122 const float* gemm_input_data = nullptr;
8123 const RuntimeShape* gemm_input_shape = nullptr;
8124 const int filter_width = filter_shape.Dims(2);
8125 const int filter_height = filter_shape.Dims(1);
8126 const int filter_depth = filter_shape.Dims(0);
8127 const bool need_dilated_im2col = dilation_width_factor != 1 ||
8128 dilation_height_factor != 1 ||
8129 dilation_depth_factor != 1;
8130 const bool need_im2col = stride_depth != 1 || stride_height != 1 ||
8131 stride_width != 1 || filter_depth != 1 ||
8132 filter_height != 1 || filter_width != 1;
8133
8134 if (need_dilated_im2col) {
8135 DilatedIm2col3D(params, filter_depth, filter_height, filter_width,
8136 float_zero_byte, input_shape, input_data, im2col_shape,
8137 im2col_data);
8138 gemm_input_data = im2col_data;
8139 gemm_input_shape = &im2col_shape;
8140 } else if (need_im2col) {
8141 TFLITE_DCHECK(im2col_data);
8142 Im2col3D(params, filter_depth, filter_height, filter_width, float_zero_byte,
8143 input_shape, input_data, im2col_shape, im2col_data);
8144 gemm_input_data = im2col_data;
8145 gemm_input_shape = &im2col_shape;
8146 } else {
8147 TFLITE_DCHECK(!im2col_data);
8148 gemm_input_data = input_data;
8149 gemm_input_shape = &input_shape;
8150 }
8151
8152 // Transpose the filter tensor.
8153 TransposeParams transpose_params;
8154 transpose_params.perm_count = 5;
8155 transpose_params.perm[0] = 4;
8156 transpose_params.perm[1] = 0;
8157 transpose_params.perm[2] = 1;
8158 transpose_params.perm[3] = 2;
8159 transpose_params.perm[4] = 3;
8160 Transpose<float, 5>(transpose_params, filter_shape, filter_data,
8161 transposed_filter_shape, transposed_filter_data);
8162
8163 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
8164 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
8165 int n = output_shape.Dims(4);
8166 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
8167
8168 cpu_backend_gemm::MatrixParams<float> lhs_params;
8169 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8170 lhs_params.rows = n;
8171 lhs_params.cols = k;
8172 cpu_backend_gemm::MatrixParams<float> rhs_params;
8173 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8174 rhs_params.rows = k;
8175 rhs_params.cols = m;
8176 cpu_backend_gemm::MatrixParams<float> dst_params;
8177 dst_params.order = cpu_backend_gemm::Order::kColMajor;
8178 dst_params.rows = n;
8179 dst_params.cols = m;
8180 cpu_backend_gemm::GemmParams<float, float> gemm_params;
8181 gemm_params.bias = bias_data;
8182 gemm_params.clamp_min = output_activation_min;
8183 gemm_params.clamp_max = output_activation_max;
8184 cpu_backend_gemm::Gemm(lhs_params, transposed_filter_data, rhs_params,
8185 gemm_input_data, dst_params, output_data, gemm_params,
8186 cpu_backend_context);
8187 }
8188
8189 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
8190 // order (planes, height, width, channel), constructed from patches in
8191 // 'col_data', which is required to be in storage order (out_planes * out_height
8192 // * out_width, filter_planes, filter_height, filter_width, in_channel).
8193 //
8194 // This function is copied from tensorflow/core/kernels/conv_grad_ops_3d.cc
8195 // authored by Eugene Zhulenev(ezhulenev).
8196 template <typename T>
Col2im(const T * col_data,const int channel,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)8197 void Col2im(const T* col_data, const int channel, const int planes,
8198 const int height, const int width, const int filter_p,
8199 const int filter_h, const int filter_w, const int pad_pt,
8200 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
8201 const int pad_r, const int stride_p, const int stride_h,
8202 const int stride_w, T* im_data) {
8203 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
8204 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
8205 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
8206 int p_pad = -pad_pt;
8207 for (int p = 0; p < planes_col; ++p) {
8208 int h_pad = -pad_t;
8209 for (int h = 0; h < height_col; ++h) {
8210 int w_pad = -pad_l;
8211 for (int w = 0; w < width_col; ++w) {
8212 T* im_patch_data =
8213 im_data +
8214 (p_pad * height * width + h_pad * width + w_pad) * channel;
8215 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
8216 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
8217 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
8218 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
8219 iw < width) {
8220 for (int i = 0; i < channel; ++i) {
8221 im_patch_data[i] += col_data[i];
8222 }
8223 }
8224 im_patch_data += channel;
8225 col_data += channel;
8226 }
8227 // Jump over remaining number of channel.
8228 im_patch_data += channel * (width - filter_w);
8229 }
8230 // Jump over remaining number of (channel * width).
8231 im_patch_data += (channel * width) * (height - filter_h);
8232 }
8233 w_pad += stride_w;
8234 }
8235 h_pad += stride_h;
8236 }
8237 p_pad += stride_p;
8238 }
8239 }
8240
8241 template <typename T>
BiasAdd3D(T * im_data,const T * bias_data,const RuntimeShape & input_shape,float float_activation_min,float float_activation_max)8242 void BiasAdd3D(T* im_data, const T* bias_data, const RuntimeShape& input_shape,
8243 float float_activation_min, float float_activation_max) {
8244 if (bias_data) {
8245 const int outer_size = input_shape.Dims(0) * input_shape.Dims(1) *
8246 input_shape.Dims(2) * input_shape.Dims(3);
8247 const int num_channels = input_shape.Dims(4);
8248 for (int n = 0; n < outer_size; ++n) {
8249 for (int c = 0; c < num_channels; ++c) {
8250 im_data[c] = ActivationFunctionWithMinMax(im_data[c] + bias_data[c],
8251 float_activation_min,
8252 float_activation_max);
8253 }
8254 im_data += num_channels;
8255 }
8256 } else {
8257 const int flat_size = input_shape.FlatSize();
8258 for (int i = 0; i < flat_size; ++i) {
8259 im_data[i] = ActivationFunctionWithMinMax(
8260 im_data[i], float_activation_min, float_activation_max);
8261 }
8262 }
8263 }
8264
Conv3DTranspose(const Conv3DTransposeParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * const output_data,const RuntimeShape & col2im_shape,float * col2im_data,CpuBackendContext * cpu_backend_context)8265 inline void Conv3DTranspose(
8266 const Conv3DTransposeParams& params, const RuntimeShape& input_shape,
8267 const float* input_data, const RuntimeShape& filter_shape,
8268 const float* filter_data, const RuntimeShape& bias_shape,
8269 const float* bias_data, const RuntimeShape& output_shape,
8270 float* const output_data, const RuntimeShape& col2im_shape,
8271 float* col2im_data, CpuBackendContext* cpu_backend_context) {
8272 ruy::profiler::ScopeLabel label("Conv3DTranspose/float");
8273 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
8274 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
8275 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
8276 TFLITE_DCHECK(col2im_data);
8277
8278 const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
8279 const int input_channel = MatchingDim(input_shape, 4, filter_shape, 4);
8280 const int output_channel = MatchingDim(output_shape, 4, filter_shape, 3);
8281 const int input_spatial_size =
8282 input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3);
8283 const int output_spatial_size =
8284 output_shape.Dims(1) * output_shape.Dims(2) * output_shape.Dims(3);
8285
8286 const int output_spatial_dim_1 = output_shape.Dims(1);
8287 const int output_spatial_dim_2 = output_shape.Dims(2);
8288 const int output_spatial_dim_3 = output_shape.Dims(3);
8289 const int input_offset = input_spatial_size * input_channel;
8290 const int output_offset = output_spatial_size * output_channel;
8291
8292 const int filter_spatial_dim_1 = filter_shape.Dims(0);
8293 const int filter_spatial_dim_2 = filter_shape.Dims(1);
8294 const int filter_spatial_dim_3 = filter_shape.Dims(2);
8295
8296 const int spatial_dim_1_padding_before = params.padding_values.depth;
8297 const int spatial_dim_1_padding_after =
8298 params.padding_values.height + params.padding_values.depth_offset;
8299 const int spatial_dim_2_padding_before = params.padding_values.height;
8300 const int spatial_dim_2_padding_after =
8301 params.padding_values.height + params.padding_values.height_offset;
8302 const int spatial_dim_3_padding_before = params.padding_values.width;
8303 const int spatial_dim_3_padding_after =
8304 params.padding_values.width + params.padding_values.width_offset;
8305 const int spatial_dim_1_stride = params.stride_depth;
8306 const int spatial_dim_2_stride = params.stride_height;
8307 const int spatial_dim_3_stride = params.stride_width;
8308 const int filter_total_size = filter_spatial_dim_1 * filter_spatial_dim_2 *
8309 filter_spatial_dim_3 * output_channel;
8310
8311 cpu_backend_gemm::MatrixParams<float> lhs_params;
8312 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
8313 lhs_params.rows = filter_total_size;
8314 lhs_params.cols = input_channel;
8315 float* output_data_p = output_data;
8316 std::fill_n(output_data, output_offset * batch_size, 0.0f);
8317 for (int i = 0; i < batch_size; ++i) {
8318 cpu_backend_gemm::MatrixParams<float> rhs_params;
8319 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
8320 rhs_params.rows = input_channel;
8321 rhs_params.cols = input_spatial_size;
8322 cpu_backend_gemm::MatrixParams<float> dst_params;
8323 dst_params.order = cpu_backend_gemm::Order::kColMajor;
8324 dst_params.rows = filter_total_size;
8325 dst_params.cols = input_spatial_size;
8326 cpu_backend_gemm::GemmParams<float, float> gemm_params;
8327 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params,
8328 input_data + input_offset * i, dst_params,
8329 col2im_data, gemm_params, cpu_backend_context);
8330
8331 Col2im(col2im_data, output_channel, output_spatial_dim_1,
8332 output_spatial_dim_2, output_spatial_dim_3, filter_spatial_dim_1,
8333 filter_spatial_dim_2, filter_spatial_dim_3,
8334 spatial_dim_1_padding_before, spatial_dim_2_padding_before,
8335 spatial_dim_3_padding_before, spatial_dim_1_padding_after,
8336 spatial_dim_2_padding_after, spatial_dim_3_padding_after,
8337 spatial_dim_1_stride, spatial_dim_2_stride, spatial_dim_3_stride,
8338 output_data_p);
8339 output_data_p += output_offset;
8340 }
8341 output_data_p = output_data;
8342 BiasAdd3D(output_data_p, bias_data, output_shape, params.float_activation_min,
8343 params.float_activation_max);
8344 }
8345
8346 } // namespace optimized_ops
8347 } // namespace tflite
8348
8349 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8350 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
8351 #pragma GCC diagnostic pop
8352 #endif
8353
8354 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
8355