1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
17
18 #include <assert.h>
19 #include <stdint.h>
20 #include <sys/types.h>
21 #include <algorithm>
22 #include <cmath>
23 #include <cstdint>
24 #include <limits>
25 #include <memory>
26 #include <tuple>
27 #include <type_traits>
28
29 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
30 #include <Accelerate/Accelerate.h>
31 #endif
32
33 #include "Eigen/Core"
34 #include "fixedpoint/fixedpoint.h"
35 #include "public/gemmlowp.h"
36 #include "tensorflow/lite/kernels/internal/common.h"
37 #include "tensorflow/lite/kernels/internal/quantization_util.h"
38 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
39 #include "tensorflow/lite/kernels/internal/round.h"
40 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
41 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
42 #include "tensorflow/lite/kernels/internal/types.h"
43 #include "unsupported/Eigen/CXX11/Tensor"
44
45 namespace tflite {
46 namespace optimized_ops {
47
48 // Unoptimized reference ops:
49 using reference_ops::ArgMax;
50 using reference_ops::ArgMinMax;
51 using reference_ops::Broadcast4DSlowGreater;
52 using reference_ops::Broadcast4DSlowGreaterEqual;
53 using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
54 using reference_ops::Broadcast4DSlowGreaterWithScaling;
55 using reference_ops::Broadcast4DSlowLess;
56 using reference_ops::Broadcast4DSlowLessEqual;
57 using reference_ops::Broadcast4DSlowLessEqualWithScaling;
58 using reference_ops::Broadcast4DSlowLessWithScaling;
59 using reference_ops::BroadcastAdd4DSlow;
60 using reference_ops::BroadcastMul4DSlow;
61 using reference_ops::BroadcastSub4DSlow;
62 using reference_ops::Concatenation;
63 using reference_ops::ConcatenationWithScaling;
64 using reference_ops::DepthConcatenation;
65 using reference_ops::Dequantize;
66 using reference_ops::Div;
67 using reference_ops::Elu;
68 using reference_ops::FakeQuant;
69 using reference_ops::Fill;
70 using reference_ops::Gather;
71 using reference_ops::Greater;
72 using reference_ops::GreaterEqual;
73 using reference_ops::GreaterEqualWithScaling;
74 using reference_ops::GreaterWithScaling;
75 using reference_ops::LeakyRelu;
76 using reference_ops::Less;
77 using reference_ops::LessEqual;
78 using reference_ops::LessEqualWithScaling;
79 using reference_ops::LessWithScaling;
80 using reference_ops::Mean;
81 using reference_ops::ProcessBroadcastShapes;
82 using reference_ops::RankOneSelect;
83 using reference_ops::Relu1;
84 using reference_ops::Relu6;
85 using reference_ops::ReluX;
86 using reference_ops::Select;
87 using reference_ops::SpaceToBatchND;
88 using reference_ops::Split;
89 using reference_ops::StridedSlice;
90 using reference_ops::Sub16;
91 using reference_ops::Transpose;
92
93 // TODO(b/80247582) Remove this constant.
94 // This will be phased out as the shifts are revised with more thought. Use of a
95 // constant enables us to track progress on this work.
96 //
97 // Used to convert from old-style shifts (right) to new-style (left).
98 static constexpr int kReverseShift = -1;
99
100 // Make a local VectorMap typedef allowing to map a float array
101 // as a Eigen vector expression. The std::conditional here is to
102 // construct the suitable Eigen type for the constness of the
103 // data. Indeed, for const data, we need to produce
104 // Eigen::Map<const Eigen::Matrix<float, ...>>
105 // and not the more straightforward
106 // Eigen::Map<Eigen::Matrix<const float, ...>>
107 template <typename Scalar>
108 using VectorMap = typename std::conditional<
109 std::is_const<Scalar>::value,
110 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
111 Eigen::Dynamic, 1>>,
112 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
113
114 template <typename Scalar>
MapAsVector(Scalar * data,const RuntimeShape & shape)115 VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
116 const int size = shape.FlatSize();
117 return VectorMap<Scalar>(data, size, 1);
118 }
119
120 // Make a local VectorMap typedef allowing to map a float array
121 // as a Eigen matrix expression. The same explanation as for VectorMap
122 // above also applies here.
123 template <typename Scalar>
124 using MatrixMap = typename std::conditional<
125 std::is_const<Scalar>::value,
126 Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
127 Eigen::Dynamic, Eigen::Dynamic>>,
128 Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
129
130 template <typename Scalar>
MapAsMatrixWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)131 MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
132 const RuntimeShape& shape) {
133 const int dims_count = shape.DimensionsCount();
134 const int rows = shape.Dims(dims_count - 1);
135 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
136 return MatrixMap<Scalar>(data, rows, cols);
137 }
138
139 template <typename Scalar>
MapAsMatrixWithFirstDimAsCols(Scalar * data,const RuntimeShape & shape)140 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
141 const RuntimeShape& shape) {
142 const int cols = shape.Dims(0);
143 const int rows = FlatSizeSkipDim(shape, 0);
144 return MatrixMap<Scalar>(data, rows, cols);
145 }
146
147 template <typename Scalar>
148 using ArrayMap = typename std::conditional<
149 std::is_const<Scalar>::value,
150 Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
151 Eigen::Dynamic, Eigen::Dynamic>>,
152 Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
153
154 template <typename Scalar>
MapAsArrayWithLastDimAsRows(Scalar * data,const RuntimeShape & shape)155 ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
156 const RuntimeShape& shape) {
157 const int dims_count = shape.DimensionsCount();
158 const int rows = shape.Dims(dims_count - 1);
159 const int cols = FlatSizeSkipDim(shape, dims_count - 1);
160 return ArrayMap<Scalar>(data, rows, cols);
161 }
162
163 // Copied from tensorflow/core/framework/tensor_types.h
164 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
165 struct TTypes {
166 // Rank-1 tensor (vector) of scalar type T.
167 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
168 Eigen::Aligned>
169 Flat;
170 typedef Eigen::TensorMap<
171 Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
172 UnalignedConstMatrix;
173 };
174
175 // TODO(b/62193649): this function is only needed as long
176 // as we have the --variable_batch hack.
177 template <typename Scalar>
MapAsMatrixWithGivenNumberOfRows(Scalar * data,const RuntimeShape & shape,int rows)178 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
179 const RuntimeShape& shape,
180 int rows) {
181 const int flatsize = shape.FlatSize();
182 TFLITE_DCHECK_EQ(flatsize % rows, 0);
183 const int cols = flatsize / rows;
184 return MatrixMap<Scalar>(data, rows, cols);
185 }
186
AddBiasAndEvalActivationFunction(float output_activation_min,float output_activation_max,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & array_shape,float * array_data)187 inline void AddBiasAndEvalActivationFunction(float output_activation_min,
188 float output_activation_max,
189 const RuntimeShape& bias_shape,
190 const float* bias_data,
191 const RuntimeShape& array_shape,
192 float* array_data) {
193 #ifdef USE_NEON
194 gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
195 const int bias_size = bias_shape.FlatSize();
196 const int array_size = array_shape.FlatSize();
197 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
198 float* array_ptr = array_data;
199 float* array_end_ptr = array_ptr + array_size;
200 const auto activation_min = vdupq_n_f32(output_activation_min);
201 const auto activation_max = vdupq_n_f32(output_activation_max);
202 for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
203 int i = 0;
204 for (; i <= bias_size - 16; i += 16) {
205 auto b0 = vld1q_f32(bias_data + i);
206 auto b1 = vld1q_f32(bias_data + i + 4);
207 auto b2 = vld1q_f32(bias_data + i + 8);
208 auto b3 = vld1q_f32(bias_data + i + 12);
209 auto a0 = vld1q_f32(array_ptr + i);
210 auto a1 = vld1q_f32(array_ptr + i + 4);
211 auto a2 = vld1q_f32(array_ptr + i + 8);
212 auto a3 = vld1q_f32(array_ptr + i + 12);
213 auto x0 = vaddq_f32(a0, b0);
214 auto x1 = vaddq_f32(a1, b1);
215 auto x2 = vaddq_f32(a2, b2);
216 auto x3 = vaddq_f32(a3, b3);
217 x0 = vmaxq_f32(activation_min, x0);
218 x1 = vmaxq_f32(activation_min, x1);
219 x2 = vmaxq_f32(activation_min, x2);
220 x3 = vmaxq_f32(activation_min, x3);
221 x0 = vminq_f32(activation_max, x0);
222 x1 = vminq_f32(activation_max, x1);
223 x2 = vminq_f32(activation_max, x2);
224 x3 = vminq_f32(activation_max, x3);
225 vst1q_f32(array_ptr + i, x0);
226 vst1q_f32(array_ptr + i + 4, x1);
227 vst1q_f32(array_ptr + i + 8, x2);
228 vst1q_f32(array_ptr + i + 12, x3);
229 }
230 for (; i <= bias_size - 4; i += 4) {
231 auto b = vld1q_f32(bias_data + i);
232 auto a = vld1q_f32(array_ptr + i);
233 auto x = vaddq_f32(a, b);
234 x = vmaxq_f32(activation_min, x);
235 x = vminq_f32(activation_max, x);
236 vst1q_f32(array_ptr + i, x);
237 }
238 for (; i < bias_size; i++) {
239 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
240 output_activation_min,
241 output_activation_max);
242 }
243 }
244 #else // not NEON
245 gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
246 const int bias_size = bias_shape.FlatSize();
247 const int array_size = array_shape.FlatSize();
248 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
249 for (int array_offset = 0; array_offset < array_size;
250 array_offset += bias_size) {
251 for (int i = 0; i < bias_size; i++) {
252 array_data[array_offset + i] = ActivationFunctionWithMinMax(
253 array_data[array_offset + i] + bias_data[i], output_activation_min,
254 output_activation_max);
255 }
256 }
257 #endif
258 }
259
260 template <typename Lhs, typename Rhs, typename Result>
Gemm(const Eigen::MatrixBase<Lhs> & lhs,const Eigen::MatrixBase<Rhs> & rhs,Eigen::MatrixBase<Result> * result)261 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
262 Eigen::MatrixBase<Result>* result) {
263 if (rhs.cols() == 1) {
264 gemmlowp::ScopedProfilingLabel label("GEMV");
265 result->col(0).noalias() = lhs * rhs.col(0);
266 } else {
267 gemmlowp::ScopedProfilingLabel label("GEMM");
268 result->noalias() = lhs * rhs;
269 }
270 }
271
optimized_ops_preload_l1_stream(const uint8 * ptr)272 inline void optimized_ops_preload_l1_stream(const uint8* ptr) {
273 #ifdef GEMMLOWP_ARM_64
274 asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
275 #else
276 gemmlowp::Prefetch(ptr);
277 #endif
278 }
279
optimized_ops_preload_l1_keep(const uint8 * ptr)280 inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
281 #ifdef GEMMLOWP_ARM_64
282 asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
283 #else
284 gemmlowp::Prefetch(ptr);
285 #endif
286 }
287
288 #ifdef GEMMLOWP_NEON
289 // In the common case of batch size 1, a fully-connected node degenerates
290 // to a matrix*vector product. LSTM cells contain a fully-connected node;
291 // when quantized, this becomes a special type of GEMV operation where
292 // the output is 16bit-quantized, thus needs its own special path.
GEMVForLstmCell(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,uint8 weights_zero_point,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)293 inline void GEMVForLstmCell(const RuntimeShape& input_shape,
294 const uint8* input_data,
295 const RuntimeShape& weights_shape,
296 const uint8* weights_data, uint8 weights_zero_point,
297 const RuntimeShape& bias_shape,
298 const int32* bias_data, int32 accum_multiplier,
299 int accum_shift, const RuntimeShape& output_shape,
300 int16* output_data) {
301 gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
302 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
303 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
304 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
305 const int output_dim_count = output_shape.DimensionsCount();
306 const int weights_dim_count = weights_shape.DimensionsCount();
307 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
308 const int input_size = FlatSizeSkipDim(input_shape, 0);
309 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
310 output_shape, output_dim_count - 1);
311 // This special fast path for quantized LSTM cells does not try to support
312 // odd sizes that we haven't encountered in any LSTM cell, that would
313 // require special code (that would go untested until any LSTM cell
314 // exercises it). We just guard our assumptions about size evenness with
315 // the following assertions.
316 TFLITE_DCHECK(!(output_size % 4));
317 TFLITE_DCHECK(!(input_size % 8));
318 const int32* bias_ptr = bias_data;
319 int16* output_ptr = output_data;
320 for (int out = 0; out < output_size; out += 4) {
321 int32x4_t acc_0 = vdupq_n_s32(0);
322 int32x4_t acc_1 = vdupq_n_s32(0);
323 int32x4_t acc_2 = vdupq_n_s32(0);
324 int32x4_t acc_3 = vdupq_n_s32(0);
325 const int16x8_t input_offset_vec = vdupq_n_s16(-128);
326 const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
327 int in = 0;
328 // Handle 16 levels of depth at a time.
329 for (; in <= input_size - 16; in += 16) {
330 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
331 const uint8* weights_ptr = weights_data + in + out * input_size;
332 uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
333 uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
334 uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
335 uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
336 int16x8_t input_val_0, input_val_1;
337 const uint8x8_t low = vget_low_u8(input_val_u8);
338 const uint8x8_t high = vget_high_u8(input_val_u8);
339 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
340 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
341 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
342 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
343 int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
344 weights_val_3_0;
345 int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
346 weights_val_3_1;
347 weights_val_0_0 = vaddq_s16(
348 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
349 weights_offset_vec);
350 weights_val_0_1 = vaddq_s16(
351 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
352 weights_offset_vec);
353 weights_val_1_0 = vaddq_s16(
354 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
355 weights_offset_vec);
356 weights_val_1_1 = vaddq_s16(
357 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
358 weights_offset_vec);
359 weights_val_2_0 = vaddq_s16(
360 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
361 weights_offset_vec);
362 weights_val_2_1 = vaddq_s16(
363 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
364 weights_offset_vec);
365 weights_val_3_0 = vaddq_s16(
366 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
367 weights_offset_vec);
368 weights_val_3_1 = vaddq_s16(
369 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
370 weights_offset_vec);
371 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
372 vget_low_s16(input_val_0));
373 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
374 vget_low_s16(input_val_0));
375 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
376 vget_low_s16(input_val_0));
377 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
378 vget_low_s16(input_val_0));
379 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
380 vget_high_s16(input_val_0));
381 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
382 vget_high_s16(input_val_0));
383 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
384 vget_high_s16(input_val_0));
385 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
386 vget_high_s16(input_val_0));
387 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
388 vget_low_s16(input_val_1));
389 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
390 vget_low_s16(input_val_1));
391 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
392 vget_low_s16(input_val_1));
393 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
394 vget_low_s16(input_val_1));
395 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
396 vget_high_s16(input_val_1));
397 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
398 vget_high_s16(input_val_1));
399 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
400 vget_high_s16(input_val_1));
401 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
402 vget_high_s16(input_val_1));
403 }
404 // Handle 8 levels of depth at a time.
405 for (; in < input_size; in += 8) {
406 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
407 const uint8* weights_ptr = weights_data + in + out * input_size;
408 uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
409 uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
410 uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
411 uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
412 int16x8_t input_val;
413 input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
414 input_val = vaddq_s16(input_val, input_offset_vec);
415 int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
416 weights_val_0 =
417 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
418 weights_offset_vec);
419 weights_val_1 =
420 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
421 weights_offset_vec);
422 weights_val_2 =
423 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
424 weights_offset_vec);
425 weights_val_3 =
426 vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
427 weights_offset_vec);
428 acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
429 vget_low_s16(input_val));
430 acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
431 vget_low_s16(input_val));
432 acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
433 vget_low_s16(input_val));
434 acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
435 vget_low_s16(input_val));
436 acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
437 vget_high_s16(input_val));
438 acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
439 vget_high_s16(input_val));
440 acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
441 vget_high_s16(input_val));
442 acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
443 vget_high_s16(input_val));
444 }
445 // Horizontally reduce accumulators
446 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
447 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
448 pairwise_reduced_acc_0 =
449 vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
450 pairwise_reduced_acc_1 =
451 vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
452 pairwise_reduced_acc_2 =
453 vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
454 pairwise_reduced_acc_3 =
455 vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
456 const int32x2_t reduced_lo =
457 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
458 const int32x2_t reduced_hi =
459 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
460 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
461 // Add bias values.
462 int32x4_t bias_vec = vld1q_s32(bias_ptr);
463 bias_ptr += 4;
464 reduced = vaddq_s32(reduced, bias_vec);
465 int left_shift = accum_shift > 0 ? accum_shift : 0;
466 int right_shift = accum_shift > 0 ? 0 : -accum_shift;
467 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
468 // Multiply by the fixed-point multiplier.
469 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
470 // Rounding-shift-right.
471 using gemmlowp::RoundingDivideByPOT;
472 reduced = RoundingDivideByPOT(reduced, right_shift);
473 // Narrow values down to 16 bit signed.
474 const int16x4_t res16 = vqmovn_s32(reduced);
475 vst1_s16(output_ptr, res16);
476 output_ptr += 4;
477 }
478 }
479 #endif
480
481 #ifdef GEMMLOWP_NEON
GEMVForLstmCellWithSymmetricRange(const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,int32 accum_multiplier,int accum_shift,const RuntimeShape & output_shape,int16 * output_data)482 inline void GEMVForLstmCellWithSymmetricRange(
483 const RuntimeShape& input_shape, const uint8* input_data,
484 const RuntimeShape& weights_shape, const uint8* weights_data,
485 const RuntimeShape& bias_shape, const int32* bias_data,
486 int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
487 int16* output_data) {
488 gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
489 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
490 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
491 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
492 const int output_dim_count = output_shape.DimensionsCount();
493 const int weights_dim_count = weights_shape.DimensionsCount();
494 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
495 const int input_size = FlatSizeSkipDim(input_shape, 0);
496 const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
497 output_shape, output_dim_count - 1);
498 // This special fast path for quantized LSTM cells does not try to support
499 // odd sizes that we haven't encountered in any LSTM cell, that would
500 // require special code (that would go untested until any LSTM cell
501 // exercises it). We just guard our assumptions about size evenness with
502 // the following assertions.
503 TFLITE_DCHECK(!(output_size % 4));
504 TFLITE_DCHECK(!(input_size % 64));
505 const int32* bias_ptr = bias_data;
506 int16* output_ptr = output_data;
507 const uint8x16_t signbit = vdupq_n_u8(0x80);
508 for (int in = 0; in < input_size; in += 32) {
509 optimized_ops_preload_l1_keep(input_data + in);
510 }
511 const int left_shift = accum_shift > 0 ? accum_shift : 0;
512 const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
513 for (int out = 0; out < output_size; out += 4) {
514 // Load the bias values
515 int32x4_t bias_vec = vld1q_s32(bias_ptr);
516 bias_ptr += 4;
517
518 // Clear accumulators. We use 2 accumulator registers per row,
519 // for 4 rows. row_accumRN is the N-th accumulator for row R.
520 int32x4_t row_accum00 = vdupq_n_s32(0);
521 int32x4_t row_accum01 = vdupq_n_s32(0);
522 int32x4_t row_accum10 = vdupq_n_s32(0);
523 int32x4_t row_accum11 = vdupq_n_s32(0);
524 int32x4_t row_accum20 = vdupq_n_s32(0);
525 int32x4_t row_accum21 = vdupq_n_s32(0);
526 int32x4_t row_accum30 = vdupq_n_s32(0);
527 int32x4_t row_accum31 = vdupq_n_s32(0);
528
529 // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
530 const int kReadAhead = 512;
531 // Prefetch the first weights values.
532 for (int k = 0; k < kReadAhead; k += 64) {
533 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
534 k);
535 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
536 k);
537 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
538 k);
539 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
540 k);
541 }
542 // Loop along the rows, handling 64 bytes per iteration because that's
543 // cache line size on most current ARM-architecture CPUs.
544 for (int in = 0; in < input_size; in += 64) {
545 // Prefetch some future weights values.
546 optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
547 in + kReadAhead);
548 optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
549 in + kReadAhead);
550 optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
551 in + kReadAhead);
552 optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
553 in + kReadAhead);
554
555 // We will use 2 local 16-bit accumulators per row, for 2 rows.
556 // See below (*) for the rationale of processing only 2 rows at a time.
557 // local_accumRN is the N-th local accumulator for row R.
558 int16x8_t local_accum00;
559 int16x8_t local_accum01;
560 int16x8_t local_accum10;
561 int16x8_t local_accum11;
562
563 // Load 64 bytes of input activations values. Convert to signed int8
564 // by flipping the sign bit (i.e. subtracting 128, the required
565 // zero_point value).
566 int8x16_t input0 = vreinterpretq_s8_u8(
567 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
568 int8x16_t input1 = vreinterpretq_s8_u8(
569 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
570 int8x16_t input2 = vreinterpretq_s8_u8(
571 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
572 int8x16_t input3 = vreinterpretq_s8_u8(
573 veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
574
575 // Beginning of the core accumulation. Notice how while we have 4
576 // rows to process, this code is taking care of only 2 rows at a time,
577 // thus being divided into two parts looking similar ("Rows 0 and 1" and
578 // "Rows 2 and 3").
579 //
580 // (*) The rationale for handling only 2 rows at a time is to avoid
581 // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
582 // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
583 // we may find ourselves in a situation where rows alias each other in
584 // the L1 cache, and moreover may also mutually alias with the input
585 // activations. If we try to load 4 rows at a time, together with the
586 // input activations, that may be 5 mutually-aliasing vectors, resulting
587 // in constant mutual eviction from L1 cache. Handling 2 rows at a time
588 // here largely mitigates these issues, and seems at least to be very
589 // effective on Cortex-A53:
590 // Before After
591 // big (Cortex-A73) 2.85 ms 2.85 ms
592 // little (Cortex-A53) 11.0 ms 5.16 ms
593
594 // Rows 0 and 1:
595 // Load 64 bytes of weights values from each row. Convert to signed int8
596 // by flipping the sign bit (i.e. subtracting 128, the required
597 // zero_point value).
598 int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
599 signbit,
600 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
601 int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
602 signbit,
603 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
604 int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
605 signbit,
606 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
607 int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
608 signbit,
609 vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
610 int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
611 signbit,
612 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
613 int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
614 signbit,
615 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
616 int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
617 signbit,
618 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
619 int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
620 signbit,
621 vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
622 // Multiply-accumulate into local 16-bit accumulators.
623 // We can accumulate two products without overflow because weights are
624 // required to never be -128, so each product is at most 127^2 in absolute
625 // value.
626 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
627 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
628 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
629 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
630 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
631 vget_high_s8(input0));
632 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
633 vget_high_s8(input1));
634 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
635 vget_high_s8(input0));
636 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
637 vget_high_s8(input1));
638 // Pairwise add and accumulate into 32-bit accumulators
639 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
640 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
641 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
642 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
643 // Multiply-accumulate into local 16-bit accumulators.
644 // We can accumulate two products without overflow because weights are
645 // required to never be -128, so each product is at most 127^2 in absolute
646 // value.
647 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
648 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
649 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
650 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
651 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
652 vget_high_s8(input2));
653 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
654 vget_high_s8(input3));
655 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
656 vget_high_s8(input2));
657 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
658 vget_high_s8(input3));
659 // Pairwise add and accumulate into 32-bit accumulators
660 row_accum00 = vpadalq_s16(row_accum00, local_accum00);
661 row_accum01 = vpadalq_s16(row_accum01, local_accum01);
662 row_accum10 = vpadalq_s16(row_accum10, local_accum10);
663 row_accum11 = vpadalq_s16(row_accum11, local_accum11);
664
665 // Rows 2 and 3:
666 // Load 64 bytes of weights values from each row. Convert to signed int8
667 // by flipping the sign bit (i.e. subtracting 128, the required
668 // zero_point value).
669 weights00 = vreinterpretq_s8_u8(veorq_u8(
670 signbit,
671 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
672 weights01 = vreinterpretq_s8_u8(veorq_u8(
673 signbit,
674 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
675 weights02 = vreinterpretq_s8_u8(veorq_u8(
676 signbit,
677 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
678 weights03 = vreinterpretq_s8_u8(veorq_u8(
679 signbit,
680 vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
681 weights10 = vreinterpretq_s8_u8(veorq_u8(
682 signbit,
683 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
684 weights11 = vreinterpretq_s8_u8(veorq_u8(
685 signbit,
686 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
687 weights12 = vreinterpretq_s8_u8(veorq_u8(
688 signbit,
689 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
690 weights13 = vreinterpretq_s8_u8(veorq_u8(
691 signbit,
692 vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
693 // Multiply-accumulate into local 16-bit accumulators.
694 // We can accumulate two products without overflow because weights are
695 // required to never be -128, so each product is at most 127^2 in absolute
696 // value.
697 local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
698 local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
699 local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
700 local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
701 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
702 vget_high_s8(input0));
703 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
704 vget_high_s8(input1));
705 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
706 vget_high_s8(input0));
707 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
708 vget_high_s8(input1));
709 // Pairwise add and accumulate into 32-bit accumulators
710 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
711 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
712 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
713 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
714 // Multiply-accumulate into local 16-bit accumulators.
715 // We can accumulate two products without overflow because weights are
716 // required to never be -128, so each product is at most 127^2 in absolute
717 // value.
718 local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
719 local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
720 local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
721 local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
722 local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
723 vget_high_s8(input2));
724 local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
725 vget_high_s8(input3));
726 local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
727 vget_high_s8(input2));
728 local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
729 vget_high_s8(input3));
730 // Pairwise add and accumulate into 32-bit accumulators
731 row_accum20 = vpadalq_s16(row_accum20, local_accum00);
732 row_accum21 = vpadalq_s16(row_accum21, local_accum01);
733 row_accum30 = vpadalq_s16(row_accum30, local_accum10);
734 row_accum31 = vpadalq_s16(row_accum31, local_accum11);
735 }
736
737 row_accum00 = vaddq_s32(row_accum00, row_accum01);
738 row_accum10 = vaddq_s32(row_accum10, row_accum11);
739 row_accum20 = vaddq_s32(row_accum20, row_accum21);
740 row_accum30 = vaddq_s32(row_accum30, row_accum31);
741 // Horizontally reduce accumulators
742 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
743 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
744 pairwise_reduced_acc_0 =
745 vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
746 pairwise_reduced_acc_1 =
747 vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
748 pairwise_reduced_acc_2 =
749 vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
750 pairwise_reduced_acc_3 =
751 vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
752 const int32x2_t reduced_lo =
753 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
754 const int32x2_t reduced_hi =
755 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
756 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
757 // Add bias values.
758 reduced = vaddq_s32(reduced, bias_vec);
759 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
760 // Multiply by the fixed-point multiplier.
761 reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
762 // Rounding-shift-right.
763 using gemmlowp::RoundingDivideByPOT;
764 reduced = RoundingDivideByPOT(reduced, right_shift);
765 // Narrow values down to 16 bit signed.
766 const int16x4_t res16 = vqmovn_s32(reduced);
767 vst1_s16(output_ptr, res16);
768 output_ptr += 4;
769 }
770 }
771 #endif
772
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & bias_shape,const float * optional_bias_data,const RuntimeShape & output_shape,float * output_data)773 inline void FullyConnected(
774 const FullyConnectedParams& params, const RuntimeShape& input_shape,
775 const float* input_data, const RuntimeShape& weights_shape,
776 const float* weights_data, const RuntimeShape& bias_shape,
777 const float* optional_bias_data, const RuntimeShape& output_shape,
778 float* output_data) {
779 gemmlowp::ScopedProfilingLabel label("FullyConnected");
780 const float output_activation_min = params.float_activation_min;
781 const float output_activation_max = params.float_activation_max;
782
783 // TODO(b/62193649): this convoluted shape computation (determining
784 // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
785 // is because the current --variable_batch hack consists in overwriting the
786 // 3rd dimension with the runtime batch size, as we don't keep track for each
787 // array of which dimension is the batch dimension in it.
788 // When that is fixed, this should become:
789 // const auto input_matrix_map =
790 // MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
791 const int dims_count = weights_shape.DimensionsCount();
792 const int input_rows = weights_shape.Dims(dims_count - 1);
793 const auto input_matrix_map =
794 MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
795 const auto filter_matrix_map =
796 MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
797 auto output_matrix_map =
798 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
799
800 Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
801
802 if (optional_bias_data != nullptr) {
803 AddBiasAndEvalActivationFunction(
804 output_activation_min, output_activation_max, bias_shape,
805 optional_bias_data, output_shape, output_data);
806 } else {
807 const int flat_size = output_shape.FlatSize();
808 for (int i = 0; i < flat_size; ++i) {
809 output_data[i] = ActivationFunctionWithMinMax(
810 output_data[i], output_activation_min, output_activation_max);
811 }
812 }
813 }
814
815 #ifdef USE_NEON
FullyConnectedAsGEMVWorkerImpl(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,int row_start,int row_end)816 inline void FullyConnectedAsGEMVWorkerImpl(
817 const RuntimeShape& input_shape, const uint8* input_data,
818 int32 input_offset, const RuntimeShape& filter_shape,
819 const uint8* filter_data, int32 filter_offset,
820 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
821 int32 output_multiplier, int output_shift, int32 output_activation_min,
822 int32 output_activation_max, const RuntimeShape& output_shape,
823 uint8* output_data, int row_start, int row_end) {
824 gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
825 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
826 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
827 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
828 const int output_dim_count = output_shape.DimensionsCount();
829 TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
830 const int input_size = FlatSizeSkipDim(input_shape, 0);
831 static constexpr int kPeel = 4;
832 const bool shift_left = (output_shift > 0);
833 for (int k = 0; k < input_size; k += 64) {
834 optimized_ops_preload_l1_stream(input_data + k);
835 }
836 for (int k = 0; k < kPeel * input_size; k += 64) {
837 optimized_ops_preload_l1_stream(filter_data + k);
838 }
839
840 TFLITE_DCHECK_GE(row_end - row_start, kPeel);
841
842 for (int out = row_start; out < row_end; out += kPeel) {
843 out = std::min(out, row_end - kPeel);
844 int32x4_t acc0 = vdupq_n_s32(0);
845 int32x4_t acc1 = acc0;
846 int32x4_t acc2 = acc0;
847 int32x4_t acc3 = acc0;
848 const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
849 const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
850 int in = 0;
851 for (; in <= input_size - 16; in += 16) {
852 const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
853 const uint8* filter_ptr = filter_data + in + out * input_size;
854 uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr);
855 optimized_ops_preload_l1_stream(filter_ptr + 64);
856 filter_ptr += input_size;
857 uint8x16_t filter_val_u8_1 = vld1q_u8(filter_ptr);
858 optimized_ops_preload_l1_stream(filter_ptr + 64);
859 filter_ptr += input_size;
860 uint8x16_t filter_val_u8_2 = vld1q_u8(filter_ptr);
861 optimized_ops_preload_l1_stream(filter_ptr + 64);
862 filter_ptr += input_size;
863 uint8x16_t filter_val_u8_3 = vld1q_u8(filter_ptr);
864 optimized_ops_preload_l1_stream(filter_ptr + 64);
865 int16x8_t input_val_0, input_val_1;
866 uint8x8_t low = vget_low_u8(input_val_u8);
867 uint8x8_t high = vget_high_u8(input_val_u8);
868 input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
869 input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
870 input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
871 input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
872 low = vget_low_u8(filter_val_u8_0);
873 high = vget_high_u8(filter_val_u8_0);
874 int16x8_t filter_val_0_0 = vreinterpretq_s16_u16(vmovl_u8(low));
875 int16x8_t filter_val_0_1 = vreinterpretq_s16_u16(vmovl_u8(high));
876 filter_val_0_0 = vaddq_s16(filter_val_0_0, filter_offset_vec);
877 filter_val_0_1 = vaddq_s16(filter_val_0_1, filter_offset_vec);
878 low = vget_low_u8(filter_val_u8_1);
879 high = vget_high_u8(filter_val_u8_1);
880 int16x8_t filter_val_1_0 = vreinterpretq_s16_u16(vmovl_u8(low));
881 int16x8_t filter_val_1_1 = vreinterpretq_s16_u16(vmovl_u8(high));
882 filter_val_1_0 = vaddq_s16(filter_val_1_0, filter_offset_vec);
883 filter_val_1_1 = vaddq_s16(filter_val_1_1, filter_offset_vec);
884 low = vget_low_u8(filter_val_u8_2);
885 high = vget_high_u8(filter_val_u8_2);
886 int16x8_t filter_val_2_0 = vreinterpretq_s16_u16(vmovl_u8(low));
887 int16x8_t filter_val_2_1 = vreinterpretq_s16_u16(vmovl_u8(high));
888 filter_val_2_0 = vaddq_s16(filter_val_2_0, filter_offset_vec);
889 filter_val_2_1 = vaddq_s16(filter_val_2_1, filter_offset_vec);
890 low = vget_low_u8(filter_val_u8_3);
891 high = vget_high_u8(filter_val_u8_3);
892 int16x8_t filter_val_3_0 = vreinterpretq_s16_u16(vmovl_u8(low));
893 int16x8_t filter_val_3_1 = vreinterpretq_s16_u16(vmovl_u8(high));
894 filter_val_3_0 = vaddq_s16(filter_val_3_0, filter_offset_vec);
895 filter_val_3_1 = vaddq_s16(filter_val_3_1, filter_offset_vec);
896 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_0),
897 vget_low_s16(input_val_0));
898 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_0),
899 vget_low_s16(input_val_0));
900 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_0),
901 vget_low_s16(input_val_0));
902 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_0),
903 vget_low_s16(input_val_0));
904 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0_1),
905 vget_low_s16(input_val_1));
906 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1_1),
907 vget_low_s16(input_val_1));
908 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2_1),
909 vget_low_s16(input_val_1));
910 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3_1),
911 vget_low_s16(input_val_1));
912 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_0),
913 vget_high_s16(input_val_0));
914 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_0),
915 vget_high_s16(input_val_0));
916 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_0),
917 vget_high_s16(input_val_0));
918 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_0),
919 vget_high_s16(input_val_0));
920 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0_1),
921 vget_high_s16(input_val_1));
922 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1_1),
923 vget_high_s16(input_val_1));
924 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2_1),
925 vget_high_s16(input_val_1));
926 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3_1),
927 vget_high_s16(input_val_1));
928 }
929 for (; in <= input_size - 8; in += 8) {
930 const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
931 const uint8* filter_ptr = filter_data + in + out * input_size;
932 uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr);
933 filter_ptr += input_size;
934 uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr);
935 filter_ptr += input_size;
936 uint8x8_t filter_val_u8_2 = vld1_u8(filter_ptr);
937 filter_ptr += input_size;
938 uint8x8_t filter_val_u8_3 = vld1_u8(filter_ptr);
939 int16x8_t input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
940 input_val = vaddq_s16(input_val, input_offset_vec);
941 int16x8_t filter_val_0 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_0));
942 filter_val_0 = vaddq_s16(filter_val_0, filter_offset_vec);
943 int16x8_t filter_val_1 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_1));
944 filter_val_1 = vaddq_s16(filter_val_1, filter_offset_vec);
945 int16x8_t filter_val_2 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_2));
946 filter_val_2 = vaddq_s16(filter_val_2, filter_offset_vec);
947 int16x8_t filter_val_3 = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8_3));
948 filter_val_3 = vaddq_s16(filter_val_3, filter_offset_vec);
949 acc0 =
950 vmlal_s16(acc0, vget_low_s16(filter_val_0), vget_low_s16(input_val));
951 acc1 =
952 vmlal_s16(acc1, vget_low_s16(filter_val_1), vget_low_s16(input_val));
953 acc2 =
954 vmlal_s16(acc2, vget_low_s16(filter_val_2), vget_low_s16(input_val));
955 acc3 =
956 vmlal_s16(acc3, vget_low_s16(filter_val_3), vget_low_s16(input_val));
957 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
958 vget_high_s16(input_val));
959 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
960 vget_high_s16(input_val));
961 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
962 vget_high_s16(input_val));
963 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
964 vget_high_s16(input_val));
965 }
966 if (in < input_size) {
967 int32 buf[16];
968 vst1q_s32(buf + 0, acc0);
969 vst1q_s32(buf + 4, acc1);
970 vst1q_s32(buf + 8, acc2);
971 vst1q_s32(buf + 12, acc3);
972 for (; in < input_size; in++) {
973 int lane = (in + 8 - input_size) % 4;
974 const int32 input_val = input_data[in] + input_offset;
975 for (int k = 0; k < kPeel; k++) {
976 int32 filter_val =
977 filter_data[in + (out + k) * input_size] + filter_offset;
978 buf[lane + 4 * k] += filter_val * input_val;
979 }
980 }
981 acc0 = vld1q_s32(buf + 0);
982 acc1 = vld1q_s32(buf + 4);
983 acc2 = vld1q_s32(buf + 8);
984 acc3 = vld1q_s32(buf + 12);
985 }
986
987 // Horizontally reduce accumulators
988 int32x2_t pairwise_reduced_acc_0 =
989 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
990 int32x2_t pairwise_reduced_acc_1 =
991 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
992 int32x2_t pairwise_reduced_acc_2 =
993 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
994 int32x2_t pairwise_reduced_acc_3 =
995 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
996 const int32x2_t reduced_lo =
997 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
998 const int32x2_t reduced_hi =
999 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1000 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1001 // Add bias values.
1002 int32x4_t bias_vec = vld1q_s32(bias_data + out);
1003 reduced = vaddq_s32(reduced, bias_vec);
1004 if (shift_left) {
1005 const int32 multiplier_power_of_two = 1 << output_shift;
1006 reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
1007 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1008 } else {
1009 // Multiply by the fixed-point multiplier.
1010 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1011 // Rounding-shift-right.
1012 using gemmlowp::RoundingDivideByPOT;
1013 reduced = RoundingDivideByPOT(reduced, -output_shift);
1014 }
1015 // Add the output offset.
1016 const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1017 reduced = vaddq_s32(reduced, output_offset_vec);
1018 // Narrow values down to 16 bit signed.
1019 const int16x4_t res16 = vqmovn_s32(reduced);
1020 // Narrow values down to 8 bit unsigned, saturating.
1021 uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
1022 // Apply the clamping from the activation function
1023 res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
1024 res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
1025 // Store results to destination.
1026 vst1_lane_u8(output_data + out + 0, res8, 0);
1027 vst1_lane_u8(output_data + out + 1, res8, 1);
1028 vst1_lane_u8(output_data + out + 2, res8, 2);
1029 vst1_lane_u8(output_data + out + 3, res8, 3);
1030 }
1031 }
1032
1033 struct FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task {
FullyConnectedAsGEMVWorkerTaskFullyConnectedAsGEMVWorkerTask1034 FullyConnectedAsGEMVWorkerTask(const RuntimeShape& input_shape,
1035 const uint8* input_data, int32 input_offset,
1036 const RuntimeShape& filter_shape,
1037 const uint8* filter_data, int32 filter_offset,
1038 const RuntimeShape& bias_shape,
1039 const int32* bias_data, int32 output_offset,
1040 int32 output_multiplier, int output_shift,
1041 int32 output_activation_min,
1042 int32 output_activation_max,
1043 const RuntimeShape& output_shape,
1044 uint8* output_data, int row_start, int row_end)
1045 : input_shape_(input_shape),
1046 input_data_(input_data),
1047 input_offset_(input_offset),
1048 filter_shape_(filter_shape),
1049 filter_data_(filter_data),
1050 filter_offset_(filter_offset),
1051 bias_shape_(bias_shape),
1052 bias_data_(bias_data),
1053 output_offset_(output_offset),
1054 output_multiplier_(output_multiplier),
1055 output_shift_(output_shift),
1056 output_activation_min_(output_activation_min),
1057 output_activation_max_(output_activation_max),
1058 output_shape_(output_shape),
1059 output_data_(output_data),
1060 row_start_(row_start),
1061 row_end_(row_end) {}
1062
RunFullyConnectedAsGEMVWorkerTask1063 void Run() override {
1064 FullyConnectedAsGEMVWorkerImpl(
1065 input_shape_, input_data_, input_offset_, filter_shape_, filter_data_,
1066 filter_offset_, bias_shape_, bias_data_, output_offset_,
1067 output_multiplier_, output_shift_, output_activation_min_,
1068 output_activation_max_, output_shape_, output_data_, row_start_,
1069 row_end_);
1070 }
1071
1072 const RuntimeShape& input_shape_;
1073 const uint8* input_data_;
1074 int32 input_offset_;
1075 const RuntimeShape& filter_shape_;
1076 const uint8* filter_data_;
1077 int32 filter_offset_;
1078 const RuntimeShape& bias_shape_;
1079 const int32* bias_data_;
1080 int32 output_offset_;
1081 int32 output_multiplier_;
1082 int output_shift_;
1083 int32 output_activation_min_;
1084 int32 output_activation_max_;
1085 const RuntimeShape& output_shape_;
1086 uint8* output_data_;
1087 gemmlowp::GemmContext* gemm_context_;
1088 int row_start_;
1089 int row_end_;
1090 };
1091
FullyConnectedAsGEMV(const RuntimeShape & input_shape,const uint8 * input_data,int32 input_offset,const RuntimeShape & filter_shape,const uint8 * filter_data,int32 filter_offset,const RuntimeShape & bias_shape,const int32 * bias_data,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemm_context)1092 inline void FullyConnectedAsGEMV(
1093 const RuntimeShape& input_shape, const uint8* input_data,
1094 int32 input_offset, const RuntimeShape& filter_shape,
1095 const uint8* filter_data, int32 filter_offset,
1096 const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
1097 int32 output_multiplier, int output_shift, int32 output_activation_min,
1098 int32 output_activation_max, const RuntimeShape& output_shape,
1099 uint8* output_data, gemmlowp::GemmContext* gemm_context) {
1100 const int output_dim_count = output_shape.DimensionsCount();
1101 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1102 const int output_rows = output_shape.Dims(output_dim_count - 1);
1103 const int input_size = FlatSizeSkipDim(input_shape, 0);
1104 static constexpr int kKernelRows = 4;
1105 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1106 gemm_context->max_num_threads(), output_rows, batches, input_size);
1107 if (thread_count == 1) {
1108 // Single-thread case: do the computation on the current thread, don't
1109 // use a threadpool
1110 FullyConnectedAsGEMVWorkerImpl(
1111 input_shape, input_data, input_offset, filter_shape, filter_data,
1112 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1113 output_shift, output_activation_min, output_activation_max,
1114 output_shape, output_data, 0, output_rows);
1115 return;
1116 }
1117
1118 // Multi-threaded case: use the gemmlowp context's threadpool.
1119 TFLITE_DCHECK_GT(thread_count, 1);
1120 std::vector<gemmlowp::Task*> tasks(thread_count);
1121 const int kRowsPerWorker =
1122 gemmlowp::RoundUp<kKernelRows>(output_rows / thread_count);
1123 int row_start = 0;
1124 for (int i = 0; i < thread_count; ++i) {
1125 int row_end = std::min(output_rows, row_start + kRowsPerWorker);
1126 tasks[i] = new FullyConnectedAsGEMVWorkerTask(
1127 input_shape, input_data, input_offset, filter_shape, filter_data,
1128 filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
1129 output_shift, output_activation_min, output_activation_max,
1130 output_shape, output_data, row_start, row_end);
1131 row_start = row_end;
1132 }
1133 TFLITE_DCHECK_EQ(row_start, output_rows);
1134 gemm_context->workers_pool()->Execute(tasks);
1135 }
1136 #endif // USE_NEON
1137
1138 struct GemmlowpOutputPipeline {
1139 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1140 ColVectorMap;
1141 typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
1142 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
1143 gemmlowp::OutputStageClamp,
1144 gemmlowp::OutputStageSaturatingCastToUint8>
1145 Pipeline;
MakeExpGemmlowpOutputPipeline1146 static Pipeline MakeExp(const int32* bias_data, int output_rows,
1147 int32 output_offset, int32 output_multiplier,
1148 int output_left_shift, int32 output_activation_min,
1149 int32 output_activation_max) {
1150 ColVectorMap bias_vector(bias_data, output_rows);
1151 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1152 bias_addition_stage.bias_vector = bias_vector;
1153 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
1154 quantize_down_stage.result_offset_after_shift = output_offset;
1155 quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
1156 quantize_down_stage.result_exponent = output_left_shift;
1157 gemmlowp::OutputStageClamp clamp_stage;
1158 clamp_stage.min = output_activation_min;
1159 clamp_stage.max = output_activation_max;
1160 gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
1161 return std::make_tuple(bias_addition_stage, quantize_down_stage,
1162 clamp_stage, saturating_cast_stage);
1163 }
1164 };
1165
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext * gemm_context)1166 inline void FullyConnected(
1167 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1168 const uint8* input_data, const RuntimeShape& filter_shape,
1169 const uint8* filter_data, const RuntimeShape& bias_shape,
1170 const int32* bias_data, const RuntimeShape& output_shape,
1171 uint8* output_data, gemmlowp::GemmContext* gemm_context) {
1172 gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
1173 const int32 input_offset = params.input_offset;
1174 const int32 filter_offset = params.weights_offset;
1175 const int32 output_offset = params.output_offset;
1176 const int32 output_multiplier = params.output_multiplier;
1177 const int output_shift = params.output_shift;
1178 const int32 output_activation_min = params.quantized_activation_min;
1179 const int32 output_activation_max = params.quantized_activation_max;
1180 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1181 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1182 // TODO(benoitjacob): This really should be:
1183 // const int batches = ArraySize(output_dims, 1);
1184 // but the current --variable_batch hack consists in overwriting the 3rd
1185 // dimension with the runtime batch size, as we don't keep track for each
1186 // array of which dimension is the batch dimension in it.
1187 const int output_dim_count = output_shape.DimensionsCount();
1188 const int filter_dim_count = filter_shape.DimensionsCount();
1189 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1190 #ifdef USE_NEON
1191 if (batches == 1) {
1192 const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
1193 output_shape, output_dim_count - 1);
1194 if (output_size >= 4) {
1195 return FullyConnectedAsGEMV(
1196 input_shape, input_data, input_offset, filter_shape, filter_data,
1197 filter_offset, bias_shape, bias_data, output_offset,
1198 output_multiplier, output_shift, output_activation_min,
1199 output_activation_max, output_shape, output_data, gemm_context);
1200 }
1201 }
1202 #endif // USE_NEON
1203 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
1204 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
1205 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
1206 const int output_rows = output_shape.Dims(output_dim_count - 1);
1207 TFLITE_DCHECK_EQ(output_rows, filter_rows);
1208 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
1209
1210 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
1211 filter_data, output_rows, filter_cols, filter_cols);
1212 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1213 input_data, filter_cols, batches, filter_cols);
1214 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
1215 output_data, output_rows, batches, output_rows);
1216 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
1217 bias_data, output_rows, output_offset, output_multiplier, output_shift,
1218 output_activation_min, output_activation_max);
1219 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
1220 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1221 gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
1222 input_offset, output_pipeline);
1223 }
1224
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data_int32,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext * gemm_context)1225 inline void FullyConnected(
1226 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1227 const uint8* input_data, const RuntimeShape& filter_shape,
1228 const uint8* filter_data, const RuntimeShape& bias_shape,
1229 const int32* bias_data_int32, const RuntimeShape& output_shape,
1230 int16* output_data, gemmlowp::GemmContext* gemm_context) {
1231 gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
1232 const int32 input_offset = params.input_offset;
1233 const int32 filter_offset = params.weights_offset;
1234 const int32 output_offset = params.output_offset;
1235 const int32 output_multiplier = params.output_multiplier;
1236 const int output_shift = params.output_shift;
1237 const int32 output_activation_min = params.quantized_activation_min;
1238 const int32 output_activation_max = params.quantized_activation_max;
1239 // This is a copy of the reference implementation. We do not currently have a
1240 // properly optimized version.
1241 (void)gemm_context; // only used in properly optimized code.
1242 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1243 TFLITE_DCHECK_EQ(output_offset, 0);
1244 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
1245 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1246
1247 // TODO(benoitjacob): This really should be:
1248 // const int batches = ArraySize(output_dims, 1);
1249 // but the current --variable_batch hack consists in overwriting the 3rd
1250 // dimension with the runtime batch size, as we don't keep track for each
1251 // array of which dimension is the batch dimension in it.
1252 const int output_dim_count = output_shape.DimensionsCount();
1253 const int filter_dim_count = filter_shape.DimensionsCount();
1254 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1255 const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
1256 output_shape, output_dim_count - 1);
1257 const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
1258
1259 // Implementation of the fully connected node suited to the inside of an LSTM
1260 // cell. The operands are 8-bit integers, the accumulators are internally
1261 // 32bit integers, and the output is 16-bit fixed-point with 3 integer bits so
1262 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
1263 // is explained in the function comment above.
1264 #ifdef GEMMLOWP_NEON
1265 if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
1266 output_activation_max == 32767) {
1267 if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
1268 GEMVForLstmCellWithSymmetricRange(
1269 input_shape, input_data, filter_shape, filter_data, bias_shape,
1270 bias_data_int32, output_multiplier, output_shift, output_shape,
1271 output_data);
1272 return;
1273 }
1274 if (!(output_depth % 4) && !(accum_depth % 8)) {
1275 GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
1276 filter_offset, bias_shape, bias_data_int32,
1277 output_multiplier, output_shift, output_shape,
1278 output_data);
1279 return;
1280 }
1281 }
1282 #endif
1283 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> weights_matrix(
1284 filter_data, output_depth, accum_depth);
1285 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
1286 input_data, accum_depth, batches);
1287 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
1288 output_data, output_depth, batches);
1289 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
1290 ColVectorMap;
1291 ColVectorMap bias_vector(bias_data_int32, output_depth);
1292 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
1293 bias_addition_stage.bias_vector = bias_vector;
1294 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
1295 scale_stage.result_offset_after_shift = 0;
1296 scale_stage.result_fixedpoint_multiplier = output_multiplier;
1297 // Note that this shift is negated wrt ordinary FC.
1298 scale_stage.result_exponent = output_shift;
1299 gemmlowp::OutputStageClamp clamp_stage;
1300 clamp_stage.min = output_activation_min;
1301 clamp_stage.max = output_activation_max;
1302 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1303 auto output_pipeline =
1304 std::make_tuple(bias_addition_stage, scale_stage, clamp_stage,
1305 saturating_cast_int16_stage);
1306 gemmlowp::GemmWithOutputPipeline<uint8, int16,
1307 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
1308 gemm_context, weights_matrix, input_matrix, &output_matrix, filter_offset,
1309 input_offset, output_pipeline);
1310 }
1311
1312 // Internal function doing the actual arithmetic work for
1313 // ShuffledFullyConnected.
1314 // May be called either directly by it (single-threaded case) or may be used
1315 // as the 'task' for worker threads to run (multi-threaded case, see
1316 // ShuffledFullyConnectedWorkerTask below).
ShuffledFullyConnectedWorkerImpl(const uint8 * shuffled_input_workspace_data,const int8 * shuffled_weights_data,int batches,int output_depth,int output_stride,int accum_depth,const int32 * bias_data,int32 output_multiplier,int output_shift,int16 * output_data)1317 inline void ShuffledFullyConnectedWorkerImpl(
1318 const uint8* shuffled_input_workspace_data,
1319 const int8* shuffled_weights_data, int batches, int output_depth,
1320 int output_stride, int accum_depth, const int32* bias_data,
1321 int32 output_multiplier, int output_shift, int16* output_data) {
1322 #if defined USE_NEON
1323 const int8* shuffled_weights_ptr = shuffled_weights_data;
1324 if (batches == 1) {
1325 const int right_shift = output_shift > 0 ? 0 : -output_shift;
1326 const int left_shift = output_shift > 0 ? output_shift : 0;
1327 for (int c = 0; c < output_depth; c += 4) {
1328 // Accumulation loop.
1329 int32x4_t row_accum0 = vdupq_n_s32(0);
1330 int32x4_t row_accum1 = vdupq_n_s32(0);
1331 int32x4_t row_accum2 = vdupq_n_s32(0);
1332 int32x4_t row_accum3 = vdupq_n_s32(0);
1333 for (int d = 0; d < accum_depth; d += 16) {
1334 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
1335 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
1336 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
1337 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
1338 shuffled_weights_ptr += 64;
1339 int8x16_t input =
1340 vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
1341 int16x8_t local_accum0 =
1342 vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
1343 int16x8_t local_accum1 =
1344 vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
1345 int16x8_t local_accum2 =
1346 vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
1347 int16x8_t local_accum3 =
1348 vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
1349 local_accum0 =
1350 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
1351 local_accum1 =
1352 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
1353 local_accum2 =
1354 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
1355 local_accum3 =
1356 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
1357 row_accum0 = vpadalq_s16(row_accum0, local_accum0);
1358 row_accum1 = vpadalq_s16(row_accum1, local_accum1);
1359 row_accum2 = vpadalq_s16(row_accum2, local_accum2);
1360 row_accum3 = vpadalq_s16(row_accum3, local_accum3);
1361 }
1362 // Horizontally reduce accumulators
1363 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
1364 pairwise_reduced_acc_2, pairwise_reduced_acc_3;
1365 pairwise_reduced_acc_0 =
1366 vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
1367 pairwise_reduced_acc_1 =
1368 vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
1369 pairwise_reduced_acc_2 =
1370 vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
1371 pairwise_reduced_acc_3 =
1372 vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
1373 const int32x2_t reduced_lo =
1374 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
1375 const int32x2_t reduced_hi =
1376 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
1377 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
1378 // Add bias values.
1379 int32x4_t bias_vec = vld1q_s32(bias_data + c);
1380 reduced = vaddq_s32(reduced, bias_vec);
1381 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
1382 // Multiply by the fixed-point multiplier.
1383 reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
1384 // Rounding-shift-right.
1385 using gemmlowp::RoundingDivideByPOT;
1386 reduced = RoundingDivideByPOT(reduced, right_shift);
1387 // Narrow values down to 16 bit signed.
1388 const int16x4_t res16 = vqmovn_s32(reduced);
1389 vst1_s16(output_data + c, res16);
1390 }
1391 } else if (batches == 4) {
1392 const int right_shift = output_shift > 0 ? 0 : -output_shift;
1393 const int left_shift = output_shift > 0 ? output_shift : 0;
1394 for (int c = 0; c < output_depth; c += 4) {
1395 const int8* shuffled_input_ptr =
1396 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1397 // Accumulation loop.
1398 int32x4_t row_accum00 = vdupq_n_s32(0);
1399 int32x4_t row_accum10 = vdupq_n_s32(0);
1400 int32x4_t row_accum20 = vdupq_n_s32(0);
1401 int32x4_t row_accum30 = vdupq_n_s32(0);
1402 int32x4_t row_accum01 = vdupq_n_s32(0);
1403 int32x4_t row_accum11 = vdupq_n_s32(0);
1404 int32x4_t row_accum21 = vdupq_n_s32(0);
1405 int32x4_t row_accum31 = vdupq_n_s32(0);
1406 int32x4_t row_accum02 = vdupq_n_s32(0);
1407 int32x4_t row_accum12 = vdupq_n_s32(0);
1408 int32x4_t row_accum22 = vdupq_n_s32(0);
1409 int32x4_t row_accum32 = vdupq_n_s32(0);
1410 int32x4_t row_accum03 = vdupq_n_s32(0);
1411 int32x4_t row_accum13 = vdupq_n_s32(0);
1412 int32x4_t row_accum23 = vdupq_n_s32(0);
1413 int32x4_t row_accum33 = vdupq_n_s32(0);
1414 for (int d = 0; d < accum_depth; d += 16) {
1415 int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
1416 int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
1417 int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
1418 int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
1419 shuffled_weights_ptr += 64;
1420 int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
1421 int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
1422 int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
1423 int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
1424 shuffled_input_ptr += 64;
1425 int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
1426 #define TFLITE_SHUFFLED_FC_ACCUM(B) \
1427 local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B)); \
1428 local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B)); \
1429 local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B)); \
1430 local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B)); \
1431 local_accum0 = \
1432 vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
1433 local_accum1 = \
1434 vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
1435 local_accum2 = \
1436 vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
1437 local_accum3 = \
1438 vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
1439 row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0); \
1440 row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1); \
1441 row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2); \
1442 row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
1443
1444 TFLITE_SHUFFLED_FC_ACCUM(0)
1445 TFLITE_SHUFFLED_FC_ACCUM(1)
1446 TFLITE_SHUFFLED_FC_ACCUM(2)
1447 TFLITE_SHUFFLED_FC_ACCUM(3)
1448
1449 #undef TFLITE_SHUFFLED_FC_ACCUM
1450 }
1451 // Horizontally reduce accumulators
1452
1453 #define TFLITE_SHUFFLED_FC_STORE(B) \
1454 { \
1455 int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, \
1456 pairwise_reduced_acc_2, pairwise_reduced_acc_3; \
1457 pairwise_reduced_acc_0 = \
1458 vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
1459 pairwise_reduced_acc_1 = \
1460 vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
1461 pairwise_reduced_acc_2 = \
1462 vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
1463 pairwise_reduced_acc_3 = \
1464 vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
1465 const int32x2_t reduced_lo = \
1466 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); \
1467 const int32x2_t reduced_hi = \
1468 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); \
1469 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); \
1470 int32x4_t bias_vec = vld1q_s32(bias_data + c); \
1471 reduced = vaddq_s32(reduced, bias_vec); \
1472 reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); \
1473 reduced = vqrdmulhq_n_s32(reduced, output_multiplier); \
1474 using gemmlowp::RoundingDivideByPOT; \
1475 reduced = RoundingDivideByPOT(reduced, right_shift); \
1476 const int16x4_t res16 = vqmovn_s32(reduced); \
1477 vst1_s16(output_data + c + B * output_stride, res16); \
1478 }
1479
1480 TFLITE_SHUFFLED_FC_STORE(0);
1481 TFLITE_SHUFFLED_FC_STORE(1);
1482 TFLITE_SHUFFLED_FC_STORE(2);
1483 TFLITE_SHUFFLED_FC_STORE(3);
1484
1485 #undef TFLITE_SHUFFLED_FC_STORE
1486 }
1487 } else {
1488 TFLITE_DCHECK(false);
1489 return;
1490 }
1491 #else
1492 if (batches == 1) {
1493 int16* output_ptr = output_data;
1494 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1495 // so that just reinterpreting them as int8 values is equivalent to
1496 // subtracting 128 from them, thus implementing for free the subtraction of
1497 // the zero_point value 128.
1498 const int8* shuffled_weights_ptr =
1499 reinterpret_cast<const int8*>(shuffled_weights_data);
1500 // Likewise, we preshuffled and pre-xored the input data above.
1501 const int8* shuffled_input_data =
1502 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1503 for (int c = 0; c < output_depth; c += 4) {
1504 // Internal accumulation.
1505 // Initialize accumulator with the bias-value.
1506 int32 accum[4] = {0};
1507 // Accumulation loop.
1508 for (int d = 0; d < accum_depth; d += 16) {
1509 for (int i = 0; i < 4; i++) {
1510 for (int j = 0; j < 16; j++) {
1511 int8 input_val = shuffled_input_data[d + j];
1512 int8 weights_val = *shuffled_weights_ptr++;
1513 accum[i] += weights_val * input_val;
1514 }
1515 }
1516 }
1517 for (int i = 0; i < 4; i++) {
1518 // Add bias value
1519 int acc = accum[i] + bias_data[c + i];
1520 // Down-scale the final int32 accumulator to the scale used by our
1521 // (16-bit, typically 3 integer bits) fixed-point format. The quantized
1522 // multiplier and shift here have been pre-computed offline
1523 // (e.g. by toco).
1524 acc =
1525 MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
1526 // Saturate, cast to int16, and store to output array.
1527 acc = std::max(acc, -32768);
1528 acc = std::min(acc, 32767);
1529 output_ptr[c + i] = acc;
1530 }
1531 }
1532 } else if (batches == 4) {
1533 int16* output_ptr = output_data;
1534 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1535 // so that just reinterpreting them as int8 values is equivalent to
1536 // subtracting 128 from them, thus implementing for free the subtraction of
1537 // the zero_point value 128.
1538 const int8* shuffled_weights_ptr =
1539 reinterpret_cast<const int8*>(shuffled_weights_data);
1540 // Likewise, we preshuffled and pre-xored the input data above.
1541 const int8* shuffled_input_data =
1542 reinterpret_cast<const int8*>(shuffled_input_workspace_data);
1543 for (int c = 0; c < output_depth; c += 4) {
1544 const int8* shuffled_input_ptr = shuffled_input_data;
1545 // Accumulation loop.
1546 // Internal accumulation.
1547 // Initialize accumulator with the bias-value.
1548 int32 accum[4][4];
1549 for (int i = 0; i < 4; i++) {
1550 for (int b = 0; b < 4; b++) {
1551 accum[i][b] = 0;
1552 }
1553 }
1554 for (int d = 0; d < accum_depth; d += 16) {
1555 for (int i = 0; i < 4; i++) {
1556 for (int b = 0; b < 4; b++) {
1557 for (int j = 0; j < 16; j++) {
1558 int8 input_val = shuffled_input_ptr[16 * b + j];
1559 int8 weights_val = shuffled_weights_ptr[16 * i + j];
1560 accum[i][b] += weights_val * input_val;
1561 }
1562 }
1563 }
1564 shuffled_input_ptr += 64;
1565 shuffled_weights_ptr += 64;
1566 }
1567 for (int i = 0; i < 4; i++) {
1568 for (int b = 0; b < 4; b++) {
1569 // Add bias value
1570 int acc = accum[i][b] + bias_data[c + i];
1571 // Down-scale the final int32 accumulator to the scale used by our
1572 // (16-bit, typically 3 integer bits) fixed-point format. The
1573 // quantized multiplier and shift here have been pre-computed offline
1574 // (e.g. by toco).
1575 acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
1576 output_shift);
1577 // Saturate, cast to int16, and store to output array.
1578 acc = std::max(acc, -32768);
1579 acc = std::min(acc, 32767);
1580 output_ptr[b * output_stride + c + i] = acc;
1581 }
1582 }
1583 }
1584 } else {
1585 TFLITE_DCHECK(false);
1586 return;
1587 }
1588 #endif
1589 }
1590
1591 // Wraps ShuffledFullyConnectedWorkerImpl into a Task class
1592 // to allow using gemmlowp's threadpool.
1593 struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
ShuffledFullyConnectedWorkerTaskShuffledFullyConnectedWorkerTask1594 ShuffledFullyConnectedWorkerTask(const uint8* input_data,
1595 const int8* shuffled_weights_data,
1596 int batches, int output_depth,
1597 int output_stride, int accum_depth,
1598 const int32* bias_data,
1599 int32 output_multiplier, int output_shift,
1600 int16* output_data)
1601 : input_data_(input_data),
1602 shuffled_weights_data_(shuffled_weights_data),
1603 batches_(batches),
1604 output_depth_(output_depth),
1605 output_stride_(output_stride),
1606 accum_depth_(accum_depth),
1607 bias_data_(bias_data),
1608 output_multiplier_(output_multiplier),
1609 output_shift_(output_shift),
1610 output_data_(output_data) {}
1611
RunShuffledFullyConnectedWorkerTask1612 void Run() override {
1613 ShuffledFullyConnectedWorkerImpl(
1614 input_data_, shuffled_weights_data_, batches_, output_depth_,
1615 output_stride_, accum_depth_, bias_data_, output_multiplier_,
1616 output_shift_, output_data_);
1617 }
1618
1619 const uint8* input_data_;
1620 const int8* shuffled_weights_data_;
1621 int batches_;
1622 int output_depth_;
1623 int output_stride_;
1624 int accum_depth_;
1625 const int32* bias_data_;
1626 int32 output_multiplier_;
1627 int output_shift_;
1628 int16* output_data_;
1629 };
1630
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemm_context)1631 inline void ShuffledFullyConnected(
1632 const FullyConnectedParams& params, const RuntimeShape& input_shape,
1633 const uint8* input_data, const RuntimeShape& weights_shape,
1634 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
1635 const int32* bias_data, const RuntimeShape& output_shape,
1636 int16* output_data, uint8* shuffled_input_workspace_data,
1637 gemmlowp::GemmContext* gemm_context) {
1638 gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
1639 const int32 output_multiplier = params.output_multiplier;
1640 const int output_shift = params.output_shift;
1641 const int32 output_activation_min = params.quantized_activation_min;
1642 const int32 output_activation_max = params.quantized_activation_max;
1643 (void)gemm_context; // only used in optimized code.
1644 TFLITE_DCHECK_EQ(output_activation_min, -32768);
1645 TFLITE_DCHECK_EQ(output_activation_max, 32767);
1646 TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
1647 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
1648 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
1649 // TODO(benoitjacob): This really should be:
1650 // const int batches = ArraySize(output_dims, 1);
1651 // but the current --variable_batch hack consists in overwriting the 3rd
1652 // dimension with the runtime batch size, as we don't keep track for each
1653 // array of which dimension is the batch dimension in it.
1654 const int output_dim_count = output_shape.DimensionsCount();
1655 const int weights_dim_count = weights_shape.DimensionsCount();
1656 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
1657 const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
1658 output_shape, output_dim_count - 1);
1659 const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
1660 TFLITE_DCHECK((accum_depth % 16) == 0);
1661 TFLITE_DCHECK((output_depth % 4) == 0);
1662 // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
1663 // so that just reinterpreting them as int8 values is equivalent to
1664 // subtracting 128 from them, thus implementing for free the subtraction of
1665 // the zero_point value 128.
1666 const int8* int8_shuffled_weights_data =
1667 reinterpret_cast<const int8*>(shuffled_weights_data);
1668
1669 // Shuffling and xoring of input activations into the workspace buffer
1670 if (batches == 1) {
1671 #ifdef USE_NEON
1672 const uint8x16_t signbit = vdupq_n_u8(0x80);
1673 for (int i = 0; i < accum_depth; i += 16) {
1674 uint8x16_t val = vld1q_u8(input_data + i);
1675 val = veorq_u8(val, signbit);
1676 vst1q_u8(shuffled_input_workspace_data + i, val);
1677 }
1678 #else
1679 for (int i = 0; i < accum_depth; i++) {
1680 shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
1681 }
1682 #endif
1683 } else if (batches == 4) {
1684 uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
1685 int c = 0;
1686 #ifdef USE_NEON
1687 const uint8x16_t signbit = vdupq_n_u8(0x80);
1688 for (c = 0; c < accum_depth; c += 16) {
1689 const uint8* src_data_ptr = input_data + c;
1690 uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
1691 uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
1692 uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
1693 uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
1694 val0 = veorq_u8(val0, signbit);
1695 val1 = veorq_u8(val1, signbit);
1696 val2 = veorq_u8(val2, signbit);
1697 val3 = veorq_u8(val3, signbit);
1698 vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
1699 vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
1700 vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
1701 vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
1702 shuffled_input_workspace_ptr += 64;
1703 }
1704 #else
1705 for (c = 0; c < accum_depth; c += 16) {
1706 for (int b = 0; b < 4; b++) {
1707 const uint8* src_data_ptr = input_data + b * accum_depth + c;
1708 for (int j = 0; j < 16; j++) {
1709 uint8 src_val = *src_data_ptr++;
1710 // Flip the sign bit, so that the kernel will only need to
1711 // reinterpret these uint8 values as int8, getting for free the
1712 // subtraction of the zero_point value 128.
1713 uint8 dst_val = src_val ^ 0x80;
1714 *shuffled_input_workspace_ptr++ = dst_val;
1715 }
1716 }
1717 }
1718 #endif
1719 } else {
1720 TFLITE_DCHECK(false);
1721 return;
1722 }
1723
1724 static constexpr int kKernelRows = 4;
1725 const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
1726 gemm_context->max_num_threads(), output_depth, batches, accum_depth);
1727 if (thread_count == 1) {
1728 // Single-thread case: do the computation on the current thread, don't
1729 // use a threadpool
1730 ShuffledFullyConnectedWorkerImpl(
1731 shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
1732 output_depth, output_depth, accum_depth, bias_data, output_multiplier,
1733 output_shift, output_data);
1734 return;
1735 }
1736
1737 // Multi-threaded case: use the gemmlowp context's threadpool.
1738 TFLITE_DCHECK_GT(thread_count, 1);
1739 std::vector<gemmlowp::Task*> tasks(thread_count);
1740 const int kRowsPerWorker =
1741 gemmlowp::RoundUp<kKernelRows>(output_depth / thread_count);
1742 int row_start = 0;
1743 for (int i = 0; i < thread_count; i++) {
1744 int row_end = std::min(output_depth, row_start + kRowsPerWorker);
1745 tasks[i] = new ShuffledFullyConnectedWorkerTask(
1746 shuffled_input_workspace_data,
1747 int8_shuffled_weights_data + row_start * accum_depth, batches,
1748 row_end - row_start, output_depth, accum_depth, bias_data + row_start,
1749 output_multiplier, output_shift, output_data + row_start);
1750 row_start = row_end;
1751 }
1752 TFLITE_DCHECK_EQ(row_start, output_depth);
1753 gemm_context->workers_pool()->Execute(tasks);
1754 }
1755
MeanImpl(const tflite::MeanParams & op_params,const RuntimeShape & input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,int start_depth,int end_depth)1756 inline void MeanImpl(const tflite::MeanParams& op_params,
1757 const RuntimeShape& input_shape, const uint8_t* input_data,
1758 int32 input_zero_point, float input_scale,
1759 const RuntimeShape& output_shape, uint8_t* output_data,
1760 int32 output_zero_point, float output_scale,
1761 int start_depth, int end_depth) {
1762 gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8/MeanImpl");
1763
1764 // Current implementation only supports dimension equals 4 and simultaneous
1765 // reduction over width and height.
1766 const int output_batch = output_shape.Dims(0);
1767 const int output_height = output_shape.Dims(2);
1768 const int output_width = output_shape.Dims(2);
1769 const int input_height = input_shape.Dims(1);
1770 const int input_width = input_shape.Dims(2);
1771 const float num_elements_in_axis = input_width * input_height;
1772
1773 TFLITE_DCHECK_EQ(op_params.axis_count, 2);
1774 TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1775 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1776 TFLITE_DCHECK_EQ(output_height, 1);
1777 TFLITE_DCHECK_EQ(output_width, 1);
1778
1779 const bool ordinary_mean =
1780 (input_zero_point == output_zero_point && input_scale == output_scale);
1781 float scale, bias;
1782 if (!ordinary_mean) {
1783 scale = input_scale / output_scale;
1784 bias = -input_zero_point * scale + 0.5;
1785 }
1786
1787 #ifdef USE_NEON
1788 const float32x4_t num_elements_dup = vdupq_n_f32(num_elements_in_axis);
1789 // This is only an approximation as NEON does not offer division instruction.
1790 const float32x4_t num_elements_reverse = vrecpeq_f32(num_elements_dup);
1791 const float32x4_t kRounding = vdupq_n_f32(0.5);
1792 float32x4_t bias_dup;
1793 float32x4_t output_zero_point_dup;
1794 if (!ordinary_mean) {
1795 bias_dup = vdupq_n_f32(bias);
1796 output_zero_point_dup = vdupq_n_f32(output_zero_point);
1797 }
1798 #endif
1799
1800 for (int out_b = 0; out_b < output_batch; ++out_b) {
1801 int out_d = start_depth;
1802 #ifdef USE_NEON
1803
1804 for (; out_d < end_depth - 8; out_d += 8) {
1805 float32x4_t temp_sum_1 = vdupq_n_f32(0);
1806 float32x4_t temp_sum_2 = vdupq_n_f32(0);
1807 for (int in_h = 0; in_h < input_height; ++in_h) {
1808 for (int in_w = 0; in_w < input_width; ++in_w) {
1809 const uint8_t* input_data_ptr =
1810 input_data + Offset(input_shape, out_b, in_h, in_w, out_d);
1811 uint8x8_t input_data_val = vld1_u8(input_data_ptr);
1812 int16x8_t input_data_val_shift =
1813 vreinterpretq_s16_u16(vmovl_u8(input_data_val));
1814 float32x4_t input_float_1 =
1815 vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_data_val_shift)));
1816 float32x4_t input_float_2 =
1817 vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_data_val_shift)));
1818 temp_sum_1 = vaddq_f32(temp_sum_1, input_float_1);
1819 temp_sum_2 = vaddq_f32(temp_sum_2, input_float_2);
1820 }
1821 }
1822
1823 float32x4_t mean_1 = vmulq_f32(temp_sum_1, num_elements_reverse);
1824 float32x4_t mean_2 = vmulq_f32(temp_sum_2, num_elements_reverse);
1825
1826 if (!ordinary_mean) {
1827 // maq is not supported, break down into two ops.
1828 mean_1 = vmulq_n_f32(mean_1, scale);
1829 mean_1 = vaddq_f32(mean_1, bias_dup);
1830 mean_2 = vmulq_n_f32(mean_2, scale);
1831 mean_2 = vaddq_f32(mean_2, bias_dup);
1832 }
1833
1834 if (!ordinary_mean) {
1835 mean_1 = vaddq_f32(mean_1, output_zero_point_dup);
1836 mean_2 = vaddq_f32(mean_2, output_zero_point_dup);
1837 }
1838
1839 // Rounding.
1840 mean_1 = vaddq_f32(mean_1, kRounding);
1841 mean_2 = vaddq_f32(mean_2, kRounding);
1842 uint32x4_t casted_mean_1 = vcvtq_u32_f32(mean_1);
1843 uint16x4_t narrow_range_mean_1 = vmovn_u32(casted_mean_1);
1844 uint32x4_t casted_mean_2 = vcvtq_u32_f32(mean_2);
1845 uint16x4_t narrow_range_mean_2 = vmovn_u32(casted_mean_2);
1846 uint16x8_t combined_mean =
1847 vcombine_u16(narrow_range_mean_2, narrow_range_mean_1);
1848 uint8x8_t narrowed_combined_mean = vmovn_u16(combined_mean);
1849 uint8_t* output_data_ptr =
1850 output_data + Offset(output_shape, out_b, 0, 0, out_d);
1851 vst1_u8(output_data_ptr, narrowed_combined_mean);
1852 }
1853 #endif
1854
1855 for (; out_d < end_depth; ++out_d) {
1856 float temp_value = 0;
1857 for (int in_h = 0; in_h < input_height; ++in_h) {
1858 for (int in_w = 0; in_w < input_width; ++in_w) {
1859 temp_value +=
1860 input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
1861 }
1862 }
1863
1864 temp_value = temp_value / num_elements_in_axis;
1865 if (ordinary_mean) {
1866 output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1867 static_cast<uint8_t>(round(temp_value));
1868 } else {
1869 output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
1870 static_cast<uint8_t>(round(temp_value * scale + bias)) +
1871 output_zero_point;
1872 }
1873 }
1874 }
1875 }
1876
1877 struct MeanWorkerTask : public gemmlowp::Task {
MeanWorkerTaskMeanWorkerTask1878 MeanWorkerTask(const tflite::MeanParams& op_params,
1879 const RuntimeShape& input_shape, const uint8_t* input_data,
1880 int32 input_zero_point, float input_scale,
1881 const RuntimeShape& output_shape, uint8_t* output_data,
1882 int32 output_zero_point, float output_scale, int start_height,
1883 int end_height)
1884 : op_params_(op_params),
1885 input_shape_(input_shape),
1886 input_data_(input_data),
1887 input_zero_point_(input_zero_point),
1888 input_scale_(input_scale),
1889 output_shape_(output_shape),
1890 output_data_(output_data),
1891 output_zero_point_(output_zero_point),
1892 output_scale_(output_scale),
1893 start_height_(start_height),
1894 end_height_(end_height) {}
1895
RunMeanWorkerTask1896 void Run() override {
1897 MeanImpl(op_params_, input_shape_, input_data_, input_zero_point_,
1898 input_scale_, output_shape_, output_data_, output_zero_point_,
1899 output_scale_, start_height_, end_height_);
1900 }
1901
1902 private:
1903 const tflite::MeanParams& op_params_;
1904 const RuntimeShape& input_shape_;
1905 const uint8_t* input_data_;
1906 int32 input_zero_point_;
1907 float input_scale_;
1908 const RuntimeShape& output_shape_;
1909 uint8_t* output_data_;
1910 int32 output_zero_point_;
1911 float output_scale_;
1912 int start_height_;
1913 int end_height_;
1914 gemmlowp::GemmContext* gemm_context_;
1915 };
1916
Mean(const tflite::MeanParams & op_params,const RuntimeShape & unextended_input_shape,const uint8_t * input_data,int32 input_zero_point,float input_scale,const RuntimeShape & unextended_output_shape,uint8_t * output_data,int32 output_zero_point,float output_scale,gemmlowp::GemmContext * gemm_context)1917 inline void Mean(const tflite::MeanParams& op_params,
1918 const RuntimeShape& unextended_input_shape,
1919 const uint8_t* input_data, int32 input_zero_point,
1920 float input_scale, const RuntimeShape& unextended_output_shape,
1921 uint8_t* output_data, int32 output_zero_point,
1922 float output_scale, gemmlowp::GemmContext* gemm_context) {
1923 gemmlowp::ScopedProfilingLabel label("Mean4D/Uint8");
1924
1925 // Current implementation only supports dimension equals 4 and simultaneous
1926 // reduction over width and height.
1927 TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
1928 TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1929 const RuntimeShape input_shape =
1930 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1931 const RuntimeShape output_shape =
1932 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1933 const int output_height = output_shape.Dims(1);
1934 const int output_width = output_shape.Dims(2);
1935 const int output_depth = output_shape.Dims(3);
1936
1937 TFLITE_DCHECK_EQ(op_params.axis_count, 2);
1938 TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
1939 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
1940 TFLITE_DCHECK_EQ(output_height, 1);
1941 TFLITE_DCHECK_EQ(output_width, 1);
1942
1943 constexpr int kMinDepthPerThread = 8;
1944 int thread_count = output_depth / kMinDepthPerThread;
1945 thread_count = thread_count > 0 ? thread_count : 1;
1946 const int capped_thread_count =
1947 std::min(thread_count, gemm_context->max_num_threads());
1948
1949 if (thread_count == 1) {
1950 MeanImpl(op_params, input_shape, input_data, input_zero_point, input_scale,
1951 output_shape, output_data, output_zero_point, output_scale, 0,
1952 output_depth);
1953 } else {
1954 // Instead parrallel for batch, we loop for the output_depth since batch
1955 // is typical 1.
1956 std::vector<gemmlowp::Task*> tasks(capped_thread_count);
1957 int depth_start = 0;
1958 for (int i = 0; i < capped_thread_count; ++i) {
1959 // Try to distribute the tasks as even as possible.
1960 int depth_end = depth_start +
1961 (output_depth - depth_start) / (capped_thread_count - i);
1962 tasks[i] = new MeanWorkerTask(op_params, input_shape, input_data,
1963 input_zero_point, input_scale, output_shape,
1964 output_data, output_zero_point,
1965 output_scale, depth_start, depth_end);
1966 depth_start = depth_end;
1967 }
1968 gemm_context->workers_pool()->Execute(tasks);
1969 }
1970 }
1971
1972 template <typename T>
ExtractPatchIntoBufferColumn(const RuntimeShape & input_shape,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)1973 inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
1974 int h, int b, int kheight, int kwidth,
1975 int stride_width, int stride_height,
1976 int pad_width, int pad_height,
1977 int in_width, int in_height,
1978 int in_depth, int single_buffer_length,
1979 int buffer_id, const T* in_data,
1980 T* conv_buffer_data, uint8 zero_byte) {
1981 gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
1982 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1983 // This chunk of code reshapes all the inputs corresponding to
1984 // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
1985 const int kwidth_times_indepth = kwidth * in_depth;
1986 const int inwidth_times_indepth = in_width * in_depth;
1987 const int ih_ungated_start = h * stride_height - pad_height;
1988 const int ih_ungated_end = (ih_ungated_start + kheight);
1989 const int ih_end = std::min(ih_ungated_end, in_height);
1990 const int iw_ungated_start = w * stride_width - pad_width;
1991 const int iw_ungated_end = (iw_ungated_start + kwidth);
1992 const int iw_end = std::min(iw_ungated_end, in_width);
1993 // If the patch is off the edge of the input image, skip writing those rows
1994 // and columns from the patch into the output array.
1995 const int h_offset = std::max(0, -ih_ungated_start);
1996 const int w_offset = std::max(0, -iw_ungated_start);
1997 const int ih_start = std::max(0, ih_ungated_start);
1998 const int iw_start = std::max(0, iw_ungated_start);
1999 const int single_row_num =
2000 std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
2001 const int output_row_offset = (buffer_id * single_buffer_length);
2002 int out_offset =
2003 output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
2004 int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
2005
2006 // Express all of the calculations as padding around the input patch.
2007 const int top_padding = h_offset;
2008 const int bottom_padding = (ih_ungated_end - ih_end);
2009 const int left_padding = w_offset;
2010 const int right_padding = (iw_ungated_end - iw_end);
2011 assert(single_row_num ==
2012 ((kwidth - (left_padding + right_padding)) * in_depth));
2013
2014 // Write out zeroes to the elements representing the top rows of the input
2015 // patch that are off the edge of the input image.
2016 if (top_padding > 0) {
2017 const int top_row_elements = (top_padding * kwidth * in_depth);
2018 memset(conv_buffer_data + output_row_offset, zero_byte,
2019 (top_row_elements * sizeof(T)));
2020 }
2021
2022 // If the patch is on the interior of the input image horizontally, just copy
2023 // over the rows sequentially, otherwise add zero padding at the start or end.
2024 if ((left_padding == 0) && (right_padding == 0)) {
2025 for (int ih = ih_start; ih < ih_end; ++ih) {
2026 memcpy(conv_buffer_data + out_offset, in_data + in_offset,
2027 single_row_num * sizeof(T));
2028 out_offset += kwidth_times_indepth;
2029 in_offset += inwidth_times_indepth;
2030 }
2031 } else {
2032 for (int ih = ih_start; ih < ih_end; ++ih) {
2033 if (left_padding > 0) {
2034 const int left_start = (out_offset - (left_padding * in_depth));
2035 memset(conv_buffer_data + left_start, zero_byte,
2036 (left_padding * in_depth * sizeof(T)));
2037 }
2038 memcpy(conv_buffer_data + out_offset, in_data + in_offset,
2039 single_row_num * sizeof(T));
2040 if (right_padding > 0) {
2041 const int right_start = (out_offset + single_row_num);
2042 memset(conv_buffer_data + right_start, zero_byte,
2043 (right_padding * in_depth * sizeof(T)));
2044 }
2045 out_offset += kwidth_times_indepth;
2046 in_offset += inwidth_times_indepth;
2047 }
2048 }
2049
2050 // If the bottom of the patch falls off the input image, pad the values
2051 // representing those input rows with zeroes.
2052 if (bottom_padding > 0) {
2053 const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
2054 const int bottom_start =
2055 output_row_offset +
2056 ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
2057 memset(conv_buffer_data + bottom_start, zero_byte,
2058 (bottom_row_elements * sizeof(T)));
2059 }
2060 }
2061
2062 template <typename T>
DilatedIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)2063 void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
2064 const RuntimeShape& input_shape, const T* input_data,
2065 const RuntimeShape& filter_shape,
2066 const RuntimeShape& output_shape, T* im2col_data) {
2067 const int stride_width = params.stride_width;
2068 const int stride_height = params.stride_height;
2069 const int dilation_width_factor = params.dilation_width_factor;
2070 const int dilation_height_factor = params.dilation_height_factor;
2071 const int pad_width = params.padding_values.width;
2072 const int pad_height = params.padding_values.height;
2073 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2074 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2075 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2076
2077 // For dilated convolution, the input pixels are not contiguous therefore we
2078 // can't use the same opitimizations as Im2Col(). Though note this code would
2079 // work fine for the non-dilated case too (though likely a bit slower).
2080 gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
2081 TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
2082 TFLITE_DCHECK(im2col_data);
2083 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2084 const int input_height = input_shape.Dims(1);
2085 const int input_width = input_shape.Dims(2);
2086 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
2087 const int filter_height = filter_shape.Dims(1);
2088 const int filter_width = filter_shape.Dims(2);
2089 const int output_height = output_shape.Dims(1);
2090 const int output_width = output_shape.Dims(2);
2091 MatchingDim(output_shape, 3, filter_shape, 0);
2092
2093 // Construct the MxN sized im2col matrix.
2094 // The rows M, are sub-ordered B x H x W
2095 const RuntimeShape row_shape({1, batches, output_height, output_width});
2096 // The columns, N, are sub-ordered Kh x Kw x Din
2097 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
2098 // Use dimensions M and N to construct dims for indexing directly into im2col
2099 const RuntimeShape im2col_shape(
2100 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
2101
2102 // Loop through the output rows (B x H x W)
2103 for (int batch = 0; batch < batches; ++batch) {
2104 for (int out_y = 0; out_y < output_height; ++out_y) {
2105 for (int out_x = 0; out_x < output_width; ++out_x) {
2106 // Each im2col row is an output pixel. Arrange the input data in this
2107 // row in an order we can conveniently multiply with the filter data.
2108 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
2109 const int in_x_origin = (out_x * stride_width) - pad_width;
2110 const int in_y_origin = (out_y * stride_height) - pad_height;
2111 // Loop through all the pixels of the filter (Kh x Kw)
2112 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
2113 const int in_y = in_y_origin + dilation_height_factor * filter_y;
2114 if ((in_y >= 0) && (in_y < input_height)) {
2115 // Filter row is within the input data.
2116 // Loop through all the filter pixels in this row.
2117 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
2118 const int in_x = in_x_origin + dilation_width_factor * filter_x;
2119 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
2120 T* dst = im2col_data +
2121 Offset(im2col_shape, 0, 0, row_offset, col_offset);
2122 if ((in_x >= 0) && (in_x < input_width)) {
2123 // Filter pixel is within the input, copy the input data.
2124 T const* src =
2125 input_data + Offset(input_shape, batch, in_y, in_x, 0);
2126 memcpy(dst, src, input_depth * sizeof(T));
2127 } else {
2128 // Filter pixel is outside the input, zero it out.
2129 memset(dst, zero_byte, input_depth * sizeof(T));
2130 }
2131 }
2132 } else {
2133 // Filter row is outside the input, zero out the entire filter row.
2134 int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
2135 T* dst = im2col_data +
2136 Offset(im2col_shape, 0, 0, row_offset, col_offset);
2137 memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
2138 }
2139 }
2140 }
2141 }
2142 }
2143 }
2144
2145 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)2146 void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
2147 const RuntimeShape& input_shape, const T* input_data,
2148 const RuntimeShape& output_shape, T* output_data) {
2149 gemmlowp::ScopedProfilingLabel label("Im2col");
2150 const int stride_width = params.stride_width;
2151 const int stride_height = params.stride_height;
2152 const int pad_width = params.padding_values.width;
2153 const int pad_height = params.padding_values.height;
2154 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2155 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2156
2157 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
2158 const int input_depth = input_shape.Dims(3);
2159 const int input_width = input_shape.Dims(2);
2160 const int input_height = input_shape.Dims(1);
2161 const int output_depth = output_shape.Dims(3);
2162 const int output_width = output_shape.Dims(2);
2163 const int output_height = output_shape.Dims(1);
2164
2165 int buffer_id = 0;
2166 // Loop over the output nodes.
2167 for (int b = 0; b < batches; ++b) {
2168 for (int h = 0; h < output_height; ++h) {
2169 for (int w = 0; w < output_width; ++w) {
2170 ExtractPatchIntoBufferColumn(
2171 input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
2172 pad_width, pad_height, input_width, input_height, input_depth,
2173 output_depth, buffer_id, input_data, output_data, zero_byte);
2174 ++buffer_id;
2175 }
2176 }
2177 }
2178 }
2179
Conv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)2180 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2181 const float* input_data, const RuntimeShape& filter_shape,
2182 const float* filter_data, const RuntimeShape& bias_shape,
2183 const float* bias_data, const RuntimeShape& output_shape,
2184 float* output_data, const RuntimeShape& im2col_shape,
2185 float* im2col_data) {
2186 const int stride_width = params.stride_width;
2187 const int stride_height = params.stride_height;
2188 const int dilation_width_factor = params.dilation_width_factor;
2189 const int dilation_height_factor = params.dilation_height_factor;
2190 const float output_activation_min = params.float_activation_min;
2191 const float output_activation_max = params.float_activation_max;
2192 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2193 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2194 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2195
2196 (void)im2col_data;
2197 (void)im2col_shape;
2198 gemmlowp::ScopedProfilingLabel label("Conv");
2199
2200 // NB: static_cast<float>(0x00000000h) == 0.0f
2201 const uint8 float_zero_byte = 0x00;
2202 const float* gemm_input_data = nullptr;
2203 const RuntimeShape* gemm_input_shape = nullptr;
2204 const int filter_width = filter_shape.Dims(2);
2205 const int filter_height = filter_shape.Dims(1);
2206 const bool need_dilated_im2col =
2207 dilation_width_factor != 1 || dilation_height_factor != 1;
2208 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2209 filter_width != 1 || filter_height != 1;
2210 if (need_dilated_im2col) {
2211 DilatedIm2col(params, float_zero_byte, input_shape, input_data,
2212 filter_shape, output_shape, im2col_data);
2213 gemm_input_data = im2col_data;
2214 gemm_input_shape = &im2col_shape;
2215 } else if (need_im2col) {
2216 TFLITE_DCHECK(im2col_data);
2217 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
2218 input_data, im2col_shape, im2col_data);
2219 gemm_input_data = im2col_data;
2220 gemm_input_shape = &im2col_shape;
2221 } else {
2222 // TODO(aselle): We need to make sure to not send im2col if it is not
2223 // needed.
2224 TFLITE_DCHECK(!im2col_data);
2225 gemm_input_data = input_data;
2226 gemm_input_shape = &input_shape;
2227 }
2228
2229 // The following code computes matrix multiplication c = a * transponse(b)
2230 // with CBLAS, where:
2231 // * `a` is a matrix with dimensions (m, k).
2232 // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
2233 // * `c` is a matrix with dimensions (m, n).
2234 // The naming of variables are aligned with CBLAS specification here.
2235 const float* a = gemm_input_data;
2236 const float* b = filter_data;
2237 float* c = output_data;
2238 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
2239 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
2240 int n = output_shape.Dims(3);
2241 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
2242
2243 #if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2244 // The stride of matrix a, b and c respectively.
2245 int stride_a = k;
2246 int stride_b = k;
2247 int stride_c = n;
2248
2249 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
2250 stride_a, b, stride_b, 0.0f, c, stride_c);
2251 #else
2252 // When an optimized CBLAS implementation is not available, fall back
2253 // to using Eigen.
2254 typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
2255 Matrix;
2256 typedef Eigen::Map<Matrix> MatrixRef;
2257 typedef Eigen::Map<const Matrix> ConstMatrixRef;
2258
2259 MatrixRef matrix_c(c, m, n);
2260 ConstMatrixRef matrix_a(a, m, k);
2261 ConstMatrixRef matrix_b(b, n, k);
2262
2263 // The following special casing for when a or b is a vector is required
2264 // as Eigen seem to fail to make this optimization on its own.
2265 if (n == 1) {
2266 gemmlowp::ScopedProfilingLabel label("GEMV");
2267 matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
2268 } else if (m == 1) {
2269 gemmlowp::ScopedProfilingLabel label("GEMV");
2270 matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
2271 } else {
2272 gemmlowp::ScopedProfilingLabel label("GEMM");
2273 matrix_c.noalias() = matrix_a * matrix_b.transpose();
2274 }
2275
2276 #endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
2277
2278 optimized_ops::AddBiasAndEvalActivationFunction(
2279 output_activation_min, output_activation_max, bias_shape, bias_data,
2280 output_shape, output_data);
2281 }
2282
HybridConv(const ConvParams & params,float * scaling_factors_ptr,const RuntimeShape & input_shape,const int8_t * input_data,const RuntimeShape & filter_shape,const int8_t * filter_data,const RuntimeShape & bias_shape,const float * bias_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,int8_t * im2col_data)2283 inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
2284 const RuntimeShape& input_shape,
2285 const int8_t* input_data,
2286 const RuntimeShape& filter_shape,
2287 const int8_t* filter_data,
2288 const RuntimeShape& bias_shape, const float* bias_data,
2289 const RuntimeShape& output_shape, float* output_data,
2290 const RuntimeShape& im2col_shape, int8_t* im2col_data) {
2291 const int stride_width = params.stride_width;
2292 const int stride_height = params.stride_height;
2293 const float output_activation_min = params.float_activation_min;
2294 const float output_activation_max = params.float_activation_max;
2295 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2296 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2297 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2298
2299 const int batch_size = input_shape.Dims(0);
2300 const int filter_width = filter_shape.Dims(2);
2301 const int filter_height = filter_shape.Dims(1);
2302
2303 const int8_t* gemm_input_data = nullptr;
2304 int num_input;
2305 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2306 filter_width != 1 || filter_height != 1;
2307
2308 if (need_im2col) {
2309 TFLITE_DCHECK(im2col_data);
2310 // symmetric quantization assumes zero point of 0.
2311 const int input_zero_point = 0;
2312
2313 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2314 input_data, im2col_shape, im2col_data);
2315 gemm_input_data = im2col_data;
2316 num_input = im2col_shape.FlatSize();
2317 } else {
2318 TFLITE_DCHECK(!im2col_data);
2319 gemm_input_data = input_data;
2320 num_input = input_shape.FlatSize();
2321 }
2322
2323 // Flatten 4D matrices into 2D matrices for matrix multiplication.
2324
2325 // Flatten so that each filter has its own row.
2326 const int filter_rows = filter_shape.Dims(0);
2327 const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2328
2329 // In MatrixBatchVectorMultiplyAccumulate, each output value is the
2330 // dot product of one row of the first matrix with one row of the second
2331 // matrix. Therefore, the number of cols in each matrix are equivalent.
2332 //
2333 // After Im2Col, each input patch becomes a row.
2334 const int gemm_input_cols = filter_cols;
2335 const int gemm_input_rows = num_input / gemm_input_cols;
2336
2337 const int output_cols = output_shape.Dims(3);
2338 const int output_rows = FlatSizeSkipDim(output_shape, 3);
2339 TFLITE_DCHECK_EQ(output_cols, filter_rows);
2340 TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
2341 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
2342
2343 // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
2344 // input matrix has its own scale factor. This code duplicates the scale
2345 // factors for each row in the same batch.
2346 const int rows_per_batch = gemm_input_rows / batch_size;
2347 for (int i = gemm_input_rows - 1; i >= 0; --i) {
2348 scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
2349 }
2350
2351 tensor_utils::ZeroVector(output_data, output_rows * output_cols);
2352
2353 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
2354 filter_data, filter_rows, filter_cols, gemm_input_data,
2355 scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
2356 /*result_stride=*/1);
2357
2358 AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
2359 bias_shape, bias_data, output_shape,
2360 output_data);
2361 }
2362
Conv(const ConvParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,const RuntimeShape & im2col_shape,uint8 * im2col_data,gemmlowp::GemmContext * gemm_context)2363 inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
2364 const uint8* input_data, const RuntimeShape& filter_shape,
2365 const uint8* filter_data, const RuntimeShape& bias_shape,
2366 const int32* bias_data, const RuntimeShape& output_shape,
2367 uint8* output_data, const RuntimeShape& im2col_shape,
2368 uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
2369 gemmlowp::ScopedProfilingLabel label("Conv/8bit");
2370 const int stride_width = params.stride_width;
2371 const int stride_height = params.stride_height;
2372 const int dilation_width_factor = params.dilation_width_factor;
2373 const int dilation_height_factor = params.dilation_height_factor;
2374 const int32 input_offset = params.input_offset;
2375 const int32 filter_offset = params.weights_offset;
2376 const int32 output_offset = params.output_offset;
2377 const int32 output_multiplier = params.output_multiplier;
2378 const int output_shift = params.output_shift;
2379 const int32 output_activation_min = params.quantized_activation_min;
2380 const int32 output_activation_max = params.quantized_activation_max;
2381 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2382 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2383 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2384
2385 const uint8* gemm_input_data = nullptr;
2386 const RuntimeShape* gemm_input_shape = nullptr;
2387 const int filter_width = filter_shape.Dims(2);
2388 const int filter_height = filter_shape.Dims(1);
2389 const bool need_dilated_im2col =
2390 dilation_width_factor != 1 || dilation_height_factor != 1;
2391 const bool need_im2col = stride_width != 1 || stride_height != 1 ||
2392 filter_width != 1 || filter_height != 1;
2393 if (need_dilated_im2col) {
2394 TFLITE_DCHECK(im2col_data);
2395 const int input_zero_point = -input_offset;
2396 TFLITE_DCHECK_GE(input_zero_point, 0);
2397 TFLITE_DCHECK_LE(input_zero_point, 255);
2398 DilatedIm2col(params, input_zero_point, input_shape, input_data,
2399 filter_shape, output_shape, im2col_data);
2400 gemm_input_data = im2col_data;
2401 gemm_input_shape = &im2col_shape;
2402 } else if (need_im2col) {
2403 TFLITE_DCHECK(im2col_data);
2404 const int input_zero_point = -input_offset;
2405 TFLITE_DCHECK_GE(input_zero_point, 0);
2406 TFLITE_DCHECK_LE(input_zero_point, 255);
2407 Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
2408 input_data, im2col_shape, im2col_data);
2409 gemm_input_data = im2col_data;
2410 gemm_input_shape = &im2col_shape;
2411 } else {
2412 TFLITE_DCHECK(!im2col_data);
2413 gemm_input_data = input_data;
2414 gemm_input_shape = &input_shape;
2415 }
2416
2417 const int gemm_input_rows = gemm_input_shape->Dims(3);
2418 // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
2419 // The root cause has not yet been identified though. Same applies below for
2420 // the other calls commented out. This is a partial rollback of cl/196819423.
2421 // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
2422 const int gemm_input_cols = gemm_input_shape->Dims(0) *
2423 gemm_input_shape->Dims(1) *
2424 gemm_input_shape->Dims(2);
2425 const int filter_rows = filter_shape.Dims(0);
2426 // See b/79927784.
2427 // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
2428 const int filter_cols =
2429 filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
2430 const int output_rows = output_shape.Dims(3);
2431 // See b/79927784.
2432 // const int output_cols = FlatSizeSkipDim(output_shape, 3);
2433 const int output_cols =
2434 output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
2435 TFLITE_DCHECK_EQ(output_rows, filter_rows);
2436 TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
2437 TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
2438 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
2439
2440 #ifdef USE_NEON
2441 if (gemm_input_cols == 1 && output_rows >= 4) {
2442 RuntimeShape fc_filter_shape{
2443 filter_shape.Dims(0),
2444 filter_shape.Dims(filter_shape.DimensionsCount() - 1)};
2445
2446 return FullyConnectedAsGEMV(
2447 *gemm_input_shape, gemm_input_data, input_offset, fc_filter_shape,
2448 filter_data, filter_offset, bias_shape, bias_data, output_offset,
2449 output_multiplier, output_shift, output_activation_min,
2450 output_activation_max, output_shape, output_data, gemm_context);
2451 }
2452 #endif
2453
2454 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
2455 filter_data, filter_rows, filter_cols);
2456 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
2457 gemm_input_data, gemm_input_rows, gemm_input_cols);
2458 gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
2459 output_data, output_rows, output_cols);
2460 const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
2461 bias_data, output_rows, output_offset, output_multiplier, output_shift,
2462 output_activation_min, output_activation_max);
2463 gemmlowp::GemmWithOutputPipeline<uint8, uint8,
2464 gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
2465 gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
2466 input_offset, output_pipeline);
2467 }
2468
2469 template <typename T>
DepthToSpace(const tflite::DepthToSpaceParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)2470 inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
2471 const RuntimeShape& unextended_input_shape,
2472 const T* input_data,
2473 const RuntimeShape& unextended_output_shape,
2474 T* output_data) {
2475 gemmlowp::ScopedProfilingLabel label("DepthToSpace");
2476
2477 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2478 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
2479 const RuntimeShape input_shape =
2480 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2481 const RuntimeShape output_shape =
2482 RuntimeShape::ExtendedShape(4, unextended_output_shape);
2483
2484 const int input_depth = input_shape.Dims(3);
2485 const int input_width = input_shape.Dims(2);
2486 const int input_height = input_shape.Dims(1);
2487
2488 const int output_depth = output_shape.Dims(3);
2489 const int batch_size = output_shape.Dims(0);
2490
2491 // Number of continuous values that we can copy in one interation.
2492 const int stride = op_params.block_size * output_depth;
2493
2494 for (int batch = 0; batch < batch_size; ++batch) {
2495 for (int in_h = 0; in_h < input_height; ++in_h) {
2496 const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
2497 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
2498 const T* src = input_ptr;
2499 for (int in_w = 0; in_w < input_width; ++in_w) {
2500 memcpy(output_data, src, stride * sizeof(T));
2501 output_data += stride;
2502 src += input_depth;
2503 }
2504 input_ptr += stride;
2505 }
2506 }
2507 }
2508 }
2509
2510 template <typename T>
SpaceToDepth(const tflite::SpaceToDepthParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_shape,T * output_data)2511 inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
2512 const RuntimeShape& unextended_input_shape,
2513 const T* input_data,
2514 const RuntimeShape& unextended_output_shape,
2515 T* output_data) {
2516 gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
2517
2518 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
2519 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
2520 const RuntimeShape input_shape =
2521 RuntimeShape::ExtendedShape(4, unextended_input_shape);
2522 const RuntimeShape output_shape =
2523 RuntimeShape::ExtendedShape(4, unextended_output_shape);
2524
2525 const int output_depth = output_shape.Dims(3);
2526 const int output_width = output_shape.Dims(2);
2527 const int output_height = output_shape.Dims(1);
2528
2529 const int input_depth = input_shape.Dims(3);
2530 const int batch_size = input_shape.Dims(0);
2531
2532 // Number of continuous values that we can copy in one interation.
2533 const int stride = op_params.block_size * input_depth;
2534
2535 for (int batch = 0; batch < batch_size; ++batch) {
2536 for (int out_h = 0; out_h < output_height; ++out_h) {
2537 T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
2538 for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
2539 T* dst = output_ptr;
2540 for (int out_w = 0; out_w < output_width; ++out_w) {
2541 memcpy(dst, input_data, stride * sizeof(T));
2542 input_data += stride;
2543 dst += output_depth;
2544 }
2545 output_ptr += stride;
2546 }
2547 }
2548 }
2549 }
2550
Relu(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2551 inline void Relu(const RuntimeShape& input_shape, const float* input_data,
2552 const RuntimeShape& output_shape, float* output_data) {
2553 gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
2554
2555 const auto input = MapAsVector(input_data, input_shape);
2556 auto output = MapAsVector(output_data, output_shape);
2557 output = input.cwiseMax(0.0f);
2558 }
2559
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)2560 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
2561 const RuntimeShape& input_shape,
2562 const float* input_data,
2563 const RuntimeShape& output_shape,
2564 float* output_data) {
2565 gemmlowp::ScopedProfilingLabel label("L2Normalization");
2566 const int trailing_dim = input_shape.DimensionsCount() - 1;
2567 const int outer_size =
2568 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
2569 const int depth =
2570 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
2571 for (int i = 0; i < outer_size; ++i) {
2572 float squared_l2_norm = 0;
2573 for (int c = 0; c < depth; ++c) {
2574 const float val = input_data[c];
2575 squared_l2_norm += val * val;
2576 }
2577 const float l2_norm = std::sqrt(squared_l2_norm);
2578 for (int c = 0; c < depth; ++c) {
2579 *output_data = *input_data / l2_norm;
2580 ++output_data;
2581 ++input_data;
2582 }
2583 }
2584 }
2585
L2Normalization(const tflite::L2NormalizationParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)2586 inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
2587 const RuntimeShape& input_shape,
2588 const uint8* input_data,
2589 const RuntimeShape& output_shape,
2590 uint8* output_data) {
2591 gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
2592 const int trailing_dim = input_shape.DimensionsCount() - 1;
2593 const int depth =
2594 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
2595 const int outer_size =
2596 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
2597 const int32 input_zero_point = op_params.input_zero_point;
2598 for (int i = 0; i < outer_size; ++i) {
2599 int32 square_l2_norm = 0;
2600 for (int c = 0; c < depth; c++) {
2601 // Note that input_data advances by depth in the second pass below.
2602 int32 diff = input_data[c] - input_zero_point;
2603 square_l2_norm += diff * diff;
2604 }
2605 int32 inv_l2norm_multiplier;
2606 int inv_l2norm_shift;
2607 GetInvSqrtQuantizedMultiplierExp(square_l2_norm, kReverseShift,
2608 &inv_l2norm_multiplier, &inv_l2norm_shift);
2609
2610 for (int c = 0; c < depth; c++) {
2611 int32 diff = *input_data - input_zero_point;
2612 int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
2613 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
2614 int32 unclamped_output_val = 128 + rescaled_diff;
2615 int32 output_val = std::min(255, std::max(0, unclamped_output_val));
2616 *output_data = static_cast<uint8>(output_val);
2617 ++input_data;
2618 ++output_data;
2619 }
2620 }
2621 }
2622
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)2623 inline void Add(const ArithmeticParams& params,
2624 const RuntimeShape& input1_shape, const float* input1_data,
2625 const RuntimeShape& input2_shape, const float* input2_data,
2626 const RuntimeShape& output_shape, float* output_data) {
2627 gemmlowp::ScopedProfilingLabel label("Add");
2628
2629 int i = 0;
2630 const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
2631 #ifdef USE_NEON
2632 const auto activation_min = vdupq_n_f32(params.float_activation_min);
2633 const auto activation_max = vdupq_n_f32(params.float_activation_max);
2634 for (; i <= size - 16; i += 16) {
2635 auto a10 = vld1q_f32(input1_data + i);
2636 auto a11 = vld1q_f32(input1_data + i + 4);
2637 auto a12 = vld1q_f32(input1_data + i + 8);
2638 auto a13 = vld1q_f32(input1_data + i + 12);
2639 auto a20 = vld1q_f32(input2_data + i);
2640 auto a21 = vld1q_f32(input2_data + i + 4);
2641 auto a22 = vld1q_f32(input2_data + i + 8);
2642 auto a23 = vld1q_f32(input2_data + i + 12);
2643 auto x0 = vaddq_f32(a10, a20);
2644 auto x1 = vaddq_f32(a11, a21);
2645 auto x2 = vaddq_f32(a12, a22);
2646 auto x3 = vaddq_f32(a13, a23);
2647 x0 = vmaxq_f32(activation_min, x0);
2648 x1 = vmaxq_f32(activation_min, x1);
2649 x2 = vmaxq_f32(activation_min, x2);
2650 x3 = vmaxq_f32(activation_min, x3);
2651 x0 = vminq_f32(activation_max, x0);
2652 x1 = vminq_f32(activation_max, x1);
2653 x2 = vminq_f32(activation_max, x2);
2654 x3 = vminq_f32(activation_max, x3);
2655 vst1q_f32(output_data + i, x0);
2656 vst1q_f32(output_data + i + 4, x1);
2657 vst1q_f32(output_data + i + 8, x2);
2658 vst1q_f32(output_data + i + 12, x3);
2659 }
2660 for (; i <= size - 4; i += 4) {
2661 auto a1 = vld1q_f32(input1_data + i);
2662 auto a2 = vld1q_f32(input2_data + i);
2663 auto x = vaddq_f32(a1, a2);
2664 x = vmaxq_f32(activation_min, x);
2665 x = vminq_f32(activation_max, x);
2666 vst1q_f32(output_data + i, x);
2667 }
2668 #endif // NEON
2669
2670 for (; i < size; i++) {
2671 auto x = input1_data[i] + input2_data[i];
2672 output_data[i] = ActivationFunctionWithMinMax(
2673 x, params.float_activation_min, params.float_activation_max);
2674 }
2675 }
2676
2677 // Element-wise add that can often be used for inner loop of broadcast add as
2678 // well as the non-broadcast add.
AddElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)2679 inline void AddElementwise(int size, const ArithmeticParams& params,
2680 const uint8* input1_data, const uint8* input2_data,
2681 uint8* output_data) {
2682 gemmlowp::ScopedProfilingLabel label("AddElementwise/8bit");
2683 int i = 0;
2684 TFLITE_DCHECK_GT(params.input1_offset, -256);
2685 TFLITE_DCHECK_GT(params.input2_offset, -256);
2686 TFLITE_DCHECK_LT(params.input1_offset, 256);
2687 TFLITE_DCHECK_LT(params.input2_offset, 256);
2688 #ifdef USE_NEON
2689 const uint8x8_t output_activation_min_vector =
2690 vdup_n_u8(params.quantized_activation_min);
2691 const uint8x8_t output_activation_max_vector =
2692 vdup_n_u8(params.quantized_activation_max);
2693 for (; i <= size - 8; i += 8) {
2694 const uint8x8_t input1_val_original = vld1_u8(input1_data + i);
2695 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
2696 const int16x8_t input1_val_s16 =
2697 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2698 const int16x8_t input2_val_s16 =
2699 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2700 const int16x8_t input1_val =
2701 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
2702 const int16x8_t input2_val =
2703 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
2704 const int16x4_t input1_val_high = vget_high_s16(input1_val);
2705 const int16x4_t input1_val_low = vget_low_s16(input1_val);
2706 const int16x4_t input2_val_high = vget_high_s16(input2_val);
2707 const int16x4_t input2_val_low = vget_low_s16(input2_val);
2708 int32x4_t x11 = vmovl_s16(input1_val_low);
2709 int32x4_t x12 = vmovl_s16(input1_val_high);
2710 int32x4_t x21 = vmovl_s16(input2_val_low);
2711 int32x4_t x22 = vmovl_s16(input2_val_high);
2712 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
2713 x11 = vshlq_s32(x11, left_shift_dup);
2714 x12 = vshlq_s32(x12, left_shift_dup);
2715 x21 = vshlq_s32(x21, left_shift_dup);
2716 x22 = vshlq_s32(x22, left_shift_dup);
2717 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
2718 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
2719 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2720 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2721 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
2722 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2723 x11 = vshlq_s32(x11, input1_shift_dup);
2724 x12 = vshlq_s32(x12, input1_shift_dup);
2725 x21 = vshlq_s32(x21, input2_shift_dup);
2726 x22 = vshlq_s32(x22, input2_shift_dup);
2727 int32x4_t s1 = vaddq_s32(x11, x21);
2728 int32x4_t s2 = vaddq_s32(x12, x22);
2729 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2730 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2731 using gemmlowp::RoundingDivideByPOT;
2732 s1 = RoundingDivideByPOT(s1, -params.output_shift);
2733 s2 = RoundingDivideByPOT(s2, -params.output_shift);
2734 const int16x4_t s1_narrowed = vmovn_s32(s1);
2735 const int16x4_t s2_narrowed = vmovn_s32(s2);
2736 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2737 vdupq_n_s16(params.output_offset));
2738 const uint8x8_t clamped =
2739 vmax_u8(output_activation_min_vector,
2740 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2741 vst1_u8(output_data + i, clamped);
2742 }
2743 #endif // NEON
2744
2745 for (; i < size; ++i) {
2746 const int32 input1_val = params.input1_offset + input1_data[i];
2747 const int32 input2_val = params.input2_offset + input2_data[i];
2748 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2749 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2750 const int32 scaled_input1_val =
2751 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2752 shifted_input1_val, params.input1_multiplier, params.input1_shift);
2753 const int32 scaled_input2_val =
2754 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2755 shifted_input2_val, params.input2_multiplier, params.input2_shift);
2756 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2757 const int32 raw_output =
2758 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2759 raw_sum, params.output_multiplier, params.output_shift) +
2760 params.output_offset;
2761 const int32 clamped_output =
2762 std::min(params.quantized_activation_max,
2763 std::max(params.quantized_activation_min, raw_output));
2764 output_data[i] = static_cast<uint8>(clamped_output);
2765 }
2766 }
2767
2768 // Scalar-broadcast add that can be used for inner loop of more general
2769 // broadcast add, so that, for example, scalar-broadcast with batch will still
2770 // be fast.
AddScalarBroadcast(int size,const ArithmeticParams & params,uint8 input1_data,const uint8 * input2_data,uint8 * output_data)2771 inline void AddScalarBroadcast(int size, const ArithmeticParams& params,
2772 uint8 input1_data, const uint8* input2_data,
2773 uint8* output_data) {
2774 using gemmlowp::RoundingDivideByPOT;
2775
2776 gemmlowp::ScopedProfilingLabel label("AddScalarBroadcast/8bit");
2777 TFLITE_DCHECK_GT(params.input1_offset, -256);
2778 TFLITE_DCHECK_GT(params.input2_offset, -256);
2779 TFLITE_DCHECK_LT(params.input1_offset, 256);
2780 TFLITE_DCHECK_LT(params.input2_offset, 256);
2781
2782 int i = 0;
2783
2784 #ifdef USE_NEON
2785 const int32x4_t left_shift_dup = vdupq_n_s32(params.left_shift);
2786 const uint8x8_t output_activation_min_vector =
2787 vdup_n_u8(params.quantized_activation_min);
2788 const uint8x8_t output_activation_max_vector =
2789 vdup_n_u8(params.quantized_activation_max);
2790
2791 // Process broadcast scalar.
2792 const uint8x8_t input1_val_original = vdup_n_u8(input1_data);
2793 const int16x8_t input1_val_s16 =
2794 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
2795 const int16x8_t input1_val =
2796 vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset));
2797 const int16x4_t input1_val_high = vget_high_s16(input1_val);
2798 const int16x4_t input1_val_low = vget_low_s16(input1_val);
2799 int32x4_t x11 = vmovl_s16(input1_val_low);
2800 int32x4_t x12 = vmovl_s16(input1_val_high);
2801 x11 = vshlq_s32(x11, left_shift_dup);
2802 x12 = vshlq_s32(x12, left_shift_dup);
2803 x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier);
2804 x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier);
2805 const int32x4_t input1_shift_dup = vdupq_n_s32(params.input1_shift);
2806 x11 = vshlq_s32(x11, input1_shift_dup);
2807 x12 = vshlq_s32(x12, input1_shift_dup);
2808
2809 for (; i <= size - 8; i += 8) {
2810 const uint8x8_t input2_val_original = vld1_u8(input2_data + i);
2811 const int16x8_t input2_val_s16 =
2812 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
2813 const int16x8_t input2_val =
2814 vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset));
2815 const int16x4_t input2_val_high = vget_high_s16(input2_val);
2816 const int16x4_t input2_val_low = vget_low_s16(input2_val);
2817 int32x4_t x21 = vmovl_s16(input2_val_low);
2818 int32x4_t x22 = vmovl_s16(input2_val_high);
2819 x21 = vshlq_s32(x21, left_shift_dup);
2820 x22 = vshlq_s32(x22, left_shift_dup);
2821 x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier);
2822 x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier);
2823 const int32x4_t input2_shift_dup = vdupq_n_s32(params.input2_shift);
2824 x21 = vshlq_s32(x21, input2_shift_dup);
2825 x22 = vshlq_s32(x22, input2_shift_dup);
2826 int32x4_t s1 = vaddq_s32(x11, x21);
2827 int32x4_t s2 = vaddq_s32(x12, x22);
2828 s1 = vqrdmulhq_n_s32(s1, params.output_multiplier);
2829 s2 = vqrdmulhq_n_s32(s2, params.output_multiplier);
2830 s1 = RoundingDivideByPOT(s1, -params.output_shift);
2831 s2 = RoundingDivideByPOT(s2, -params.output_shift);
2832 const int16x4_t s1_narrowed = vmovn_s32(s1);
2833 const int16x4_t s2_narrowed = vmovn_s32(s2);
2834 const int16x8_t s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
2835 vdupq_n_s16(params.output_offset));
2836 const uint8x8_t clamped =
2837 vmax_u8(output_activation_min_vector,
2838 vmin_u8(output_activation_max_vector, vqmovun_s16(s)));
2839 vst1_u8(output_data + i, clamped);
2840 }
2841 #endif // NEON
2842
2843 if (i < size) {
2844 // Process broadcast scalar.
2845 const int32 input1_val = params.input1_offset + input1_data;
2846 const int32 shifted_input1_val = input1_val * (1 << params.left_shift);
2847 const int32 scaled_input1_val =
2848 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2849 shifted_input1_val, params.input1_multiplier, params.input1_shift);
2850
2851 for (; i < size; ++i) {
2852 const int32 input2_val = params.input2_offset + input2_data[i];
2853 const int32 shifted_input2_val = input2_val * (1 << params.left_shift);
2854 const int32 scaled_input2_val =
2855 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2856 shifted_input2_val, params.input2_multiplier,
2857 params.input2_shift);
2858 const int32 raw_sum = scaled_input1_val + scaled_input2_val;
2859 const int32 raw_output =
2860 MultiplyByQuantizedMultiplierSmallerThanOneExp(
2861 raw_sum, params.output_multiplier, params.output_shift) +
2862 params.output_offset;
2863 const int32 clamped_output =
2864 std::min(params.quantized_activation_max,
2865 std::max(params.quantized_activation_min, raw_output));
2866 output_data[i] = static_cast<uint8>(clamped_output);
2867 }
2868 }
2869 }
2870
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)2871 inline void Add(const ArithmeticParams& params,
2872 const RuntimeShape& input1_shape, const uint8* input1_data,
2873 const RuntimeShape& input2_shape, const uint8* input2_data,
2874 const RuntimeShape& output_shape, uint8* output_data) {
2875 TFLITE_DCHECK_LE(params.quantized_activation_min,
2876 params.quantized_activation_max);
2877 gemmlowp::ScopedProfilingLabel label("Add/8bit");
2878 const int flat_size =
2879 MatchingFlatSize(input1_shape, input2_shape, output_shape);
2880
2881 TFLITE_DCHECK_GT(params.input1_offset, -256);
2882 TFLITE_DCHECK_GT(params.input2_offset, -256);
2883 TFLITE_DCHECK_LT(params.input1_offset, 256);
2884 TFLITE_DCHECK_LT(params.input2_offset, 256);
2885 AddElementwise(flat_size, params, input1_data, input2_data, output_data);
2886 }
2887
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)2888 inline void Add(const ArithmeticParams& params,
2889 const RuntimeShape& input1_shape, const int16* input1_data,
2890 const RuntimeShape& input2_shape, const int16* input2_data,
2891 const RuntimeShape& output_shape, int16* output_data) {
2892 gemmlowp::ScopedProfilingLabel label("Add/Int16");
2893 TFLITE_DCHECK_LE(params.quantized_activation_min,
2894 params.quantized_activation_max);
2895
2896 const int input1_shift = params.input1_shift;
2897 const int flat_size =
2898 MatchingFlatSize(output_shape, input1_shape, input2_shape);
2899 const int16 output_activation_min = params.quantized_activation_min;
2900 const int16 output_activation_max = params.quantized_activation_max;
2901
2902 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
2903 TFLITE_DCHECK_LE(input1_shift, 0);
2904 TFLITE_DCHECK_LE(params.input2_shift, 0);
2905 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
2906 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
2907 const int input_right_shift =
2908 input1_shift == 0 ? -params.input2_shift : -input1_shift;
2909
2910 for (int i = 0; i < flat_size; i++) {
2911 // F0 uses 0 integer bits, range [-1, 1].
2912 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
2913
2914 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
2915 F0 scaled_input = F0::FromRaw(
2916 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
2917 F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled);
2918 const int16 raw_output = result.raw();
2919 const int16 clamped_output = std::min(
2920 output_activation_max, std::max(output_activation_min, raw_output));
2921 output_data[i] = clamped_output;
2922 }
2923 }
2924
Add(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)2925 inline void Add(const ArithmeticParams& params,
2926 const RuntimeShape& input1_shape, const int32* input1_data,
2927 const RuntimeShape& input2_shape, const int32* input2_data,
2928 const RuntimeShape& output_shape, int32* output_data) {
2929 gemmlowp::ScopedProfilingLabel label("Add/int32");
2930
2931 auto input1_map = MapAsVector(input1_data, input1_shape);
2932 auto input2_map = MapAsVector(input2_data, input2_shape);
2933 auto output_map = MapAsVector(output_data, output_shape);
2934 if (input1_shape == input2_shape) {
2935 output_map.array() = input1_map.array() + input2_map.array();
2936 } else if (input2_shape.FlatSize() == 1) {
2937 auto scalar = input2_data[0];
2938 output_map.array() = input1_map.array() + scalar;
2939 } else if (input1_shape.FlatSize() == 1) {
2940 auto scalar = input1_data[0];
2941 output_map.array() = scalar + input2_map.array();
2942 } else {
2943 // Should not come here.
2944 TFLITE_DCHECK(false);
2945 }
2946 output_map = output_map.cwiseMax(params.quantized_activation_min);
2947 output_map = output_map.cwiseMin(params.quantized_activation_max);
2948 }
2949
BroadcastAddFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)2950 inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
2951 const RuntimeShape& unswitched_input1_shape,
2952 const uint8* unswitched_input1_data,
2953 const RuntimeShape& unswitched_input2_shape,
2954 const uint8* unswitched_input2_data,
2955 const RuntimeShape& output_shape,
2956 uint8* output_data) {
2957 gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit");
2958
2959 ArithmeticParams switched_params = unswitched_params;
2960 switched_params.input1_offset = unswitched_params.input2_offset;
2961 switched_params.input1_multiplier = unswitched_params.input2_multiplier;
2962 switched_params.input1_shift = unswitched_params.input2_shift;
2963 switched_params.input2_offset = unswitched_params.input1_offset;
2964 switched_params.input2_multiplier = unswitched_params.input1_multiplier;
2965 switched_params.input2_shift = unswitched_params.input1_shift;
2966
2967 const bool use_unswitched =
2968 unswitched_params.broadcast_category ==
2969 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
2970
2971 const ArithmeticParams& params =
2972 use_unswitched ? unswitched_params : switched_params;
2973 const uint8* input1_data =
2974 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
2975 const uint8* input2_data =
2976 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
2977
2978 // Fivefold nested loops. The second input resets its position for each
2979 // iteration of the second loop. The first input resets its position at the
2980 // beginning of the fourth loop. The innermost loop is an elementwise add of
2981 // sections of the arrays.
2982 uint8* output_data_ptr = output_data;
2983 const uint8* input1_data_ptr = input1_data;
2984 const uint8* input2_data_reset = input2_data;
2985 // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
2986 // between input shapes. y3 for input 1 is always broadcast, and so the
2987 // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
2988 // Put another way,
2989 // input1.shape.FlatSize = y0 * y1 * y2 * y4,
2990 // input2.shape.FlatSize = y0 * y2 * y3 * y4.
2991 int y0 = params.broadcast_shape[0];
2992 int y1 = params.broadcast_shape[1];
2993 int y2 = params.broadcast_shape[2];
2994 int y3 = params.broadcast_shape[3];
2995 int y4 = params.broadcast_shape[4];
2996 if (y4 > 1) {
2997 // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner
2998 // dimension.
2999 for (int i0 = 0; i0 < y0; ++i0) {
3000 const uint8* input2_data_ptr = nullptr;
3001 for (int i1 = 0; i1 < y1; ++i1) {
3002 input2_data_ptr = input2_data_reset;
3003 for (int i2 = 0; i2 < y2; ++i2) {
3004 for (int i3 = 0; i3 < y3; ++i3) {
3005 AddElementwise(y4, params, input1_data_ptr, input2_data_ptr,
3006 output_data_ptr);
3007 input2_data_ptr += y4;
3008 output_data_ptr += y4;
3009 }
3010 // We have broadcast y4 of input1 data y3 times, and now move on.
3011 input1_data_ptr += y4;
3012 }
3013 }
3014 // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
3015 input2_data_reset = input2_data_ptr;
3016 }
3017 } else {
3018 // Special case of y4 == 1, in which the innermost loop is a single element
3019 // and can be combined with the next (y3) as an inner broadcast.
3020 //
3021 // Note that this handles the case of pure scalar broadcast when
3022 // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar
3023 // broadcast with batch (as y2 > 1).
3024 //
3025 // NOTE The process is the same as the above general case except simplified
3026 // for y4 == 1 and the loop over y3 is contained within the
3027 // AddScalarBroadcast function.
3028 for (int i0 = 0; i0 < y0; ++i0) {
3029 const uint8* input2_data_ptr = nullptr;
3030 for (int i1 = 0; i1 < y1; ++i1) {
3031 input2_data_ptr = input2_data_reset;
3032 for (int i2 = 0; i2 < y2; ++i2) {
3033 AddScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
3034 output_data_ptr);
3035 input2_data_ptr += y3;
3036 output_data_ptr += y3;
3037 input1_data_ptr += 1;
3038 }
3039 }
3040 input2_data_reset = input2_data_ptr;
3041 }
3042 }
3043 }
3044
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3045 inline void Mul(const ArithmeticParams& params,
3046 const RuntimeShape& input1_shape, const float* input1_data,
3047 const RuntimeShape& input2_shape, const float* input2_data,
3048 const RuntimeShape& output_shape, float* output_data) {
3049 gemmlowp::ScopedProfilingLabel label("Mul");
3050 const float output_activation_min = params.float_activation_min;
3051 const float output_activation_max = params.float_activation_max;
3052
3053 int i = 0;
3054 const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
3055 #ifdef USE_NEON
3056 const auto activation_min = vdupq_n_f32(output_activation_min);
3057 const auto activation_max = vdupq_n_f32(output_activation_max);
3058 for (; i <= size - 16; i += 16) {
3059 auto a10 = vld1q_f32(input1_data + i);
3060 auto a11 = vld1q_f32(input1_data + i + 4);
3061 auto a12 = vld1q_f32(input1_data + i + 8);
3062 auto a13 = vld1q_f32(input1_data + i + 12);
3063 auto a20 = vld1q_f32(input2_data + i);
3064 auto a21 = vld1q_f32(input2_data + i + 4);
3065 auto a22 = vld1q_f32(input2_data + i + 8);
3066 auto a23 = vld1q_f32(input2_data + i + 12);
3067 auto x0 = vmulq_f32(a10, a20);
3068 auto x1 = vmulq_f32(a11, a21);
3069 auto x2 = vmulq_f32(a12, a22);
3070 auto x3 = vmulq_f32(a13, a23);
3071
3072 x0 = vmaxq_f32(activation_min, x0);
3073 x1 = vmaxq_f32(activation_min, x1);
3074 x2 = vmaxq_f32(activation_min, x2);
3075 x3 = vmaxq_f32(activation_min, x3);
3076 x0 = vminq_f32(activation_max, x0);
3077 x1 = vminq_f32(activation_max, x1);
3078 x2 = vminq_f32(activation_max, x2);
3079 x3 = vminq_f32(activation_max, x3);
3080
3081 vst1q_f32(output_data + i, x0);
3082 vst1q_f32(output_data + i + 4, x1);
3083 vst1q_f32(output_data + i + 8, x2);
3084 vst1q_f32(output_data + i + 12, x3);
3085 }
3086 for (; i <= size - 4; i += 4) {
3087 auto a1 = vld1q_f32(input1_data + i);
3088 auto a2 = vld1q_f32(input2_data + i);
3089 auto x = vmulq_f32(a1, a2);
3090
3091 x = vmaxq_f32(activation_min, x);
3092 x = vminq_f32(activation_max, x);
3093
3094 vst1q_f32(output_data + i, x);
3095 }
3096 #endif // NEON
3097
3098 for (; i < size; i++) {
3099 auto x = input1_data[i] * input2_data[i];
3100 output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
3101 output_activation_max);
3102 }
3103 }
3104
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3105 inline void Mul(const ArithmeticParams& params,
3106 const RuntimeShape& input1_shape, const int32* input1_data,
3107 const RuntimeShape& input2_shape, const int32* input2_data,
3108 const RuntimeShape& output_shape, int32* output_data) {
3109 gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
3110
3111 const int flat_size =
3112 MatchingFlatSize(input1_shape, input2_shape, output_shape);
3113 const int32 output_activation_min = params.quantized_activation_min;
3114 const int32 output_activation_max = params.quantized_activation_max;
3115 for (int i = 0; i < flat_size; ++i) {
3116 output_data[i] = ActivationFunctionWithMinMax(
3117 input1_data[i] * input2_data[i], output_activation_min,
3118 output_activation_max);
3119 }
3120 }
3121
MulNoActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3122 inline void MulNoActivation(const ArithmeticParams& params,
3123 const RuntimeShape& input1_shape,
3124 const int32* input1_data,
3125 const RuntimeShape& input2_shape,
3126 const int32* input2_data,
3127 const RuntimeShape& output_shape,
3128 int32* output_data) {
3129 gemmlowp::ScopedProfilingLabel label("Mul/int32");
3130
3131 auto input1_map = MapAsVector(input1_data, input1_shape);
3132 auto input2_map = MapAsVector(input2_data, input2_shape);
3133 auto output_map = MapAsVector(output_data, output_shape);
3134 if (input1_shape == input2_shape) {
3135 output_map.array() = input1_map.array() * input2_map.array();
3136 } else if (input2_shape.FlatSize() == 1) {
3137 auto scalar = input2_data[0];
3138 output_map.array() = input1_map.array() * scalar;
3139 } else if (input1_shape.FlatSize() == 1) {
3140 auto scalar = input1_data[0];
3141 output_map.array() = scalar * input2_map.array();
3142 } else {
3143 // Should not come here.
3144 TFLITE_DCHECK(false);
3145 }
3146 }
3147
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)3148 inline void Mul(const ArithmeticParams& params,
3149 const RuntimeShape& input1_shape, const int16* input1_data,
3150 const RuntimeShape& input2_shape, const int16* input2_data,
3151 const RuntimeShape& output_shape, int16* output_data) {
3152 gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation");
3153 // This is a copy of the reference implementation. We do not currently have a
3154 // properly optimized version.
3155
3156 const int flat_size =
3157 MatchingFlatSize(input1_shape, input2_shape, output_shape);
3158
3159 for (int i = 0; i < flat_size; i++) {
3160 // F0 uses 0 integer bits, range [-1, 1].
3161 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3162
3163 F0 unclamped_result =
3164 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
3165 output_data[i] = unclamped_result.raw();
3166 }
3167 }
3168
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)3169 inline void Mul(const ArithmeticParams& params,
3170 const RuntimeShape& input1_shape, const int16* input1_data,
3171 const RuntimeShape& input2_shape, const int16* input2_data,
3172 const RuntimeShape& output_shape, uint8* output_data) {
3173 gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
3174 // This is a copy of the reference implementation. We do not currently have a
3175 // properly optimized version.
3176 const int32 output_activation_min = params.quantized_activation_min;
3177 const int32 output_activation_max = params.quantized_activation_max;
3178 const int32 output_offset = params.output_offset;
3179 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
3180
3181 const int flat_size =
3182 MatchingFlatSize(input1_shape, input2_shape, output_shape);
3183
3184 for (int i = 0; i < flat_size; i++) {
3185 // F0 uses 0 integer bits, range [-1, 1].
3186 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3187
3188 F0 unclamped_result =
3189 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
3190 int16 rescaled_result =
3191 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
3192 int16 clamped_result =
3193 std::min<int16>(output_activation_max - output_offset, rescaled_result);
3194 clamped_result =
3195 std::max<int16>(output_activation_min - output_offset, clamped_result);
3196 output_data[i] = output_offset + clamped_result;
3197 }
3198 }
3199
3200 // Element-wise mul that can often be used for inner loop of broadcast Mul as
3201 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const uint8 * input1_data,const uint8 * input2_data,uint8 * output_data)3202 inline void MulElementwise(int size, const ArithmeticParams& params,
3203 const uint8* input1_data, const uint8* input2_data,
3204 uint8* output_data) {
3205 int i = 0;
3206 TFLITE_DCHECK_GT(params.input1_offset, -256);
3207 TFLITE_DCHECK_LT(params.input1_offset, 256);
3208 TFLITE_DCHECK_GT(params.input2_offset, -256);
3209 TFLITE_DCHECK_LT(params.input2_offset, 256);
3210 TFLITE_DCHECK_GT(params.output_offset, -256);
3211 TFLITE_DCHECK_LT(params.output_offset, 256);
3212 #ifdef USE_NEON
3213 const auto input1_offset_vector = vdupq_n_s16(params.input1_offset);
3214 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
3215 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
3216 const auto output_activation_min_vector =
3217 vdup_n_u8(params.quantized_activation_min);
3218 const auto output_activation_max_vector =
3219 vdup_n_u8(params.quantized_activation_max);
3220 for (; i <= size - 8; i += 8) {
3221 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
3222 const auto input1_val_original = vld1_u8(input1_data + i);
3223 const auto input2_val_original = vld1_u8(input2_data + i);
3224 const auto input1_val_s16 =
3225 vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
3226 const auto input2_val_s16 =
3227 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
3228 const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector);
3229 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
3230
3231 const auto input1_val_low = vget_low_s16(input1_val);
3232 const auto input1_val_high = vget_high_s16(input1_val);
3233 const auto input2_val_low = vget_low_s16(input2_val);
3234 const auto input2_val_high = vget_high_s16(input2_val);
3235
3236 auto p1 = vmull_s16(input2_val_low, input1_val_low);
3237 auto p2 = vmull_s16(input2_val_high, input1_val_high);
3238
3239 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
3240 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
3241 using gemmlowp::RoundingDivideByPOT;
3242 p1 = RoundingDivideByPOT(p1, -params.output_shift);
3243 p2 = RoundingDivideByPOT(p2, -params.output_shift);
3244
3245 const auto p1_narrowed = vmovn_s32(p1);
3246 const auto p2_narrowed = vmovn_s32(p2);
3247 const auto p =
3248 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
3249 const auto clamped =
3250 vmax_u8(output_activation_min_vector,
3251 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
3252 vst1_u8(output_data + i, clamped);
3253 }
3254 #endif // NEON
3255
3256 for (; i < size; ++i) {
3257 const int32 input1_val = params.input1_offset + input1_data[i];
3258 const int32 input2_val = params.input2_offset + input2_data[i];
3259 const int32 unclamped_result =
3260 params.output_offset +
3261 MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
3262 params.output_multiplier,
3263 params.output_shift);
3264 const int32 clamped_output =
3265 std::min(params.quantized_activation_max,
3266 std::max(params.quantized_activation_min, unclamped_result));
3267 output_data[i] = static_cast<uint8>(clamped_output);
3268 }
3269 }
3270
3271 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const uint8 broadcast_value,const uint8 * input2_data,uint8 * output_data)3272 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
3273 const uint8 broadcast_value,
3274 const uint8* input2_data, uint8* output_data) {
3275 const int16 input1_val = params.input1_offset + broadcast_value;
3276
3277 int i = 0;
3278 TFLITE_DCHECK_GT(params.input1_offset, -256);
3279 TFLITE_DCHECK_LT(params.input1_offset, 256);
3280 TFLITE_DCHECK_GT(params.input2_offset, -256);
3281 TFLITE_DCHECK_LT(params.input2_offset, 256);
3282 TFLITE_DCHECK_GT(params.output_offset, -256);
3283 TFLITE_DCHECK_LT(params.output_offset, 256);
3284 #ifdef USE_NEON
3285 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
3286 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
3287 const auto output_activation_min_vector =
3288 vdup_n_u8(params.quantized_activation_min);
3289 const auto output_activation_max_vector =
3290 vdup_n_u8(params.quantized_activation_max);
3291 for (; i <= size - 8; i += 8) {
3292 // We load / store 8 at a time, multiplying as two sets of 4 int32s.
3293 const auto input2_val_original = vld1_u8(input2_data + i);
3294 const auto input2_val_s16 =
3295 vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
3296 const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
3297
3298 const auto input2_val_low = vget_low_s16(input2_val);
3299 const auto input2_val_high = vget_high_s16(input2_val);
3300
3301 auto p1 = vmull_n_s16(input2_val_low, input1_val);
3302 auto p2 = vmull_n_s16(input2_val_high, input1_val);
3303
3304 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
3305 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
3306 using gemmlowp::RoundingDivideByPOT;
3307 p1 = RoundingDivideByPOT(p1, -params.output_shift);
3308 p2 = RoundingDivideByPOT(p2, -params.output_shift);
3309
3310 const auto p1_narrowed = vmovn_s32(p1);
3311 const auto p2_narrowed = vmovn_s32(p2);
3312 const auto p =
3313 vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
3314 const auto clamped =
3315 vmax_u8(output_activation_min_vector,
3316 vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
3317 vst1_u8(output_data + i, clamped);
3318 }
3319 #endif // NEON
3320
3321 for (; i < size; ++i) {
3322 const int32 input2_val = params.input2_offset + input2_data[i];
3323 const int32 unclamped_result =
3324 params.output_offset +
3325 MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
3326 params.output_multiplier,
3327 params.output_shift);
3328 const int32 clamped_output =
3329 std::min(params.quantized_activation_max,
3330 std::max(params.quantized_activation_min, unclamped_result));
3331 output_data[i] = static_cast<uint8>(clamped_output);
3332 }
3333 }
3334
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const uint8 * input1_data,const RuntimeShape & input2_shape,const uint8 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)3335 inline void Mul(const ArithmeticParams& params,
3336 const RuntimeShape& input1_shape, const uint8* input1_data,
3337 const RuntimeShape& input2_shape, const uint8* input2_data,
3338 const RuntimeShape& output_shape, uint8* output_data) {
3339 TFLITE_DCHECK_LE(params.quantized_activation_min,
3340 params.quantized_activation_max);
3341 gemmlowp::ScopedProfilingLabel label("Mul/8bit");
3342 const int flat_size =
3343 MatchingFlatSize(input1_shape, input2_shape, output_shape);
3344
3345 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
3346 }
3347
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)3348 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
3349 const RuntimeShape& unswitched_input1_shape,
3350 const uint8* unswitched_input1_data,
3351 const RuntimeShape& unswitched_input2_shape,
3352 const uint8* unswitched_input2_data,
3353 const RuntimeShape& output_shape,
3354 uint8* output_data) {
3355 gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit");
3356
3357 ArithmeticParams switched_params = unswitched_params;
3358 switched_params.input1_offset = unswitched_params.input2_offset;
3359 switched_params.input2_offset = unswitched_params.input1_offset;
3360
3361 const bool use_unswitched =
3362 unswitched_params.broadcast_category ==
3363 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
3364
3365 const ArithmeticParams& params =
3366 use_unswitched ? unswitched_params : switched_params;
3367 const uint8* input1_data =
3368 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
3369 const uint8* input2_data =
3370 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
3371
3372 // Fivefold nested loops. The second input resets its position for each
3373 // iteration of the second loop. The first input resets its position at the
3374 // beginning of the fourth loop. The innermost loop is an elementwise Mul of
3375 // sections of the arrays.
3376 uint8* output_data_ptr = output_data;
3377 const uint8* input1_data_ptr = input1_data;
3378 const uint8* input2_data_reset = input2_data;
3379 int y0 = params.broadcast_shape[0];
3380 int y1 = params.broadcast_shape[1];
3381 int y2 = params.broadcast_shape[2];
3382 int y3 = params.broadcast_shape[3];
3383 int y4 = params.broadcast_shape[4];
3384 if (y4 > 1) {
3385 for (int i0 = 0; i0 < y0; ++i0) {
3386 const uint8* input2_data_ptr = nullptr;
3387 for (int i1 = 0; i1 < y1; ++i1) {
3388 input2_data_ptr = input2_data_reset;
3389 for (int i2 = 0; i2 < y2; ++i2) {
3390 for (int i3 = 0; i3 < y3; ++i3) {
3391 MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
3392 output_data_ptr);
3393 input2_data_ptr += y4;
3394 output_data_ptr += y4;
3395 }
3396 input1_data_ptr += y4;
3397 }
3398 }
3399 input2_data_reset = input2_data_ptr;
3400 }
3401 } else {
3402 for (int i0 = 0; i0 < y0; ++i0) {
3403 const uint8* input2_data_ptr = nullptr;
3404 for (int i1 = 0; i1 < y1; ++i1) {
3405 input2_data_ptr = input2_data_reset;
3406 for (int i2 = 0; i2 < y2; ++i2) {
3407 MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr,
3408 output_data_ptr);
3409 input2_data_ptr += y3;
3410 output_data_ptr += y3;
3411 ++input1_data_ptr;
3412 }
3413 }
3414 input2_data_reset = input2_data_ptr;
3415 }
3416 }
3417 }
3418
3419 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
3420 // dimensionality if the runtime code does a single loop over one dimension
3421 // that handles broadcasting as the base case. The code generator would then
3422 // generate max(D1, D2) nested for loops.
3423 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
3424 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
3425 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
3426 // reference_ops.h.
3427 template <typename T>
BroadcastDiv4DSlow(const ArithmeticParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)3428 void BroadcastDiv4DSlow(const ArithmeticParams& params,
3429 const RuntimeShape& unextended_input1_shape,
3430 const T* input1_data,
3431 const RuntimeShape& unextended_input2_shape,
3432 const T* input2_data,
3433 const RuntimeShape& unextended_output_shape,
3434 T* output_data) {
3435 gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
3436 T output_activation_min;
3437 T output_activation_max;
3438 GetActivationParams(params, &output_activation_min, &output_activation_max);
3439
3440 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
3441 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
3442 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
3443 const RuntimeShape output_shape =
3444 RuntimeShape::ExtendedShape(4, unextended_output_shape);
3445
3446 NdArrayDesc<4> desc1;
3447 NdArrayDesc<4> desc2;
3448 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
3449 unextended_input2_shape, &desc1, &desc2);
3450
3451 // In Tensorflow, the dimensions are canonically named (batch_number, row,
3452 // col, channel), with extents (batches, height, width, depth), with the
3453 // trailing dimension changing most rapidly (channels has the smallest stride,
3454 // typically 1 element).
3455 //
3456 // In generated C code, we store arrays with the dimensions reversed. The
3457 // first dimension has smallest stride.
3458 //
3459 // We name our variables by their Tensorflow convention, but generate C code
3460 // nesting loops such that the innermost loop has the smallest stride for the
3461 // best cache behavior.
3462 for (int b = 0; b < output_shape.Dims(0); ++b) {
3463 for (int y = 0; y < output_shape.Dims(1); ++y) {
3464 for (int x = 0; x < output_shape.Dims(2); ++x) {
3465 for (int c = 0; c < output_shape.Dims(3); ++c) {
3466 output_data[Offset(output_shape, b, y, x, c)] =
3467 ActivationFunctionWithMinMax(
3468 input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
3469 input2_data[SubscriptToIndex(desc2, b, y, x, c)],
3470 output_activation_min, output_activation_max);
3471 }
3472 }
3473 }
3474 }
3475 }
3476
3477 // TODO(aselle): This is not actually optimized yet.
SubNonBroadcast(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3478 inline void SubNonBroadcast(const ArithmeticParams& params,
3479 const RuntimeShape& input1_shape,
3480 const float* input1_data,
3481 const RuntimeShape& input2_shape,
3482 const float* input2_data,
3483 const RuntimeShape& output_shape,
3484 float* output_data) {
3485 gemmlowp::ScopedProfilingLabel label("SubNonBroadcast");
3486 const int flat_size =
3487 MatchingFlatSize(input1_shape, input2_shape, output_shape);
3488 for (int i = 0; i < flat_size; ++i) {
3489 output_data[i] = ActivationFunctionWithMinMax(
3490 input1_data[i] - input2_data[i], params.float_activation_min,
3491 params.float_activation_max);
3492 }
3493 }
3494
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int32 * input1_data,const RuntimeShape & input2_shape,const int32 * input2_data,const RuntimeShape & output_shape,int32 * output_data)3495 inline void SubWithActivation(const ArithmeticParams& params,
3496 const RuntimeShape& input1_shape,
3497 const int32* input1_data,
3498 const RuntimeShape& input2_shape,
3499 const int32* input2_data,
3500 const RuntimeShape& output_shape,
3501 int32* output_data) {
3502 gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32");
3503 const int flat_size =
3504 MatchingFlatSize(input1_shape, input2_shape, input2_shape);
3505 for (int i = 0; i < flat_size; ++i) {
3506 output_data[i] = ActivationFunctionWithMinMax(
3507 input1_data[i] - input2_data[i], params.quantized_activation_min,
3508 params.quantized_activation_max);
3509 }
3510 }
3511
SubWithActivation(const ArithmeticParams & params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,float * output_data)3512 inline void SubWithActivation(const ArithmeticParams& params,
3513 const RuntimeShape& input1_shape,
3514 const float* input1_data,
3515 const RuntimeShape& input2_shape,
3516 const float* input2_data,
3517 const RuntimeShape& output_shape,
3518 float* output_data) {
3519 gemmlowp::ScopedProfilingLabel label("SubWithActivation/float");
3520 const int flat_size =
3521 MatchingFlatSize(input1_shape, input2_shape, input2_shape);
3522 for (int i = 0; i < flat_size; ++i) {
3523 output_data[i] = ActivationFunctionWithMinMax(
3524 input1_data[i] - input2_data[i], params.float_activation_min,
3525 params.float_activation_max);
3526 }
3527 }
3528
3529 template <typename T>
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)3530 void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
3531 const T* input1_data, const RuntimeShape& input2_shape,
3532 const T* input2_data, const RuntimeShape& output_shape,
3533 T* output_data) {
3534 gemmlowp::ScopedProfilingLabel label("Sub");
3535
3536 auto input1_map = MapAsVector(input1_data, input1_shape);
3537 auto input2_map = MapAsVector(input2_data, input2_shape);
3538 auto output_map = MapAsVector(output_data, output_shape);
3539 if (input1_shape == input2_shape) {
3540 output_map.array() = input1_map.array() - input2_map.array();
3541 } else if (input1_shape.FlatSize() == 1) {
3542 auto scalar = input1_data[0];
3543 output_map.array() = scalar - input2_map.array();
3544 } else if (input2_shape.FlatSize() == 1) {
3545 auto scalar = input2_data[0];
3546 output_map.array() = input1_map.array() - scalar;
3547 } else {
3548 BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape,
3549 input2_data, output_shape, output_data);
3550 }
3551 }
3552
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)3553 inline void LstmCell(
3554 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3555 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
3556 const float* prev_activ_data, const RuntimeShape& weights_shape,
3557 const float* weights_data, const RuntimeShape& unextended_bias_shape,
3558 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
3559 const float* prev_state_data,
3560 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
3561 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
3562 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
3563 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
3564 gemmlowp::ScopedProfilingLabel label("LstmCell");
3565 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3566 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3567 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3568 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3569 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3570 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3571 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3572 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3573 const RuntimeShape input_shape =
3574 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3575 const RuntimeShape prev_activ_shape =
3576 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3577 const RuntimeShape bias_shape =
3578 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3579 const RuntimeShape prev_state_shape =
3580 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3581 const RuntimeShape output_state_shape =
3582 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3583 const RuntimeShape output_activ_shape =
3584 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3585 const RuntimeShape concat_temp_shape =
3586 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3587 const RuntimeShape activ_temp_shape =
3588 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3589 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3590
3591 const int weights_dim_count = weights_shape.DimensionsCount();
3592 MatchingDim( // batches
3593 input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
3594 output_state_shape, 0, output_activ_shape, 0);
3595 MatchingDim( // height
3596 input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
3597 output_state_shape, 1, output_activ_shape, 1);
3598 MatchingDim( // width
3599 input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
3600 output_state_shape, 2, output_activ_shape, 2);
3601 const int input_depth = input_shape.Dims(3);
3602 const int prev_activ_depth = prev_activ_shape.Dims(3);
3603 const int total_input_depth = prev_activ_depth + input_depth;
3604 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3605 total_input_depth);
3606 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3607 const int intern_activ_depth =
3608 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3609 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3610 intern_activ_depth * total_input_depth);
3611 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3612 const int output_depth =
3613 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3614 3, output_activ_shape, 3);
3615 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3616
3617 // Concatenate prev_activ and input data together
3618 std::vector<float const*> concat_input_arrays_data;
3619 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
3620 concat_input_arrays_data.push_back(input_data);
3621 concat_input_arrays_data.push_back(prev_activ_data);
3622 concat_input_arrays_shapes.push_back(&input_shape);
3623 concat_input_arrays_shapes.push_back(&prev_activ_shape);
3624 tflite::ConcatenationParams concat_params;
3625 concat_params.axis = 3;
3626 concat_params.inputs_count = concat_input_arrays_data.size();
3627 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
3628 &(concat_input_arrays_data[0]), concat_temp_shape,
3629 concat_temp_data);
3630
3631 // Fully connected
3632 tflite::FullyConnectedParams fc_params;
3633 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
3634 fc_params.float_activation_max = std::numeric_limits<float>::max();
3635 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
3636 weights_data, bias_shape, bias_data, activ_temp_shape,
3637 activ_temp_data);
3638
3639 // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
3640 // operations.
3641 ArrayMap<float> activ_temp_map =
3642 MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
3643 auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
3644 activ_temp_map.cols());
3645 auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
3646 activ_temp_map.cols());
3647 auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
3648 activ_temp_map.cols());
3649 auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
3650 activ_temp_map.cols());
3651 ArrayMap<const float> prev_state_map =
3652 MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
3653 ArrayMap<float> output_state_map =
3654 MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
3655 ArrayMap<float> output_activ_map =
3656 MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
3657
3658 // Combined memory state and final output calculation
3659 gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
3660 output_state_map =
3661 input_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3662 new_input_sm.tanh() +
3663 forget_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3664 prev_state_map;
3665 output_activ_map =
3666 output_gate_sm.unaryExpr(Eigen::internal::scalar_logistic_op<float>()) *
3667 output_state_map.tanh();
3668 }
3669
3670 // Quantized LSTM cell. Currently just a copy of the reference impl in
3671 // reference_ops.h. See the big function comment there, not replicating it
3672 // here.
3673 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,gemmlowp::GemmContext * gemm_context)3674 inline void LstmCell(
3675 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
3676 const uint8* input_data_uint8,
3677 const RuntimeShape& unextended_prev_activ_shape,
3678 const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
3679 const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
3680 const int32* bias_data_int32,
3681 const RuntimeShape& unextended_prev_state_shape,
3682 const int16* prev_state_data_int16,
3683 const RuntimeShape& unextended_output_state_shape,
3684 int16* output_state_data_int16,
3685 const RuntimeShape& unextended_output_activ_shape,
3686 uint8* output_activ_data_uint8,
3687 const RuntimeShape& unextended_concat_temp_shape,
3688 uint8* concat_temp_data_uint8,
3689 const RuntimeShape& unextended_activ_temp_shape,
3690 int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
3691 gemmlowp::ScopedProfilingLabel label(
3692 "LstmCell/quantized (8bit external, 16bit internal)");
3693 int32 weights_zero_point = params.weights_zero_point;
3694 int32 accum_multiplier = params.accum_multiplier;
3695 int accum_shift = params.accum_shift;
3696 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
3697 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
3698 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
3699 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
3700 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
3701 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
3702 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
3703 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
3704 const RuntimeShape input_shape =
3705 RuntimeShape::ExtendedShape(4, unextended_input_shape);
3706 const RuntimeShape prev_activ_shape =
3707 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
3708 const RuntimeShape bias_shape =
3709 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
3710 const RuntimeShape prev_state_shape =
3711 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
3712 const RuntimeShape output_state_shape =
3713 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
3714 const RuntimeShape output_activ_shape =
3715 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
3716 const RuntimeShape concat_temp_shape =
3717 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
3718 const RuntimeShape activ_temp_shape =
3719 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
3720 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
3721
3722 // Gather dimensions information, and perform consistency checks.
3723 const int weights_dim_count = weights_shape.DimensionsCount();
3724 const int outer_size = MatchingFlatSizeSkipDim(
3725 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
3726 output_activ_shape);
3727 const int input_depth = input_shape.Dims(3);
3728 const int prev_activ_depth = prev_activ_shape.Dims(3);
3729 const int total_input_depth = prev_activ_depth + input_depth;
3730 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
3731 total_input_depth);
3732 const int intern_activ_depth =
3733 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
3734 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
3735 intern_activ_depth * total_input_depth);
3736 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
3737 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
3738 const int output_depth =
3739 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
3740 3, output_activ_shape, 3);
3741 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
3742 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
3743 const int fc_output_depth =
3744 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
3745 const int fc_accum_depth = total_input_depth;
3746 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
3747
3748 // Depth-concatenate prev_activ and input data together.
3749 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
3750 prev_activ_data_uint8};
3751 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
3752 &prev_activ_shape};
3753 tflite::ConcatenationParams concat_params;
3754 concat_params.axis = 3;
3755 concat_params.inputs_count = 2;
3756 Concatenation(concat_params, concat_input_arrays_shapes,
3757 concat_input_arrays_data, concat_temp_shape,
3758 concat_temp_data_uint8);
3759
3760 // Implementation of the fully connected node inside the LSTM cell.
3761 // The operands are 8-bit integers, the accumulators are internally 32bit
3762 // integers, and the output is 16-bit fixed-point with 3 integer bits so
3763 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
3764 // is explained in the function comment above.
3765 bool gemm_already_performed = false;
3766 #ifdef GEMMLOWP_NEON
3767 if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
3768 GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
3769 weights_data_uint8, weights_zero_point, bias_shape,
3770 bias_data_int32, accum_multiplier, accum_shift,
3771 activ_temp_shape, activ_temp_data_int16);
3772 gemm_already_performed = true;
3773 }
3774 #endif
3775 if (!gemm_already_performed) {
3776 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
3777 weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
3778 gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
3779 concat_temp_data_uint8, fc_accum_depth, fc_batches);
3780 gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
3781 activ_temp_data_int16, fc_output_depth, fc_batches);
3782 typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
3783 ColVectorMap;
3784 ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
3785 gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
3786 bias_addition_stage.bias_vector = bias_vector;
3787 gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
3788 scale_stage.result_offset_after_shift = 0;
3789 scale_stage.result_fixedpoint_multiplier = accum_multiplier;
3790 scale_stage.result_exponent = accum_shift;
3791 gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
3792 auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
3793 saturating_cast_int16_stage);
3794 gemmlowp::GemmWithOutputPipeline<
3795 uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
3796 gemm_context, weights_matrix, input_matrix, &output_matrix,
3797 -weights_zero_point, -128, output_pipeline);
3798 }
3799
3800 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
3801 // and muls, all done in 16-bit fixed-point.
3802 const int16* input_gate_input_ptr = activ_temp_data_int16;
3803 const int16* input_modulation_gate_input_ptr =
3804 activ_temp_data_int16 + output_depth;
3805 const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
3806 const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
3807 const int16* prev_state_ptr = prev_state_data_int16;
3808 int16* output_state_data_ptr = output_state_data_int16;
3809 uint8* output_activ_data_ptr = output_activ_data_uint8;
3810
3811 for (int b = 0; b < outer_size; ++b) {
3812 int c = 0;
3813 #ifdef GEMMLOWP_NEON
3814 for (; c <= output_depth - 8; c += 8) {
3815 // Define the fixed-point data types that we will use here. All use
3816 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3817 // They only differ by the number of integral vs. fractional bits,
3818 // determining the range of values that they can represent.
3819 //
3820 // F0 uses 0 integer bits, range [-1, 1].
3821 // This is the return type of math functions such as tanh, logistic,
3822 // whose range is in [-1, 1].
3823 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
3824 // F3 uses 3 integer bits, range [-8, 8].
3825 // This is the range of the previous fully-connected node's output,
3826 // which is our input here.
3827 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
3828 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3829 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3830 // number of integer bits is currently dictated by the model. See comment
3831 // on the StateIntegerBits template parameter above.
3832 using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
3833 // Implementation of input gate, using fixed-point logistic function.
3834 F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
3835 input_gate_input_ptr += 8;
3836 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3837 // Implementation of input modulation gate, using fixed-point tanh
3838 // function.
3839 F3 input_modulation_gate_input =
3840 F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
3841 input_modulation_gate_input_ptr += 8;
3842 F0 input_modulation_gate_output =
3843 gemmlowp::tanh(input_modulation_gate_input);
3844 // Implementation of forget gate, using fixed-point logistic function.
3845 F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
3846 forget_gate_input_ptr += 8;
3847 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3848 // Implementation of output gate, using fixed-point logistic function.
3849 F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
3850 output_gate_input_ptr += 8;
3851 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3852 // Implementation of internal multiplication nodes, still in fixed-point.
3853 F0 input_times_input_modulation =
3854 input_gate_output * input_modulation_gate_output;
3855 FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
3856 prev_state_ptr += 8;
3857 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3858 // Implementation of internal addition node, saturating.
3859 FS new_state = gemmlowp::SaturatingAdd(
3860 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3861 prev_state_times_forget_state);
3862 // Implementation of last internal Tanh node, still in fixed-point.
3863 // Since a Tanh fixed-point implementation is specialized for a given
3864 // number or integer bits, and each specialization can have a substantial
3865 // code size, and we already used above a Tanh on an input with 3 integer
3866 // bits, and per the table in the above function comment there is no
3867 // significant accuracy to be lost by clamping to [-8, +8] for a
3868 // 3-integer-bits representation, let us just do that. This helps people
3869 // porting this to targets where code footprint must be minimized.
3870 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3871 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3872 // Store the new internal state back to memory, as 16-bit integers.
3873 // Note: here we store the original value with StateIntegerBits, not
3874 // the rescaled 3-integer-bits value fed to tanh.
3875 vst1q_s16(output_state_data_ptr, new_state.raw());
3876 output_state_data_ptr += 8;
3877 // Down-scale the output activations to 8-bit integers, saturating,
3878 // and store back to memory.
3879 int16x8_t rescaled_output_activ =
3880 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3881 int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
3882 uint8x8_t uint8_output_activ =
3883 vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
3884 vst1_u8(output_activ_data_ptr, uint8_output_activ);
3885 output_activ_data_ptr += 8;
3886 }
3887 #endif
3888 for (; c < output_depth; ++c) {
3889 // Define the fixed-point data types that we will use here. All use
3890 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
3891 // They only differ by the number of integral vs. fractional bits,
3892 // determining the range of values that they can represent.
3893 //
3894 // F0 uses 0 integer bits, range [-1, 1].
3895 // This is the return type of math functions such as tanh, logistic,
3896 // whose range is in [-1, 1].
3897 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
3898 // F3 uses 3 integer bits, range [-8, 8].
3899 // This is the range of the previous fully-connected node's output,
3900 // which is our input here.
3901 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
3902 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
3903 // 2^StateIntegerBits]. It's used to represent the internal state, whose
3904 // number of integer bits is currently dictated by the model. See comment
3905 // on the StateIntegerBits template parameter above.
3906 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
3907 // Implementation of input gate, using fixed-point logistic function.
3908 F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
3909 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
3910 // Implementation of input modulation gate, using fixed-point tanh
3911 // function.
3912 F3 input_modulation_gate_input =
3913 F3::FromRaw(*input_modulation_gate_input_ptr++);
3914 F0 input_modulation_gate_output =
3915 gemmlowp::tanh(input_modulation_gate_input);
3916 // Implementation of forget gate, using fixed-point logistic function.
3917 F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
3918 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
3919 // Implementation of output gate, using fixed-point logistic function.
3920 F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
3921 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
3922 // Implementation of internal multiplication nodes, still in fixed-point.
3923 F0 input_times_input_modulation =
3924 input_gate_output * input_modulation_gate_output;
3925 FS prev_state = FS::FromRaw(*prev_state_ptr++);
3926 FS prev_state_times_forget_state = forget_gate_output * prev_state;
3927 // Implementation of internal addition node, saturating.
3928 FS new_state = gemmlowp::SaturatingAdd(
3929 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
3930 prev_state_times_forget_state);
3931 // Implementation of last internal Tanh node, still in fixed-point.
3932 // Since a Tanh fixed-point implementation is specialized for a given
3933 // number or integer bits, and each specialization can have a substantial
3934 // code size, and we already used above a Tanh on an input with 3 integer
3935 // bits, and per the table in the above function comment there is no
3936 // significant accuracy to be lost by clamping to [-8, +8] for a
3937 // 3-integer-bits representation, let us just do that. This helps people
3938 // porting this to targets where code footprint must be minimized.
3939 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
3940 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
3941 // Store the new internal state back to memory, as 16-bit integers.
3942 // Note: here we store the original value with StateIntegerBits, not
3943 // the rescaled 3-integer-bits value fed to tanh.
3944 *output_state_data_ptr++ = new_state.raw();
3945 // Down-scale the output activations to 8-bit integers, saturating,
3946 // and store back to memory.
3947 int16 rescaled_output_activ =
3948 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
3949 int16 clamped_output_activ =
3950 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
3951 *output_activ_data_ptr++ = 128 + clamped_output_activ;
3952 }
3953 input_gate_input_ptr += 3 * output_depth;
3954 input_modulation_gate_input_ptr += 3 * output_depth;
3955 forget_gate_input_ptr += 3 * output_depth;
3956 output_gate_input_ptr += 3 * output_depth;
3957 }
3958 }
3959
NodeOffset(int b,int h,int w,int height,int width)3960 inline int NodeOffset(int b, int h, int w, int height, int width) {
3961 return (b * height + h) * width + w;
3962 }
3963
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)3964 inline void AveragePool(const PoolParams& params,
3965 const RuntimeShape& input_shape,
3966 const float* input_data,
3967 const RuntimeShape& output_shape, float* output_data) {
3968 gemmlowp::ScopedProfilingLabel label("AveragePool");
3969 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
3970 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
3971 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
3972 const int input_height = input_shape.Dims(1);
3973 const int input_width = input_shape.Dims(2);
3974 const int output_height = output_shape.Dims(1);
3975 const int output_width = output_shape.Dims(2);
3976 const int stride_height = params.stride_height;
3977 const int stride_width = params.stride_width;
3978
3979 // TODO(benoitjacob) make this a proper reference impl without Eigen!
3980 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
3981 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
3982 // TODO(benoitjacob) get rid of the dynamic memory allocation here!
3983 Eigen::VectorXf out_count(out_mat.cols());
3984 out_count.setZero();
3985 // Prefill the output to 0.
3986 out_mat.setZero();
3987 for (int b = 0; b < batches; ++b) {
3988 for (int h = 0; h < input_height; ++h) {
3989 for (int w = 0; w < input_width; ++w) {
3990 // (h_start, h_end) * (w_start, w_end) is the range that the input
3991 // vector projects to.
3992 int hpad = h + params.padding_values.height;
3993 int wpad = w + params.padding_values.width;
3994 int h_start = (hpad < params.filter_height)
3995 ? 0
3996 : (hpad - params.filter_height) / stride_height + 1;
3997 int h_end = std::min(hpad / stride_height + 1, output_height);
3998 int w_start = (wpad < params.filter_width)
3999 ? 0
4000 : (wpad - params.filter_width) / stride_width + 1;
4001 int w_end = std::min(wpad / stride_width + 1, output_width);
4002 // compute elementwise sum
4003 for (int ph = h_start; ph < h_end; ++ph) {
4004 for (int pw = w_start; pw < w_end; ++pw) {
4005 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
4006 out_mat.col(out_offset) +=
4007 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
4008 out_count(out_offset)++;
4009 }
4010 }
4011 }
4012 }
4013 }
4014 // Divide the output by the actual number of elements being averaged over
4015 TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
4016 out_mat.array().rowwise() /= out_count.transpose().array();
4017
4018 const int flat_size = output_shape.FlatSize();
4019 for (int i = 0; i < flat_size; ++i) {
4020 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4021 params.float_activation_min,
4022 params.float_activation_max);
4023 }
4024 }
4025
AveragePool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4026 inline void AveragePool(const PoolParams& params,
4027 const RuntimeShape& input_shape,
4028 const uint8* input_data,
4029 const RuntimeShape& output_shape, uint8* output_data) {
4030 gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
4031
4032 // Here, and in other pooling ops, in order to maintain locality of reference,
4033 // to minimize some recalculations, and to load into NEON vector registers, we
4034 // use an inner loop down the depth. Since depths can be large and hence we
4035 // would need arbitrarily large temporary storage, we divide the work up into
4036 // depth tranches just within the batch loop.
4037 static constexpr int kPoolingAccTrancheSize = 256;
4038
4039 TFLITE_DCHECK_LE(params.quantized_activation_min,
4040 params.quantized_activation_max);
4041 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4042 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4043 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4044 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
4045 const int input_height = input_shape.Dims(1);
4046 const int input_width = input_shape.Dims(2);
4047 const int output_height = output_shape.Dims(1);
4048 const int output_width = output_shape.Dims(2);
4049 const int stride_height = params.stride_height;
4050 const int stride_width = params.stride_width;
4051
4052 uint16 acc[kPoolingAccTrancheSize];
4053 for (int batch = 0; batch < batches; ++batch) {
4054 // We proceed through the depth in tranches (see comment above). The
4055 // depth_base is the depth at the beginning of the tranche. The
4056 // tranche_depth is the depth dimension of the tranche.
4057 for (int depth_base = 0; depth_base < depth;
4058 depth_base += kPoolingAccTrancheSize) {
4059 const int tranche_depth =
4060 std::min(depth - depth_base, kPoolingAccTrancheSize);
4061 for (int out_y = 0; out_y < output_height; ++out_y) {
4062 for (int out_x = 0; out_x < output_width; ++out_x) {
4063 const int in_x_origin =
4064 (out_x * stride_width) - params.padding_values.width;
4065 const int in_y_origin =
4066 (out_y * stride_height) - params.padding_values.height;
4067 const int filter_x_start = std::max(0, -in_x_origin);
4068 const int filter_x_end =
4069 std::min(params.filter_width, input_width - in_x_origin);
4070 const int filter_y_start = std::max(0, -in_y_origin);
4071 const int filter_y_end =
4072 std::min(params.filter_height, input_height - in_y_origin);
4073 const int filter_count =
4074 (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
4075 memset(acc, 0, tranche_depth * sizeof(acc[0]));
4076 const uint8* input_ptr =
4077 input_data + depth_base +
4078 depth * (in_x_origin +
4079 input_width * (in_y_origin + input_height * batch));
4080 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
4081 const uint8* input_row_ptr =
4082 input_ptr + depth * (fy * input_width + filter_x_start);
4083 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
4084 const uint8* input_channel_ptr = input_row_ptr;
4085 int channel = 0;
4086 #ifdef USE_NEON
4087 for (; channel <= tranche_depth - 16; channel += 16) {
4088 uint16x8_t acc_reg[2];
4089 for (int i = 0; i < 2; i++) {
4090 acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
4091 }
4092 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
4093 input_channel_ptr += 16;
4094 acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
4095 acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
4096 for (int i = 0; i < 2; i++) {
4097 vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
4098 }
4099 }
4100 for (; channel <= tranche_depth - 8; channel += 8) {
4101 uint16x8_t acc_reg = vld1q_u16(acc + channel);
4102 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
4103 input_channel_ptr += 8;
4104 acc_reg = vaddw_u8(acc_reg, input_reg);
4105 vst1q_u16(acc + channel, acc_reg);
4106 }
4107 #endif
4108 for (; channel < tranche_depth; ++channel) {
4109 acc[channel] += *input_channel_ptr++;
4110 }
4111 input_row_ptr += depth;
4112 }
4113 }
4114 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
4115 out_x, depth_base);
4116 int channel = 0;
4117 #ifdef USE_NEON
4118 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
4119 if (filter_count == FILTER_COUNT) { \
4120 for (; channel <= tranche_depth - 8; channel += 8) { \
4121 uint16 buf[8]; \
4122 for (int i = 0; i < 8; i++) { \
4123 buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
4124 } \
4125 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
4126 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
4127 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
4128 vst1_u8(output_ptr + channel, buf8); \
4129 } \
4130 }
4131 AVGPOOL_DIVIDING_BY(9)
4132 AVGPOOL_DIVIDING_BY(15)
4133 #undef AVGPOOL_DIVIDING_BY
4134 for (; channel <= tranche_depth - 8; channel += 8) {
4135 uint16 buf[8];
4136 for (int i = 0; i < 8; i++) {
4137 buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
4138 }
4139 uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
4140 buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
4141 buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
4142 vst1_u8(output_ptr + channel, buf8);
4143 }
4144 #endif
4145 for (; channel < tranche_depth; ++channel) {
4146 uint16 a = (acc[channel] + filter_count / 2) / filter_count;
4147 a = std::max<uint16>(a, params.quantized_activation_min);
4148 a = std::min<uint16>(a, params.quantized_activation_max);
4149 output_ptr[channel] = static_cast<uint8>(a);
4150 }
4151 }
4152 }
4153 }
4154 }
4155 }
4156
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4157 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
4158 const float* input_data, const RuntimeShape& output_shape,
4159 float* output_data) {
4160 gemmlowp::ScopedProfilingLabel label("MaxPool");
4161 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4162 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4163 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4164 const int input_height = input_shape.Dims(1);
4165 const int input_width = input_shape.Dims(2);
4166 const int output_height = output_shape.Dims(1);
4167 const int output_width = output_shape.Dims(2);
4168 const int stride_height = params.stride_height;
4169 const int stride_width = params.stride_width;
4170
4171 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4172 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4173 // Prefill the output to minimum representable float value
4174 out_mat.setConstant(std::numeric_limits<float>::lowest());
4175 for (int b = 0; b < batches; ++b) {
4176 for (int h = 0; h < input_height; ++h) {
4177 for (int w = 0; w < input_width; ++w) {
4178 // (h_start, h_end) * (w_start, w_end) is the range that the input
4179 // vector projects to.
4180 int hpad = h + params.padding_values.height;
4181 int wpad = w + params.padding_values.width;
4182 int h_start = (hpad < params.filter_height)
4183 ? 0
4184 : (hpad - params.filter_height) / stride_height + 1;
4185 int h_end = std::min(hpad / stride_height + 1, output_height);
4186 int w_start = (wpad < params.filter_width)
4187 ? 0
4188 : (wpad - params.filter_width) / stride_width + 1;
4189 int w_end = std::min(wpad / stride_width + 1, output_width);
4190 // compute elementwise sum
4191 for (int ph = h_start; ph < h_end; ++ph) {
4192 for (int pw = w_start; pw < w_end; ++pw) {
4193 int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
4194 out_mat.col(out_offset) =
4195 out_mat.col(out_offset)
4196 .cwiseMax(in_mat.col(
4197 NodeOffset(b, h, w, input_height, input_width)));
4198 }
4199 }
4200 }
4201 }
4202 }
4203 const int flat_size = output_shape.FlatSize();
4204 for (int i = 0; i < flat_size; ++i) {
4205 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4206 params.float_activation_min,
4207 params.float_activation_max);
4208 }
4209 }
4210
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4211 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
4212 const uint8* input_data, const RuntimeShape& output_shape,
4213 uint8* output_data) {
4214 gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
4215
4216 // Here, and in other pooling ops, in order to maintain locality of reference,
4217 // to minimize some recalculations, and to load into NEON vector registers, we
4218 // use an inner loop down the depth. Since depths can be large and hence we
4219 // would need arbitrarily large temporary storage, we divide the work up into
4220 // depth tranches just within the batch loop.
4221 static constexpr int kPoolingAccTrancheSize = 256;
4222
4223 TFLITE_DCHECK_LE(params.quantized_activation_min,
4224 params.quantized_activation_max);
4225 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4226 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4227 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4228 const int depth = MatchingDim(input_shape, 3, output_shape, 3);
4229 const int input_height = input_shape.Dims(1);
4230 const int input_width = input_shape.Dims(2);
4231 const int output_height = output_shape.Dims(1);
4232 const int output_width = output_shape.Dims(2);
4233 const int stride_height = params.stride_height;
4234 const int stride_width = params.stride_width;
4235
4236 uint8 acc[kPoolingAccTrancheSize];
4237 for (int batch = 0; batch < batches; ++batch) {
4238 // We proceed through the depth in tranches (see comment above). The
4239 // depth_base is the depth at the beginning of the tranche. The
4240 // tranche_depth is the depth dimension of the tranche.
4241 for (int depth_base = 0; depth_base < depth;
4242 depth_base += kPoolingAccTrancheSize) {
4243 const int tranche_depth =
4244 std::min(depth - depth_base, kPoolingAccTrancheSize);
4245 for (int out_y = 0; out_y < output_height; ++out_y) {
4246 for (int out_x = 0; out_x < output_width; ++out_x) {
4247 const int in_x_origin =
4248 (out_x * stride_width) - params.padding_values.width;
4249 const int in_y_origin =
4250 (out_y * stride_height) - params.padding_values.height;
4251 const int filter_x_start = std::max(0, -in_x_origin);
4252 const int filter_x_end =
4253 std::min(params.filter_width, input_width - in_x_origin);
4254 const int filter_y_start = std::max(0, -in_y_origin);
4255 const int filter_y_end =
4256 std::min(params.filter_height, input_height - in_y_origin);
4257 memset(acc, 0, tranche_depth * sizeof(acc[0]));
4258 const uint8* input_ptr =
4259 input_data + depth_base +
4260 depth * (in_x_origin +
4261 input_width * (in_y_origin + input_height * batch));
4262 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
4263 const uint8* input_row_ptr =
4264 input_ptr + depth * (fy * input_width + filter_x_start);
4265 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
4266 const uint8* input_channel_ptr = input_row_ptr;
4267 int channel = 0;
4268 #ifdef USE_NEON
4269 for (; channel <= tranche_depth - 16; channel += 16) {
4270 uint8x16_t acc_reg = vld1q_u8(acc + channel);
4271 uint8x16_t input_reg = vld1q_u8(input_channel_ptr);
4272 input_channel_ptr += 16;
4273 acc_reg = vmaxq_u8(acc_reg, input_reg);
4274 vst1q_u8(acc + channel, acc_reg);
4275 }
4276
4277 for (; channel <= tranche_depth - 8; channel += 8) {
4278 uint8x8_t acc_reg = vld1_u8(acc + channel);
4279 uint8x8_t input_reg = vld1_u8(input_channel_ptr);
4280 input_channel_ptr += 8;
4281 acc_reg = vmax_u8(acc_reg, input_reg);
4282 vst1_u8(acc + channel, acc_reg);
4283 }
4284 #endif
4285 for (; channel < tranche_depth; ++channel) {
4286 acc[channel] = std::max(acc[channel], *input_channel_ptr++);
4287 }
4288 input_row_ptr += depth;
4289 }
4290 }
4291 uint8* output_ptr = output_data + Offset(output_shape, batch, out_y,
4292 out_x, depth_base);
4293 int channel = 0;
4294 #ifdef USE_NEON
4295 for (; channel <= tranche_depth - 16; channel += 16) {
4296 uint8x16_t a = vld1q_u8(acc + channel);
4297 a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
4298 a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
4299 vst1q_u8(output_ptr + channel, a);
4300 }
4301 for (; channel <= tranche_depth - 8; channel += 8) {
4302 uint8x8_t a = vld1_u8(acc + channel);
4303 a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
4304 a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
4305 vst1_u8(output_ptr + channel, a);
4306 }
4307 #endif
4308 for (; channel < tranche_depth; ++channel) {
4309 uint8 a = acc[channel];
4310 a = std::max<uint8>(a, params.quantized_activation_min);
4311 a = std::min<uint8>(a, params.quantized_activation_max);
4312 output_ptr[channel] = static_cast<uint8>(a);
4313 }
4314 }
4315 }
4316 }
4317 }
4318 }
4319
L2Pool(const PoolParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4320 inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
4321 const float* input_data, const RuntimeShape& output_shape,
4322 float* output_data) {
4323 gemmlowp::ScopedProfilingLabel label("L2Pool");
4324 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
4325 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
4326 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
4327 const int input_height = input_shape.Dims(1);
4328 const int input_width = input_shape.Dims(2);
4329 const int output_height = output_shape.Dims(1);
4330 const int output_width = output_shape.Dims(2);
4331 const int stride_height = params.stride_height;
4332 const int stride_width = params.stride_width;
4333 // Actually carry out L2 Pool. Code is written in forward mode: we go through
4334 // the input values once, and write to all the pooled regions that it maps to.
4335 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4336 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4337 Eigen::VectorXf in_square(in_mat.rows());
4338 Eigen::VectorXf out_count(out_mat.cols());
4339 out_count.setZero();
4340 // Prefill the output to 0.
4341 out_mat.setZero();
4342 for (int b = 0; b < batches; ++b) {
4343 for (int h = 0; h < input_height; ++h) {
4344 for (int w = 0; w < input_width; ++w) {
4345 // (h_start, h_end) * (w_start, w_end) is the range that the input
4346 // vector projects to.
4347 const int hpad = h + params.padding_values.height;
4348 const int wpad = w + params.padding_values.width;
4349 const int h_start =
4350 (hpad < params.filter_height)
4351 ? 0
4352 : (hpad - params.filter_height) / stride_height + 1;
4353 const int h_end = std::min(hpad / stride_height + 1, output_height);
4354 const int w_start =
4355 (wpad < params.filter_width)
4356 ? 0
4357 : (wpad - params.filter_width) / stride_width + 1;
4358 const int w_end = std::min(wpad / stride_width + 1, output_width);
4359 // pre-compute square
4360 const int in_offset = w + input_width * (h + input_height * b);
4361 in_square =
4362 in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
4363 // compute elementwise sum of squares
4364 for (int ph = h_start; ph < h_end; ++ph) {
4365 for (int pw = w_start; pw < w_end; ++pw) {
4366 const int out_offset = pw + output_width * (ph + output_height * b);
4367 out_mat.col(out_offset) += in_square;
4368 out_count(out_offset)++;
4369 }
4370 }
4371 }
4372 }
4373 }
4374
4375 out_count = out_count.array().inverse();
4376 out_mat =
4377 (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
4378
4379 const int flat_size = output_shape.FlatSize();
4380 for (int i = 0; i < flat_size; ++i) {
4381 output_data[i] = ActivationFunctionWithMinMax(output_data[i],
4382 params.float_activation_min,
4383 params.float_activation_max);
4384 }
4385 }
4386
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4387 inline void LocalResponseNormalization(
4388 const tflite::LocalResponseNormalizationParams& op_params,
4389 const RuntimeShape& input_shape, const float* input_data,
4390 const RuntimeShape& output_shape, float* output_data) {
4391 gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
4392 MatchingFlatSize(input_shape, output_shape);
4393
4394 const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4395 auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4396
4397 // Carry out local response normalization, vector by vector.
4398 // Since the data are stored column major, making row-wise operation
4399 // probably not memory efficient anyway, we do an explicit for loop over
4400 // the columns.
4401 const int double_range = op_params.range * 2;
4402 Eigen::VectorXf padded_square(data_in.rows() + double_range);
4403 padded_square.setZero();
4404 for (int r = 0; r < data_in.cols(); ++r) {
4405 // Do local response normalization for data_in(:, r)
4406 // first, compute the square and store them in buffer for repeated use
4407 padded_square.block(op_params.range, 0, data_in.rows(), 1) =
4408 data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
4409 // Then, compute the scale and writes them to data_out
4410 float accumulated_scale = 0;
4411 for (int i = 0; i < double_range; ++i) {
4412 accumulated_scale += padded_square(i);
4413 }
4414 for (int i = 0; i < data_in.rows(); ++i) {
4415 accumulated_scale += padded_square(i + double_range);
4416 data_out(i, r) = op_params.bias + accumulated_scale;
4417 accumulated_scale -= padded_square(i);
4418 }
4419 }
4420
4421 // In a few cases, the pow computation could benefit from speedups.
4422 if (op_params.beta == 1) {
4423 data_out.array() = data_in.array() * data_out.array().inverse();
4424 } else if (op_params.beta == 0.5) {
4425 data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
4426 } else {
4427 data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
4428 }
4429 }
4430
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4431 inline void Softmax(const SoftmaxParams& params,
4432 const RuntimeShape& input_shape, const float* input_data,
4433 const RuntimeShape& output_shape, float* output_data) {
4434 gemmlowp::ScopedProfilingLabel label("Softmax");
4435 MatchingFlatSize(input_shape, output_shape);
4436
4437 const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
4438 auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
4439 // Compute the exponential first, removing the max coefficient for numerical
4440 // stability.
4441 out_mat =
4442 (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
4443 // We are separating out the exp function so that exp can be vectorized.
4444 out_mat = out_mat.array().exp();
4445 // Normalize to get the activations.
4446 Eigen::Array<float, 1, Eigen::Dynamic> scale =
4447 out_mat.array().colwise().sum().inverse();
4448 out_mat.array().rowwise() *= scale;
4449 }
4450
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4451 inline void Softmax(const SoftmaxParams& params,
4452 const RuntimeShape& input_shape, const uint8* input_data,
4453 const RuntimeShape& output_shape, uint8* output_data) {
4454 const int32 input_beta_multiplier = params.input_multiplier;
4455 const int32 input_beta_left_shift = params.input_left_shift;
4456 const int diff_min = params.diff_min;
4457 // The representation chosen for the input to the exp() function is Q5.26.
4458 // We need to leave extra space since values that we skip might be as large as
4459 // -32 before multiplying by input_beta_multiplier, and therefore as large as
4460 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
4461 // accumulation, but exp(-16) definitely is.
4462 static const int kScaledDiffIntegerBits = 5;
4463 static const int kAccumulationIntegerBits = 12;
4464 using FixedPointScaledDiff =
4465 gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4466 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4467 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4468
4469 gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
4470 const int trailing_dim = input_shape.DimensionsCount() - 1;
4471 const int outer_size =
4472 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4473 const int depth =
4474 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4475
4476 for (int b = 0; b < outer_size; ++b) {
4477 const uint8* input_data_ptr = input_data + b * depth;
4478 uint8* output_data_ptr = output_data + b * depth;
4479
4480 // Determine the largest entry in the current row
4481 uint8 max_in_row = 0;
4482 {
4483 int c = 0;
4484 #ifdef USE_NEON
4485 uint8x16_t max16_0 = vdupq_n_u8(0);
4486 uint8x16_t max16_1 = vdupq_n_u8(0);
4487 for (; c <= depth - 32; c += 32) {
4488 max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
4489 max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
4490 }
4491 uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
4492 if (c <= depth - 16) {
4493 max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
4494 c += 16;
4495 }
4496 uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
4497 if (c <= depth - 8) {
4498 max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
4499 c += 8;
4500 }
4501 uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
4502 uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
4503 uint8x8_t max1 = vpmax_u8(max2, max2);
4504 max_in_row = vget_lane_u8(max1, 0);
4505 #endif
4506 for (; c < depth; ++c) {
4507 max_in_row = std::max(max_in_row, input_data_ptr[c]);
4508 }
4509 }
4510
4511 #ifdef USE_NEON
4512 using FixedPointAccumInt32x4 =
4513 gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
4514 using FixedPointScaledDiffInt32x4 =
4515 gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
4516 using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
4517 FixedPoint0Int32x4 input_beta_multiplier_f0 =
4518 FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
4519 int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
4520 #endif
4521
4522 // Compute the sum of exponentials of the differences of entries in the
4523 // current row from the largest entry in the current row.
4524 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4525 {
4526 int c = 0;
4527 #ifdef USE_NEON
4528 int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
4529 FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
4530 FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
4531 FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
4532 for (; c <= depth - 8; c += 8) {
4533 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4534 int16x8_t input_diff_s16 =
4535 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4536 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4537 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4538 int32x4_t mask_0 =
4539 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
4540 int32x4_t mask_1 =
4541 gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
4542 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4543 input_beta_multiplier_f0 *
4544 FixedPointScaledDiffInt32x4::FromRaw(
4545 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4546 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4547 input_beta_multiplier_f0 *
4548 FixedPointScaledDiffInt32x4::FromRaw(
4549 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4550 FixedPointAccumInt32x4 exps_0 =
4551 gemmlowp::Rescale<kAccumulationIntegerBits>(
4552 exp_on_negative_values(scaled_diff_0));
4553 FixedPointAccumInt32x4 exps_1 =
4554 gemmlowp::Rescale<kAccumulationIntegerBits>(
4555 exp_on_negative_values(scaled_diff_1));
4556 FixedPointAccumInt32x4 masked_exps_0 =
4557 SelectUsingMask(mask_0, exps_0, zeros);
4558 FixedPointAccumInt32x4 masked_exps_1 =
4559 SelectUsingMask(mask_1, exps_1, zeros);
4560 sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
4561 sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
4562 }
4563 int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
4564 int32x2_t sum_of_exps_reduced_2 =
4565 vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
4566 vget_high_s32(sum_of_exps_reduced_4));
4567 int32x2_t sum_of_exps_reduced_1 =
4568 vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
4569 sum_of_exps =
4570 FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
4571 #endif
4572 for (; c < depth; ++c) {
4573 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4574 if (input_diff >= diff_min) {
4575 const int32 input_diff_rescaled =
4576 MultiplyByQuantizedMultiplierGreaterThanOne(
4577 input_diff, input_beta_multiplier, input_beta_left_shift);
4578 const FixedPointScaledDiff scaled_diff_f8 =
4579 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4580 sum_of_exps =
4581 sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4582 exp_on_negative_values(scaled_diff_f8));
4583 }
4584 }
4585 }
4586
4587 // Compute the fixed-point multiplier and shift that we need to apply to
4588 // perform a division by the above-computed sum-of-exponentials.
4589 int32 fixed_sum_of_exps = sum_of_exps.raw();
4590 int headroom_plus_one =
4591 CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
4592 // This is the number of bits to the left of the binary point above 1.0.
4593 // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
4594 // no later adjustment will be needed.
4595 int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
4596 int32 shifted_sum_minus_one = static_cast<int32>(
4597 (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
4598 (static_cast<uint32>(1) << 31));
4599 FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
4600 FixedPoint0::FromRaw(shifted_sum_minus_one));
4601
4602 // Compute the quotients of exponentials of differences of entries in the
4603 // current row from the largest entry, over the previously-computed sum of
4604 // exponentials.
4605 {
4606 int c = 0;
4607 #ifdef USE_NEON
4608 int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
4609 for (; c <= depth - 8; c += 8) {
4610 uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
4611 int16x8_t input_diff_s16 =
4612 vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
4613 int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
4614 int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
4615 uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
4616 FixedPointScaledDiffInt32x4 scaled_diff_0 =
4617 input_beta_multiplier_f0 *
4618 FixedPointScaledDiffInt32x4::FromRaw(
4619 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
4620 FixedPointScaledDiffInt32x4 scaled_diff_1 =
4621 input_beta_multiplier_f0 *
4622 FixedPointScaledDiffInt32x4::FromRaw(
4623 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
4624 FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
4625 FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
4626 int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
4627 vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
4628 num_bits_over_unit + 31 - 8);
4629 int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
4630 vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
4631 num_bits_over_unit + 31 - 8);
4632 int16x8_t output_s16 =
4633 vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
4634 uint8x8_t output_u8 = vqmovun_s16(output_s16);
4635 uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
4636 vst1_u8(output_data_ptr + c, masked_output);
4637 }
4638 #endif
4639 for (; c < depth; ++c) {
4640 int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
4641 if (input_diff >= diff_min) {
4642 const int32 input_diff_rescaled =
4643 MultiplyByQuantizedMultiplierGreaterThanOne(
4644 input_diff, input_beta_multiplier, input_beta_left_shift);
4645 const FixedPointScaledDiff scaled_diff_f8 =
4646 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4647
4648 FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
4649 int32 unsat_output = gemmlowp::RoundingDivideByPOT(
4650 (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
4651
4652 output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
4653
4654 } else {
4655 output_data_ptr[c] = 0;
4656 }
4657 }
4658 }
4659 }
4660 }
4661
4662 // TODO(myenik): This is the same as the reference implementation, not actually
4663 // optimized yet.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4664 inline void LogSoftmax(const SoftmaxParams& params,
4665 const RuntimeShape& input_shape, const float* input_data,
4666 const RuntimeShape& output_shape, float* output_data) {
4667 gemmlowp::ScopedProfilingLabel label("LogSoftmax");
4668 const int trailing_dim = input_shape.DimensionsCount() - 1;
4669 const int outer_size =
4670 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4671 const int depth =
4672 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4673
4674 for (int i = 0; i < outer_size; ++i) {
4675 const float* block_input_data = input_data + i * depth;
4676 float* block_output_data = output_data + i * depth;
4677 // Find max element value which we'll use to ensure numerical stability
4678 // taking advantage of the following equality:
4679 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
4680 float max = std::numeric_limits<float>::lowest();
4681 for (int c = 0; c < depth; ++c) {
4682 max = std::max(max, block_input_data[c]);
4683 }
4684
4685 // Compute sum.
4686 float sum = 0.f;
4687 for (int c = 0; c < depth; ++c) {
4688 sum += std::exp(block_input_data[c] - max);
4689 }
4690
4691 // Compute result.
4692 const float log_sum = std::log(sum);
4693 for (int c = 0; c < depth; ++c) {
4694 block_output_data[c] = block_input_data[c] - max - log_sum;
4695 }
4696 }
4697 }
4698
4699 // Currently just a copy of the reference code.
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4700 inline void LogSoftmax(const SoftmaxParams& params,
4701 const RuntimeShape& input_shape, const uint8* input_data,
4702 const RuntimeShape& output_shape, uint8* output_data) {
4703 gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
4704 const int32 input_multiplier = params.input_multiplier;
4705 const int32 input_left_shift = params.input_left_shift;
4706 const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
4707 const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
4708 const int diff_min = params.diff_min;
4709 // The representation chosen for the input to the exp() function is Q5.26.
4710 // We need to leave extra space since values that we skip might be as large as
4711 // -32 before multiplying by input_beta_multiplier, and therefore as large as
4712 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
4713 // accumulation, but exp(-16) definitely is.
4714 static constexpr int kScaledDiffIntegerBits = 5;
4715 static constexpr int kAccumulationIntegerBits = 12;
4716 static constexpr int kOutputIntegerBits = 4;
4717 using FixedPointScaledDiff =
4718 gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
4719 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
4720
4721 const int trailing_dim = input_shape.DimensionsCount() - 1;
4722 const int outer_size =
4723 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
4724 const int depth =
4725 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
4726
4727 for (int i = 0; i < outer_size; ++i) {
4728 const uint8* block_input_data = input_data + i * depth;
4729 uint8* block_output_data = output_data + i * depth;
4730 uint8 max_in_row = 0;
4731 for (int c = 0; c < depth; ++c) {
4732 max_in_row = std::max(max_in_row, block_input_data[c]);
4733 }
4734
4735 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
4736 for (int c = 0; c < depth; ++c) {
4737 int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
4738 if (input_diff >= diff_min) {
4739 const int32 input_diff_rescaled =
4740 MultiplyByQuantizedMultiplierGreaterThanOne(
4741 input_diff, input_multiplier, input_left_shift);
4742 const FixedPointScaledDiff scaled_diff_f8 =
4743 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
4744 sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
4745 exp_on_negative_values(scaled_diff_f8));
4746 }
4747 }
4748
4749 const int32 fixed_log_sum_of_exps =
4750 log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
4751 sum_of_exps)
4752 .raw();
4753
4754 // rescaled_diff_min is smallest representable in
4755 // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
4756 // log-sub-exps that will be subtracted in the loop.
4757 //
4758 // The thresholds diff_min, etc are negative.
4759 const int rescaled_diff_min =
4760 fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
4761 const int adjusted_diff_min =
4762 std::max(diff_min - 1, // Note use of > below instead of >= above.
4763 MultiplyByQuantizedMultiplierSmallerThanOneExp(
4764 rescaled_diff_min, reverse_scaling_divisor,
4765 -reverse_scaling_right_shift));
4766
4767 for (int c = 0; c < depth; ++c) {
4768 int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
4769 if (input_diff > adjusted_diff_min) {
4770 const int32 input_diff_rescaled =
4771 MultiplyByQuantizedMultiplierGreaterThanOne(
4772 input_diff, input_multiplier, input_left_shift);
4773 int32 unsat_output =
4774 gemmlowp::RoundingDivideByPOT(
4775 (input_diff_rescaled - fixed_log_sum_of_exps),
4776 31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
4777 255;
4778
4779 block_output_data[c] = static_cast<uint8>(
4780 std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
4781 } else {
4782 // Set output to smallest value.
4783 block_output_data[c] = 0;
4784 }
4785 }
4786 }
4787 }
4788
Logistic(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4789 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
4790 const RuntimeShape& output_shape, float* output_data) {
4791 gemmlowp::ScopedProfilingLabel label("Logistic");
4792 auto input_map = MapAsVector(input_data, input_shape);
4793 auto output_map = MapAsVector(output_data, output_shape);
4794 output_map.array() =
4795 input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
4796 }
4797
4798 // Convenience version that allows, for example, generated-code calls to be
4799 // uniform between data types.
Logistic(const LogisticParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)4800 inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
4801 const float* input_data, const RuntimeShape& output_shape,
4802 float* output_data) {
4803 // Drop params: not needed.
4804 Logistic(input_shape, input_data, output_shape, output_data);
4805 }
4806
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)4807 inline void Logistic(const LogisticParams& params,
4808 const RuntimeShape& input_shape, const uint8* input_data,
4809 const RuntimeShape& output_shape, uint8* output_data) {
4810 gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
4811 const int32 input_zero_point = params.input_zero_point;
4812 const int32 input_range_radius = params.input_range_radius;
4813 const int32 input_multiplier = params.input_multiplier;
4814 const int input_left_shift = params.input_left_shift;
4815 const int size = MatchingFlatSize(input_shape, output_shape);
4816
4817 int c = 0;
4818 #ifdef USE_NEON
4819 // Handle 16 values at a time
4820 for (; c <= size - 16; c += 16) {
4821 // Read input uint8 values, cast to int16 and subtract input_zero_point
4822 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
4823 int16x8_t input_val_centered_0 =
4824 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
4825 vdupq_n_s16(input_zero_point));
4826 int16x8_t input_val_centered_1 =
4827 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
4828 vdupq_n_s16(input_zero_point));
4829
4830 // Prepare the bit masks that we will use at the end to implement the logic
4831 // that was expressed in the scalar code with branching:
4832 // if (input_val_centered < -input_range_radius) {
4833 // output_val = 0;
4834 // } else if (input_val_centered > input_range_radius) {
4835 // output_val = 255;
4836 // } else {
4837 // ...
4838 uint16x8_t mask_rightclamp_0 =
4839 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
4840 uint16x8_t mask_rightclamp_1 =
4841 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
4842 uint16x8_t mask_leftclamp_0 =
4843 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
4844 uint16x8_t mask_leftclamp_1 =
4845 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
4846 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
4847 vshrn_n_u16(mask_rightclamp_1, 8));
4848 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
4849 vshrn_n_u16(mask_leftclamp_1, 8));
4850
4851 // This performs what is expressed in the scalar code as
4852 // const int32 input_val_rescaled =
4853 // MultiplyByQuantizedMultiplierGreaterThanOne(
4854 // input_val_centered, input_multiplier, input_left_shift);
4855 int32x4_t input_val_rescaled_0 =
4856 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
4857 vdupq_n_s32(input_left_shift));
4858 int32x4_t input_val_rescaled_1 =
4859 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
4860 vdupq_n_s32(input_left_shift));
4861 int32x4_t input_val_rescaled_2 =
4862 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
4863 vdupq_n_s32(input_left_shift));
4864 int32x4_t input_val_rescaled_3 =
4865 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
4866 vdupq_n_s32(input_left_shift));
4867 input_val_rescaled_0 =
4868 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
4869 input_val_rescaled_1 =
4870 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
4871 input_val_rescaled_2 =
4872 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
4873 input_val_rescaled_3 =
4874 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
4875
4876 // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
4877 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
4878 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
4879 const FixedPoint4 input_val_f4_0 =
4880 FixedPoint4::FromRaw(input_val_rescaled_0);
4881 const FixedPoint4 input_val_f4_1 =
4882 FixedPoint4::FromRaw(input_val_rescaled_1);
4883 const FixedPoint4 input_val_f4_2 =
4884 FixedPoint4::FromRaw(input_val_rescaled_2);
4885 const FixedPoint4 input_val_f4_3 =
4886 FixedPoint4::FromRaw(input_val_rescaled_3);
4887 const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
4888 const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
4889 const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
4890 const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
4891
4892 // Divide by 2^23 as in the scalar code
4893 using gemmlowp::RoundingDivideByPOT;
4894 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
4895 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
4896 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
4897 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
4898
4899 // Cast output values to uint8, saturating
4900 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
4901 vqmovn_s32(output_val_s32_1));
4902 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
4903 vqmovn_s32(output_val_s32_3));
4904 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
4905 vqmovun_s16(output_val_s16_1));
4906
4907 // Perform the bit-masking with the bit masks computed at the beginning,
4908 // see the comment there.
4909 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
4910 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
4911
4912 // Store back to memory
4913 vst1q_u8(output_data + c, output_val_u8);
4914 }
4915 #endif
4916 // Leftover loop: handle one value at a time with scalar code.
4917 for (; c < size; ++c) {
4918 const uint8 input_val_u8 = input_data[c];
4919 const int32 input_val_centered =
4920 static_cast<int32>(input_val_u8) - input_zero_point;
4921 uint8 output_val;
4922 if (input_val_centered < -input_range_radius) {
4923 output_val = 0;
4924 } else if (input_val_centered > input_range_radius) {
4925 output_val = 255;
4926 } else {
4927 const int32 input_val_rescaled =
4928 MultiplyByQuantizedMultiplierGreaterThanOne(
4929 input_val_centered, input_multiplier, input_left_shift);
4930 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
4931 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
4932 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
4933 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
4934 using gemmlowp::RoundingDivideByPOT;
4935 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
4936 if (output_val_s32 == 256) {
4937 output_val_s32 = 255;
4938 }
4939 TFLITE_DCHECK_GE(output_val_s32, 0);
4940 TFLITE_DCHECK_LE(output_val_s32, 255);
4941 output_val = static_cast<uint8>(output_val_s32);
4942 }
4943 output_data[c] = output_val;
4944 }
4945 }
4946
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)4947 inline void Logistic(const LogisticParams& params,
4948 const RuntimeShape& input_shape, const int16* input_data,
4949 const RuntimeShape& output_shape, int16* output_data) {
4950 gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
4951 const int flat_size = MatchingFlatSize(input_shape, output_shape);
4952
4953 for (int i = 0; i < flat_size; i++) {
4954 }
4955
4956 int c = 0;
4957 const int16* input_data_ptr = input_data;
4958 int16* output_data_ptr = output_data;
4959 #ifdef GEMMLOWP_NEON
4960 {
4961 // F0 uses 0 integer bits, range [-1, 1].
4962 // This is the return type of math functions such as tanh, logistic,
4963 // whose range is in [-1, 1].
4964 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
4965 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4966 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
4967
4968 for (; c <= flat_size - 16; c += 16) {
4969 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
4970 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
4971 F0 output0 = gemmlowp::logistic(input0);
4972 F0 output1 = gemmlowp::logistic(input1);
4973 vst1q_s16(output_data_ptr, output0.raw());
4974 vst1q_s16(output_data_ptr + 8, output1.raw());
4975
4976 input_data_ptr += 16;
4977 output_data_ptr += 16;
4978 }
4979 for (; c <= flat_size - 8; c += 8) {
4980 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
4981 F0 output = gemmlowp::logistic(input);
4982 vst1q_s16(output_data_ptr, output.raw());
4983
4984 input_data_ptr += 8;
4985 output_data_ptr += 8;
4986 }
4987 }
4988 #endif
4989 {
4990 // F0 uses 0 integer bits, range [-1, 1].
4991 // This is the return type of math functions such as tanh, logistic,
4992 // whose range is in [-1, 1].
4993 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
4994 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
4995 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
4996
4997 for (; c < flat_size; ++c) {
4998 F3 input = F3::FromRaw(*input_data_ptr);
4999 F0 output = gemmlowp::logistic(input);
5000 *output_data_ptr = output.raw();
5001
5002 ++input_data_ptr;
5003 ++output_data_ptr;
5004 }
5005 }
5006 }
5007
Tanh(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5008 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
5009 const RuntimeShape& output_shape, float* output_data) {
5010 gemmlowp::ScopedProfilingLabel label("Tanh");
5011 auto input_map = MapAsVector(input_data, input_shape);
5012 auto output_map = MapAsVector(output_data, output_shape);
5013 output_map.array() = input_map.array().tanh();
5014 }
5015
5016 // Convenience version that allows, for example, generated-code calls to be
5017 // uniform between data types.
Tanh(const TanhParams &,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5018 inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
5019 const float* input_data, const RuntimeShape& output_shape,
5020 float* output_data) {
5021 // Drop params: not needed.
5022 Tanh(input_shape, input_data, output_shape, output_data);
5023 }
5024
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)5025 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
5026 const uint8* input_data, const RuntimeShape& output_shape,
5027 uint8* output_data) {
5028 // Note that this is almost the exact same code as in Logistic().
5029 gemmlowp::ScopedProfilingLabel label("Tanh");
5030 const int32 input_zero_point = params.input_zero_point;
5031 const int32 input_range_radius = params.input_range_radius;
5032 const int32 input_multiplier = params.input_multiplier;
5033 const int input_left_shift = params.input_left_shift;
5034 const int size = MatchingFlatSize(input_shape, output_shape);
5035
5036 int c = 0;
5037 int32_t output_zero_point = 128;
5038 #ifdef USE_NEON
5039 // Handle 16 values at a time
5040 for (; c <= size - 16; c += 16) {
5041 // Read input uint8 values, cast to int16 and subtract input_zero_point
5042 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
5043 int16x8_t input_val_centered_0 =
5044 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
5045 vdupq_n_s16(input_zero_point));
5046 int16x8_t input_val_centered_1 =
5047 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
5048 vdupq_n_s16(input_zero_point));
5049
5050 // Prepare the bit masks that we will use at the end to implement the logic
5051 // that was expressed in the scalar code with branching:
5052 // if (input_val_centered < -input_range_radius) {
5053 // output_val = 0;
5054 // } else if (input_val_centered > input_range_radius) {
5055 // output_val = 255;
5056 // } else {
5057 // ...
5058 uint16x8_t mask_rightclamp_0 =
5059 vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
5060 uint16x8_t mask_rightclamp_1 =
5061 vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
5062 uint16x8_t mask_leftclamp_0 =
5063 vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
5064 uint16x8_t mask_leftclamp_1 =
5065 vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
5066 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
5067 vshrn_n_u16(mask_rightclamp_1, 8));
5068 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
5069 vshrn_n_u16(mask_leftclamp_1, 8));
5070
5071 // This performs what is expressed in the scalar code as
5072 // const int32 input_val_rescaled =
5073 // MultiplyByQuantizedMultiplierGreaterThanOne(
5074 // input_val_centered, input_multiplier, input_left_shift);
5075 int32x4_t input_val_rescaled_0 =
5076 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
5077 vdupq_n_s32(input_left_shift));
5078 int32x4_t input_val_rescaled_1 =
5079 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
5080 vdupq_n_s32(input_left_shift));
5081 int32x4_t input_val_rescaled_2 =
5082 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
5083 vdupq_n_s32(input_left_shift));
5084 int32x4_t input_val_rescaled_3 =
5085 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
5086 vdupq_n_s32(input_left_shift));
5087 input_val_rescaled_0 =
5088 vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
5089 input_val_rescaled_1 =
5090 vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
5091 input_val_rescaled_2 =
5092 vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
5093 input_val_rescaled_3 =
5094 vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
5095
5096 // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
5097 using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
5098 using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
5099 const FixedPoint4 input_val_f4_0 =
5100 FixedPoint4::FromRaw(input_val_rescaled_0);
5101 const FixedPoint4 input_val_f4_1 =
5102 FixedPoint4::FromRaw(input_val_rescaled_1);
5103 const FixedPoint4 input_val_f4_2 =
5104 FixedPoint4::FromRaw(input_val_rescaled_2);
5105 const FixedPoint4 input_val_f4_3 =
5106 FixedPoint4::FromRaw(input_val_rescaled_3);
5107 const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
5108 const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
5109 const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
5110 const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
5111
5112 // Divide by 2^24 as in the scalar code
5113 using gemmlowp::RoundingDivideByPOT;
5114 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
5115 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
5116 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
5117 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
5118
5119 // Add the output zero point
5120 int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
5121 output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
5122 output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
5123 output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
5124 output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
5125
5126 // Cast output values to uint8, saturating
5127 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
5128 vqmovn_s32(output_val_s32_1));
5129 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
5130 vqmovn_s32(output_val_s32_3));
5131 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
5132 vqmovun_s16(output_val_s16_1));
5133
5134 // Perform the bit-masking with the bit masks computed at the beginning,
5135 // see the comment there.
5136 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
5137 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
5138
5139 // Store back to memory
5140 vst1q_u8(output_data + c, output_val_u8);
5141 }
5142 #endif
5143 // Leftover loop: handle one value at a time with scalar code.
5144 for (; c < size; ++c) {
5145 const uint8 input_val_u8 = input_data[c];
5146 const int32 input_val_centered =
5147 static_cast<int32>(input_val_u8) - input_zero_point;
5148 uint8 output_val;
5149 if (input_val_centered < -input_range_radius) {
5150 output_val = 0;
5151 } else if (input_val_centered > input_range_radius) {
5152 output_val = 255;
5153 } else {
5154 const int32 input_val_rescaled =
5155 MultiplyByQuantizedMultiplierGreaterThanOne(
5156 input_val_centered, input_multiplier, input_left_shift);
5157 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
5158 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
5159 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
5160 const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
5161 using gemmlowp::RoundingDivideByPOT;
5162 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
5163 output_val_s32 += output_zero_point;
5164 if (output_val_s32 == 256) {
5165 output_val_s32 = 255;
5166 }
5167 TFLITE_DCHECK_GE(output_val_s32, 0);
5168 TFLITE_DCHECK_LE(output_val_s32, 255);
5169 output_val = static_cast<uint8>(output_val_s32);
5170 }
5171 output_data[c] = output_val;
5172 }
5173 }
5174
Tanh(const TanhParams & params,const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)5175 inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
5176 const int16* input_data, const RuntimeShape& output_shape,
5177 int16* output_data) {
5178 gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
5179 const int input_left_shift = params.input_left_shift;
5180 // Support for shifts is limited until we have a parameterized version of
5181 // SaturatingRoundingMultiplyByPOT().
5182 TFLITE_DCHECK_GE(input_left_shift, 0);
5183 TFLITE_DCHECK_LE(input_left_shift, 1);
5184
5185 const int flat_size = MatchingFlatSize(input_shape, output_shape);
5186
5187 int c = 0;
5188 const int16* input_data_ptr = input_data;
5189 int16* output_data_ptr = output_data;
5190 #ifdef GEMMLOWP_NEON
5191 {
5192 // F0 uses 0 integer bits, range [-1, 1].
5193 // This is the return type of math functions such as tanh, logistic,
5194 // whose range is in [-1, 1].
5195 using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
5196 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
5197 using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
5198
5199 if (input_left_shift == 0) {
5200 for (; c <= flat_size - 16; c += 16) {
5201 F3 input0 = F3::FromRaw(vld1q_s16(input_data_ptr));
5202 F3 input1 = F3::FromRaw(vld1q_s16(input_data_ptr + 8));
5203 F0 output0 = gemmlowp::tanh(input0);
5204 F0 output1 = gemmlowp::tanh(input1);
5205 vst1q_s16(output_data_ptr, output0.raw());
5206 vst1q_s16(output_data_ptr + 8, output1.raw());
5207
5208 input_data_ptr += 16;
5209 output_data_ptr += 16;
5210 }
5211 for (; c <= flat_size - 8; c += 8) {
5212 F3 input = F3::FromRaw(vld1q_s16(input_data_ptr));
5213 F0 output = gemmlowp::tanh(input);
5214 vst1q_s16(output_data_ptr, output.raw());
5215
5216 input_data_ptr += 8;
5217 output_data_ptr += 8;
5218 }
5219 } else {
5220 for (; c <= flat_size - 16; c += 16) {
5221 F3 input0 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5222 vld1q_s16(input_data_ptr)));
5223 F3 input1 = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5224 vld1q_s16(input_data_ptr + 8)));
5225 F0 output0 = gemmlowp::tanh(input0);
5226 F0 output1 = gemmlowp::tanh(input1);
5227 vst1q_s16(output_data_ptr, output0.raw());
5228 vst1q_s16(output_data_ptr + 8, output1.raw());
5229
5230 input_data_ptr += 16;
5231 output_data_ptr += 16;
5232 }
5233 for (; c <= flat_size - 8; c += 8) {
5234 F3 input = F3::FromRaw(gemmlowp::SaturatingRoundingMultiplyByPOT<1>(
5235 vld1q_s16(input_data_ptr)));
5236 F0 output = gemmlowp::tanh(input);
5237 vst1q_s16(output_data_ptr, output.raw());
5238
5239 input_data_ptr += 8;
5240 output_data_ptr += 8;
5241 }
5242 }
5243 }
5244 #endif
5245 {
5246 // F0 uses 0 integer bits, range [-1, 1].
5247 // This is the return type of math functions such as tanh, logistic,
5248 // whose range is in [-1, 1].
5249 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
5250 // F3 uses 3 integer bits, range [-8, 8], the input range expected here.
5251 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
5252
5253 if (input_left_shift == 0) {
5254 for (; c < flat_size; ++c) {
5255 F3 input = F3::FromRaw(*input_data_ptr);
5256 F0 output = gemmlowp::tanh(input);
5257 *output_data_ptr = output.raw();
5258
5259 ++input_data_ptr;
5260 ++output_data_ptr;
5261 }
5262 } else {
5263 for (; c < flat_size; ++c) {
5264 F3 input = F3::FromRaw(
5265 gemmlowp::SaturatingRoundingMultiplyByPOT<1>(*input_data_ptr));
5266 F0 output = gemmlowp::tanh(input);
5267 *output_data_ptr = output.raw();
5268
5269 ++input_data_ptr;
5270 ++output_data_ptr;
5271 }
5272 }
5273 }
5274 }
5275
5276 template <typename SrcT, typename DstT>
Cast(const RuntimeShape & input_shape,const SrcT * input_data,const RuntimeShape & output_shape,DstT * output_data)5277 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
5278 const RuntimeShape& output_shape, DstT* output_data) {
5279 gemmlowp::ScopedProfilingLabel label("Cast");
5280 auto input_map = MapAsVector(input_data, input_shape);
5281 auto output_map = MapAsVector(output_data, output_shape);
5282 output_map.array() = input_map.array().template cast<DstT>();
5283 }
5284
Floor(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5285 inline void Floor(const RuntimeShape& input_shape, const float* input_data,
5286 const RuntimeShape& output_shape, float* output_data) {
5287 gemmlowp::ScopedProfilingLabel label("Floor");
5288 auto input_map = MapAsVector(input_data, input_shape);
5289 auto output_map = MapAsVector(output_data, output_shape);
5290 output_map.array() = Eigen::floor(input_map.array());
5291 }
5292
Ceil(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5293 inline void Ceil(const RuntimeShape& input_shape, const float* input_data,
5294 const RuntimeShape& output_shape, float* output_data) {
5295 gemmlowp::ScopedProfilingLabel label("Ceil");
5296 auto input_map = MapAsVector(input_data, input_shape);
5297 auto output_map = MapAsVector(output_data, output_shape);
5298 output_map.array() = Eigen::ceil(input_map.array());
5299 }
5300
5301 #ifdef USE_NEON
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)5302 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
5303 float scale, float* output_ptr) {
5304 int ic = 0;
5305 // Handle 32 input channels at a time.
5306 for (; ic <= depth - 32; ic += 32) {
5307 float32x4x2_t input[4];
5308 for (int i = 0; i < 4; i++) {
5309 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
5310 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
5311 }
5312 float32x4x2_t acc[4];
5313 for (int i = 0; i < 4; i++) {
5314 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
5315 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
5316 }
5317 for (int i = 0; i < 4; i++) {
5318 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
5319 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
5320 }
5321 for (int i = 0; i < 4; i++) {
5322 vst1q_f32(output_ptr, acc[i].val[0]);
5323 vst1q_f32(output_ptr + 4, acc[i].val[1]);
5324 output_ptr += 8;
5325 }
5326 input_ptr += 32;
5327 }
5328 // Handle 16 input channels at a time.
5329 for (; ic <= depth - 16; ic += 16) {
5330 float32x4x2_t input[2];
5331 for (int i = 0; i < 2; i++) {
5332 input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
5333 input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
5334 }
5335 float32x4x2_t acc[2];
5336 for (int i = 0; i < 2; i++) {
5337 acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
5338 acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
5339 }
5340 for (int i = 0; i < 2; i++) {
5341 acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
5342 acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
5343 }
5344 for (int i = 0; i < 2; i++) {
5345 vst1q_f32(output_ptr, acc[i].val[0]);
5346 vst1q_f32(output_ptr + 4, acc[i].val[1]);
5347 output_ptr += 8;
5348 }
5349 input_ptr += 16;
5350 }
5351 // Handle 8 input channels at a time.
5352 for (; ic <= depth - 8; ic += 8) {
5353 float32x4x2_t input;
5354 input.val[0] = vld1q_f32(input_ptr);
5355 input.val[1] = vld1q_f32(input_ptr + 4);
5356
5357 float32x4x2_t acc;
5358 acc.val[0] = vld1q_f32(output_ptr);
5359 acc.val[1] = vld1q_f32(output_ptr + 4);
5360 acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
5361 acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
5362
5363 vst1q_f32(output_ptr, acc.val[0]);
5364 vst1q_f32(output_ptr + 4, acc.val[1]);
5365
5366 input_ptr += 8;
5367 output_ptr += 8;
5368 }
5369 // Handle 4 input channels at a time.
5370 for (; ic <= depth - 4; ic += 4) {
5371 float32x4_t input = vld1q_f32(input_ptr);
5372 float32x4_t acc = vld1q_f32(output_ptr);
5373
5374 acc = vmlaq_n_f32(acc, input, scale);
5375 vst1q_f32(output_ptr, acc);
5376
5377 input_ptr += 4;
5378 output_ptr += 4;
5379 }
5380 // Handle 1 input channel at a time.
5381 for (; ic < depth; ic++) {
5382 *output_ptr += *input_ptr * scale;
5383 output_ptr++;
5384 input_ptr++;
5385 }
5386 }
5387 #else
ResizeBilinearKernel(const float * input_ptr,int32 depth,float scale,float * output_ptr)5388 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
5389 float scale, float* output_ptr) {
5390 for (int32 i = 0; i < depth; i++) {
5391 *output_ptr += *input_ptr * scale;
5392 output_ptr++;
5393 input_ptr++;
5394 }
5395 }
5396 #endif
5397
ResizeBilinearKernel2x2(int32 x0,int32 x1,int32 y0,int32 y1,int32 x,int32 y,int32 depth,int32 batch,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5398 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
5399 int32 x, int32 y, int32 depth, int32 batch,
5400 const RuntimeShape& input_shape,
5401 const float* input_data,
5402 const RuntimeShape& output_shape,
5403 float* output_data) {
5404 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
5405 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
5406 const int32 input_width = input_shape.Dims(2);
5407 const int32 output_width = output_shape.Dims(2);
5408
5409 const int32 input_x_offset = (x1 - x0) * depth;
5410 const int32 input_y_offset = (y1 - y0) * depth * input_width;
5411 const int32 output_x_offset = depth;
5412 const int32 output_y_offset = depth * output_width;
5413
5414 #ifdef USE_NEON
5415 TFLITE_DCHECK(x1 >= x0);
5416 TFLITE_DCHECK(y1 >= y0);
5417
5418 int ic = 0;
5419 // Handle 8 input channels at a time.
5420 for (; ic <= depth - 8; ic += 8) {
5421 const float* input_ptr = nullptr;
5422
5423 float32x4x2_t x0y0;
5424 input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
5425 x0y0.val[0] = vld1q_f32(input_ptr);
5426 x0y0.val[1] = vld1q_f32(input_ptr + 4);
5427
5428 float32x4x2_t x1y0;
5429 input_ptr += input_x_offset;
5430 x1y0.val[0] = vld1q_f32(input_ptr);
5431 x1y0.val[1] = vld1q_f32(input_ptr + 4);
5432
5433 float32x4x2_t x0y1;
5434 input_ptr += -input_x_offset + input_y_offset;
5435 x0y1.val[0] = vld1q_f32(input_ptr);
5436 x0y1.val[1] = vld1q_f32(input_ptr + 4);
5437
5438 float32x4x2_t x1y1;
5439 input_ptr += input_x_offset;
5440 x1y1.val[0] = vld1q_f32(input_ptr);
5441 x1y1.val[1] = vld1q_f32(input_ptr + 4);
5442
5443 // Top left corner.
5444 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
5445 vst1q_f32(output_ptr, x0y0.val[0]);
5446 vst1q_f32(output_ptr + 4, x0y0.val[1]);
5447
5448 // Top right corner.
5449 output_ptr += output_x_offset;
5450 float32x4x2_t tr;
5451 tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
5452 tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
5453 tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
5454 tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
5455
5456 vst1q_f32(output_ptr, tr.val[0]);
5457 vst1q_f32(output_ptr + 4, tr.val[1]);
5458
5459 // Bottom left corner.
5460 output_ptr += -output_x_offset + output_y_offset;
5461 float32x4x2_t bl;
5462 bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
5463 bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
5464 bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
5465 bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
5466 vst1q_f32(output_ptr, bl.val[0]);
5467 vst1q_f32(output_ptr + 4, bl.val[1]);
5468
5469 // Bottom right corner.
5470 output_ptr += output_x_offset;
5471 float32x4x2_t br;
5472 br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
5473 br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
5474 br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
5475 br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
5476 br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
5477 br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
5478 vst1q_f32(output_ptr, br.val[0]);
5479 vst1q_f32(output_ptr + 4, br.val[1]);
5480 }
5481 // Handle 4 input channels at a time.
5482 for (; ic <= depth - 4; ic += 4) {
5483 const float* input_ptr =
5484 &input_data[Offset(input_shape, batch, y0, x0, ic)];
5485 float32x4_t x0y0 = vld1q_f32(input_ptr);
5486 float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
5487 float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
5488 float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
5489
5490 // Top left corner.
5491 float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
5492 vst1q_f32(output_ptr, x0y0);
5493
5494 // Top right corner.
5495 output_ptr += output_x_offset;
5496 float32x4_t tr = vaddq_f32(x0y0, x1y0);
5497 tr = vmulq_n_f32(tr, 0.5f);
5498 vst1q_f32(output_ptr, tr);
5499
5500 // Bottom left corner.
5501 output_ptr += -output_x_offset + output_y_offset;
5502 float32x4_t bl = vaddq_f32(x0y0, x0y1);
5503 bl = vmulq_n_f32(bl, 0.5f);
5504 vst1q_f32(output_ptr, bl);
5505
5506 // Bottom right corner.
5507 output_ptr += output_x_offset;
5508 float32x4_t br = vaddq_f32(x1y0, x1y1);
5509 br = vmlaq_n_f32(bl, br, 0.5f);
5510 br = vmulq_n_f32(br, 0.5f);
5511 vst1q_f32(output_ptr, br);
5512 }
5513 // Handle one input channel at a time.
5514 for (; ic < depth; ic++) {
5515 const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
5516
5517 float x0y0 = input_data[input_offset];
5518 float x1y0 = input_data[input_offset + input_x_offset];
5519 float x0y1 = input_data[input_offset + input_y_offset];
5520 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
5521
5522 // Top left corner.
5523 const int32 output_offset = Offset(output_shape, batch, y, x, ic);
5524 output_data[output_offset] = x0y0;
5525
5526 // Top right corner.
5527 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
5528
5529 // Bottom left corner.
5530 float output = (x0y0 + x0y1) / 2;
5531 output_data[output_offset + output_y_offset] = output;
5532
5533 // Bottom right corner.
5534 output_data[output_offset + output_x_offset + output_y_offset] =
5535 (output + ((x1y0 + x1y1) / 2)) / 2;
5536 }
5537 #else
5538 for (int ch = 0; ch < depth; ch++) {
5539 const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
5540
5541 float x0y0 = input_data[input_offset];
5542 float x1y0 = input_data[input_offset + input_x_offset];
5543 float x0y1 = input_data[input_offset + input_y_offset];
5544 float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
5545
5546 // Top left corner.
5547 const int32 output_offset = Offset(output_shape, batch, y, x, ch);
5548 output_data[output_offset] = x0y0;
5549
5550 // Top right corner.
5551 output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
5552
5553 // Bottom left corner.
5554 float output = (x0y0 + x0y1) / 2;
5555 output_data[output_offset + output_y_offset] = output;
5556
5557 // Bottom right corner.
5558 output_data[output_offset + output_x_offset + output_y_offset] =
5559 (output + ((x1y0 + x1y1) / 2)) / 2;
5560 }
5561 #endif
5562 }
5563
ResizeBilinear2x2(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5564 inline void ResizeBilinear2x2(int32 batches, int32 input_height,
5565 int32 input_width, int32 depth,
5566 int32 output_height, int32 output_width,
5567 const RuntimeShape& input_shape,
5568 const float* input_data,
5569 const RuntimeShape& output_shape,
5570 float* output_data) {
5571 for (int b = 0; b < batches; b++) {
5572 for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
5573 for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
5574 int32 x1 = std::min(x0 + 1, input_width - 1);
5575 int32 y1 = std::min(y0 + 1, input_height - 1);
5576 ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
5577 input_data, output_shape, output_data);
5578 }
5579 }
5580 }
5581 }
5582
ResizeBilinearGeneric(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)5583 inline void ResizeBilinearGeneric(
5584 int32 batches, int32 input_height, int32 input_width, int32 depth,
5585 int32 output_height, int32 output_width, float height_scale,
5586 float width_scale, const RuntimeShape& input_shape, const float* input_data,
5587 const RuntimeShape& output_shape, float* output_data) {
5588 memset(output_data, 0,
5589 batches * output_height * output_width * depth * sizeof(float));
5590
5591 int32 output_offset = 0;
5592 for (int b = 0; b < batches; ++b) {
5593 for (int y = 0; y < output_height; ++y) {
5594 float input_y = y * height_scale;
5595 int32 y0 = static_cast<int32>(std::floor(input_y));
5596 int32 y1 = std::min(y0 + 1, input_height - 1);
5597 for (int x = 0; x < output_width; ++x) {
5598 float input_x = x * width_scale;
5599 int32 x0 = static_cast<int32>(input_x);
5600 int32 x1 = std::min(x0 + 1, input_width - 1);
5601 float* output_ptr = &output_data[output_offset];
5602
5603 // Run kernel on the 4 corners of the bilinear resize algorithm.
5604 int32 input_offset = Offset(input_shape, b, y0, x0, 0);
5605 float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
5606 const float* input_ptr = &input_data[input_offset];
5607 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5608
5609 input_offset = Offset(input_shape, b, y0, x1, 0);
5610 scale = (1 - (input_y - y0)) * (input_x - x0);
5611 input_ptr = &input_data[input_offset];
5612 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5613
5614 input_offset = Offset(input_shape, b, y1, x0, 0);
5615 scale = (input_y - y0) * (1 - (input_x - x0));
5616 input_ptr = &input_data[input_offset];
5617 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5618
5619 input_offset = Offset(input_shape, b, y1, x1, 0);
5620 scale = (input_y - y0) * (input_x - x0);
5621 input_ptr = &input_data[input_offset];
5622 ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
5623
5624 output_offset += depth;
5625 }
5626 }
5627 }
5628 }
5629
5630 template <typename T>
ResizeBilinearGenericSmallChannel(int32 batches,int32 input_height,int32 input_width,int32 depth,int32 output_height,int32 output_width,float height_scale,float width_scale,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)5631 inline void ResizeBilinearGenericSmallChannel(
5632 int32 batches, int32 input_height, int32 input_width, int32 depth,
5633 int32 output_height, int32 output_width, float height_scale,
5634 float width_scale, const RuntimeShape& input_shape, const T* input_data,
5635 const RuntimeShape& output_shape, T* output_data) {
5636 T* output_ptr = &output_data[0];
5637 for (int b = 0; b < batches; ++b) {
5638 for (int y = 0; y < output_height; ++y) {
5639 float input_y = y * height_scale;
5640 int32 y0 = static_cast<int32>(std::floor(input_y));
5641 int32 y1 = std::min(y0 + 1, input_height - 1);
5642 for (int x = 0; x < output_width; ++x) {
5643 float input_x = x * width_scale;
5644 int32 x0 = static_cast<int32>(std::floor((input_x)));
5645 int32 x1 = std::min(x0 + 1, input_width - 1);
5646
5647 int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
5648 Offset(input_shape, b, y0, x1, 0),
5649 Offset(input_shape, b, y1, x0, 0),
5650 Offset(input_shape, b, y1, x1, 0)};
5651 float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
5652 (1 - (input_y - y0)) * (input_x - x0),
5653 (input_y - y0) * (1 - (input_x - x0)),
5654 (input_y - y0) * (input_x - x0)};
5655
5656 for (int d = 0; d < depth; d++) {
5657 const T* input_ptr = &input_data[d];
5658 *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
5659 input_ptr[input_offset[1]] * scale[1] +
5660 input_ptr[input_offset[2]] * scale[2] +
5661 input_ptr[input_offset[3]] * scale[3]);
5662 }
5663 }
5664 }
5665 }
5666 }
5667
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,float * output_data)5668 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
5669 const RuntimeShape& unextended_input_shape,
5670 const float* input_data,
5671 const RuntimeShape& output_size_shape,
5672 const int32* output_size_data,
5673 const RuntimeShape& unextended_output_shape,
5674 float* output_data) {
5675 gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
5676 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5677 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5678 const RuntimeShape input_shape =
5679 RuntimeShape::ExtendedShape(4, unextended_input_shape);
5680 const RuntimeShape output_shape =
5681 RuntimeShape::ExtendedShape(4, unextended_output_shape);
5682
5683 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5684 int32 input_height = input_shape.Dims(1);
5685 int32 input_width = input_shape.Dims(2);
5686 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5687
5688 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5689 int32 output_height = output_size_data[0];
5690 int32 output_width = output_size_data[1];
5691
5692 // Specialize for 2x2 upsample.
5693 if (!op_params.align_corners && output_height == 2 * input_height &&
5694 output_width == 2 * input_width) {
5695 ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
5696 output_width, input_shape, input_data, output_shape,
5697 output_data);
5698 } else {
5699 float height_scale = static_cast<float>(input_height) / output_height;
5700 float width_scale = static_cast<float>(input_width) / output_width;
5701 if (op_params.align_corners && output_height > 1) {
5702 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
5703 }
5704 if (op_params.align_corners && output_width > 1) {
5705 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
5706 }
5707
5708 ResizeBilinearGeneric(batches, input_height, input_width, depth,
5709 output_height, output_width, height_scale,
5710 width_scale, input_shape, input_data, output_shape,
5711 output_data);
5712 }
5713 }
5714
5715 // TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
5716 // or int16 arithmetic.
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)5717 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
5718 const RuntimeShape& unextended_input_shape,
5719 const uint8* input_data,
5720 const RuntimeShape& output_size_shape,
5721 const int32* output_size_data,
5722 const RuntimeShape& unextended_output_shape,
5723 uint8* output_data) {
5724 gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
5725 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
5726 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5727 const RuntimeShape input_shape =
5728 RuntimeShape::ExtendedShape(4, unextended_input_shape);
5729 const RuntimeShape output_shape =
5730 RuntimeShape::ExtendedShape(4, unextended_output_shape);
5731
5732 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
5733 int32 input_height = input_shape.Dims(1);
5734 int32 input_width = input_shape.Dims(2);
5735 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
5736
5737 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
5738 int32 output_height = output_size_data[0];
5739 int32 output_width = output_size_data[1];
5740
5741 float height_scale =
5742 (op_params.align_corners && output_height > 1)
5743 ? (static_cast<float>(input_height - 1) / (output_height - 1))
5744 : (static_cast<float>(input_height) / output_height);
5745
5746 float width_scale =
5747 (op_params.align_corners && output_width > 1)
5748 ? (static_cast<float>(input_width - 1) / (output_width - 1))
5749 : (static_cast<float>(input_width) / output_width);
5750
5751 ResizeBilinearGenericSmallChannel<uint8>(
5752 batches, input_height, input_width, depth, output_height, output_width,
5753 height_scale, width_scale, input_shape, input_data, output_shape,
5754 output_data);
5755 }
5756
5757 // Helper methods for BatchToSpaceND.
5758 // `spatial_index_dim` specifies post-crop offset index in this spatial
5759 // dimension, i.e. spatial offset introduced by flattening batch to spatial
5760 // dimension minus the crop size at beginning. `block_shape_dim` is the block
5761 // size in current dimension. `input_dim` and `output_dim` are input and output
5762 // size of BatchToSpaceND operation in current dimension.
5763 // Output start index is inclusive and end index is exclusive.
GetIndexRange(int spatial_index_dim,int block_shape_dim,int input_dim,int output_dim,int * start_index,int * end_index)5764 inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
5765 int input_dim, int output_dim, int* start_index,
5766 int* end_index) {
5767 // (*start_index) * block_shape_dim is effectively rounded up to the next
5768 // multiple of block_shape_dim by the integer division.
5769 *start_index =
5770 std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
5771 // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
5772 // end_index is exclusive).
5773 *end_index = std::min(
5774 input_dim,
5775 (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
5776 }
5777
5778 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32 * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32 * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)5779 inline void BatchToSpaceND(
5780 const RuntimeShape& unextended_input1_shape, const T* input1_data,
5781 const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
5782 const RuntimeShape& unextended_input3_shape, const int32* crops_data,
5783 const RuntimeShape& unextended_output_shape, T* output_data) {
5784 gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
5785
5786 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
5787 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
5788 const RuntimeShape input1_shape =
5789 RuntimeShape::ExtendedShape(4, unextended_input1_shape);
5790 const RuntimeShape output_shape =
5791 RuntimeShape::ExtendedShape(4, unextended_output_shape);
5792
5793 const int output_width = output_shape.Dims(2);
5794 const int output_height = output_shape.Dims(1);
5795 const int output_batch_size = output_shape.Dims(0);
5796
5797 const int depth = input1_shape.Dims(3);
5798 const int input_width = input1_shape.Dims(2);
5799 const int input_height = input1_shape.Dims(1);
5800 const int input_batch_size = input1_shape.Dims(0);
5801
5802 const int block_shape_width = block_shape_data[1];
5803 const int block_shape_height = block_shape_data[0];
5804 const int crops_top = crops_data[0];
5805 const int crops_left = crops_data[2];
5806
5807 for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
5808 const int out_batch = in_batch % output_batch_size;
5809 const int spatial_offset = in_batch / output_batch_size;
5810
5811 int in_h_start = 0;
5812 int in_h_end = 0;
5813 // GetIndexRange ensures start and end indices are in [0, output_height).
5814 GetIndexRange(spatial_offset / block_shape_width - crops_top,
5815 block_shape_height, input_height, output_height, &in_h_start,
5816 &in_h_end);
5817
5818 for (int in_h = in_h_start; in_h < in_h_end; ++in_h) {
5819 const int out_h = in_h * block_shape_height +
5820 spatial_offset / block_shape_width - crops_top;
5821 TFLITE_DCHECK_GE(out_h, 0);
5822 TFLITE_DCHECK_LT(out_h, output_height);
5823
5824 int in_w_start = 0;
5825 int in_w_end = 0;
5826 // GetIndexRange ensures start and end indices are in [0, output_width).
5827 GetIndexRange(spatial_offset % block_shape_width - crops_left,
5828 block_shape_width, input_width, output_width, &in_w_start,
5829 &in_w_end);
5830
5831 for (int in_w = in_w_start; in_w < in_w_end; ++in_w) {
5832 const int out_w = in_w * block_shape_width +
5833 spatial_offset % block_shape_width - crops_left;
5834 TFLITE_DCHECK_GE(out_w, 0);
5835 TFLITE_DCHECK_LT(out_w, output_width);
5836 T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
5837 const T* in =
5838 input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
5839 memcpy(out, in, depth * sizeof(T));
5840 }
5841 }
5842 }
5843 }
5844
5845 template <typename T>
TypedMemset(void * ptr,T value,size_t num)5846 void TypedMemset(void* ptr, T value, size_t num) {
5847 // Optimization for common cases where memset() will suffice.
5848 if (value == 0 || std::is_same<T, uint8_t>::value) {
5849 memset(ptr, value, num * sizeof(T));
5850 } else {
5851 // Default implementation for cases where memset() will not preserve the
5852 // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t).
5853 char* pos = static_cast<char*>(ptr);
5854 for (size_t i = 0; i < num; ++i) {
5855 memcpy(pos, &value, sizeof(T));
5856 pos = pos + sizeof(T);
5857 }
5858 }
5859 }
5860
5861 // This makes heavy use of Offset, along with conditional branches. There may be
5862 // opportunities for improvement.
5863 //
5864 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
5865 // scalar input that provides the padding value. Therefore pad_value_ptr can be
5866 // equivalent to a simple input1_data. For Pad, it should point to a zero
5867 // value.
5868 //
5869 // Note that two typenames are required, so that T=P=int32 is considered a
5870 // specialization distinct from P=int32.
5871 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5872 inline void PadImpl(const tflite::PadParams& op_params,
5873 const RuntimeShape& input_shape, const T* input_data,
5874 const P* pad_value_ptr, const RuntimeShape& output_shape,
5875 T* output_data) {
5876 gemmlowp::ScopedProfilingLabel label("Pad4DSlowImpl");
5877 const RuntimeShape ext_input_shape =
5878 RuntimeShape::ExtendedShape(4, input_shape);
5879 const RuntimeShape ext_output_shape =
5880 RuntimeShape::ExtendedShape(4, output_shape);
5881 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
5882 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
5883
5884 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
5885 // to 4 dims (yes, we are "padding the padding").
5886 std::vector<int> left_padding_copy(4, 0);
5887 const int left_padding_extend = 4 - op_params.left_padding_count;
5888 for (int i = 0; i < op_params.left_padding_count; ++i) {
5889 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
5890 }
5891 std::vector<int> right_padding_copy(4, 0);
5892 const int right_padding_extend = 4 - op_params.right_padding_count;
5893 for (int i = 0; i < op_params.right_padding_count; ++i) {
5894 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
5895 }
5896
5897 const int output_batch = ext_output_shape.Dims(0);
5898 const int output_height = ext_output_shape.Dims(1);
5899 const int output_width = ext_output_shape.Dims(2);
5900 const int output_depth = ext_output_shape.Dims(3);
5901
5902 const int left_b_padding = left_padding_copy[0];
5903 const int left_h_padding = left_padding_copy[1];
5904 const int left_w_padding = left_padding_copy[2];
5905 const int left_d_padding = left_padding_copy[3];
5906
5907 const int right_b_padding = right_padding_copy[0];
5908 const int right_h_padding = right_padding_copy[1];
5909 const int right_w_padding = right_padding_copy[2];
5910 const int right_d_padding = right_padding_copy[3];
5911
5912 const int input_depth = ext_input_shape.Dims(3);
5913 const T pad_value = *pad_value_ptr;
5914
5915 if (left_b_padding != 0) {
5916 TypedMemset<T>(
5917 output_data, pad_value,
5918 left_b_padding * output_height * output_width * output_depth);
5919 }
5920 for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
5921 ++out_b) {
5922 if (left_h_padding != 0) {
5923 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, 0, 0, 0),
5924 pad_value, left_h_padding * output_width * output_depth);
5925 }
5926 for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
5927 ++out_h) {
5928 if (left_w_padding != 0) {
5929 TypedMemset<T>(
5930 output_data + Offset(ext_output_shape, out_b, out_h, 0, 0),
5931 pad_value, left_w_padding * output_depth);
5932 }
5933 for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
5934 ++out_w) {
5935 if (left_d_padding != 0) {
5936 TypedMemset<T>(
5937 output_data + Offset(ext_output_shape, out_b, out_h, out_w, 0),
5938 pad_value, left_d_padding);
5939 }
5940
5941 T* out = output_data +
5942 Offset(ext_output_shape, out_b, out_h, out_w, left_d_padding);
5943 const T* in = input_data +
5944 Offset(ext_input_shape, out_b - left_b_padding,
5945 out_h - left_h_padding, out_w - left_w_padding, 0);
5946 memcpy(out, in, input_depth * sizeof(T));
5947
5948 if (right_d_padding != 0) {
5949 TypedMemset<T>(
5950 output_data + Offset(ext_output_shape, out_b, out_h, out_w,
5951 output_depth - right_d_padding),
5952 pad_value, right_d_padding);
5953 }
5954 }
5955 if (right_w_padding != 0) {
5956 TypedMemset<T>(output_data + Offset(ext_output_shape, out_b, out_h,
5957 output_width - right_w_padding, 0),
5958 pad_value, right_w_padding * output_depth);
5959 }
5960 }
5961 if (right_h_padding != 0) {
5962 TypedMemset<T>(
5963 output_data + Offset(ext_output_shape, out_b,
5964 output_height - right_h_padding, 0, 0),
5965 pad_value, right_h_padding * output_width * output_depth);
5966 }
5967 }
5968 if (right_b_padding != 0) {
5969 TypedMemset<T>(
5970 output_data +
5971 Offset(ext_output_shape, output_batch - right_b_padding, 0, 0, 0),
5972 pad_value,
5973 right_b_padding * output_height * output_width * output_depth);
5974 }
5975 }
5976
5977 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5978 inline void Pad(const tflite::PadParams& op_params,
5979 const RuntimeShape& input_shape, const T* input_data,
5980 const P* pad_value_ptr, const RuntimeShape& output_shape,
5981 T* output_data) {
5982 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
5983 output_data);
5984 }
5985
5986 // The second (pad-value) input can be int32 when, say, the first is uint8.
5987 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)5988 inline void Pad(const tflite::PadParams& op_params,
5989 const RuntimeShape& input_shape, const T* input_data,
5990 const int32* pad_value_ptr, const RuntimeShape& output_shape,
5991 T* output_data) {
5992 const T converted_pad_value = static_cast<T>(*pad_value_ptr);
5993 PadImpl(op_params, input_shape, input_data, &converted_pad_value,
5994 output_shape, output_data);
5995 }
5996
5997 // This version avoids conflicting template matching.
5998 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)5999 inline void Pad(const tflite::PadParams& op_params,
6000 const RuntimeShape& input_shape, const int32* input_data,
6001 const int32* pad_value_ptr, const RuntimeShape& output_shape,
6002 int32* output_data) {
6003 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
6004 output_data);
6005 }
6006
6007 // TODO(b/117643175): Optimize. (This is an introductory copy of standard Pad.)
6008 //
6009 // This pad requires that (a) left and right paddings are in the 4D patterns
6010 // {0, h_pad, w_pad, 0}, and (b) memset can be used: *pad_value_ptr == 0 and/or
6011 // T is uint8.
6012 //
6013 // There are two versions of pad: Pad and PadV2. In PadV2 there is a second
6014 // scalar input that provides the padding value. Therefore pad_value_ptr can be
6015 // equivalent to a simple input1_data. For Pad, it should point to a zero
6016 // value.
6017 //
6018 // Note that two typenames are required, so that T=P=int32 is considered a
6019 // specialization distinct from P=int32.
6020 template <typename T, typename P>
PadImageStyleMemset(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)6021 inline void PadImageStyleMemset(const tflite::PadParams& op_params,
6022 const RuntimeShape& input_shape,
6023 const T* input_data, const P* pad_value_ptr,
6024 const RuntimeShape& output_shape,
6025 T* output_data) {
6026 gemmlowp::ScopedProfilingLabel label("PadImageStyle");
6027 const RuntimeShape ext_input_shape =
6028 RuntimeShape::ExtendedShape(4, input_shape);
6029 const RuntimeShape ext_output_shape =
6030 RuntimeShape::ExtendedShape(4, output_shape);
6031 TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
6032 TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
6033
6034 // Pad kernels are limited to max 4 dimensions. Copy inputs so we can pad them
6035 // to 4 dims (yes, we are "padding the padding").
6036 std::vector<int> left_padding_copy(4, 0);
6037 const int left_padding_extend = 4 - op_params.left_padding_count;
6038 for (int i = 0; i < op_params.left_padding_count; ++i) {
6039 left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
6040 }
6041 std::vector<int> right_padding_copy(4, 0);
6042 const int right_padding_extend = 4 - op_params.right_padding_count;
6043 for (int i = 0; i < op_params.right_padding_count; ++i) {
6044 right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
6045 }
6046 // The following padding restrictions are contractual requirements, and
6047 // embody what it means for a padding op to be "image-style".
6048 TFLITE_DCHECK_EQ(left_padding_copy[0], 0);
6049 TFLITE_DCHECK_EQ(left_padding_copy[3], 0);
6050 TFLITE_DCHECK_EQ(right_padding_copy[0], 0);
6051 TFLITE_DCHECK_EQ(right_padding_copy[3], 0);
6052
6053 const int batch = MatchingDim(ext_input_shape, 0, ext_output_shape, 0);
6054 const int output_height = ext_output_shape.Dims(1);
6055 const int output_width = ext_output_shape.Dims(2);
6056 const int input_height = ext_input_shape.Dims(1);
6057 const int input_width = ext_input_shape.Dims(2);
6058 const int depth = MatchingDim(ext_input_shape, 3, ext_output_shape, 3);
6059
6060 const int left_h_padding = left_padding_copy[1];
6061 const int left_w_padding = left_padding_copy[2];
6062 const int right_h_padding = right_padding_copy[1];
6063 const int right_w_padding = right_padding_copy[2];
6064
6065 TFLITE_DCHECK_EQ(output_height,
6066 input_height + left_h_padding + right_h_padding);
6067 TFLITE_DCHECK_EQ(output_width,
6068 input_width + left_w_padding + right_w_padding);
6069
6070 const T pad_value = *pad_value_ptr;
6071 const int top_block_size = left_h_padding * output_width * depth;
6072 const size_t num_top_block_bytes = top_block_size * sizeof(T);
6073 const int bottom_block_size = right_h_padding * output_width * depth;
6074 const size_t num_bottom_block_bytes = bottom_block_size * sizeof(T);
6075 const int left_blocks_size = left_w_padding * depth;
6076 const size_t num_left_block_bytes = left_blocks_size * sizeof(T);
6077 const int right_blocks_size = right_w_padding * depth;
6078 const size_t num_right_block_bytes = right_blocks_size * sizeof(T);
6079 const int inner_line_size = input_width * depth;
6080 const size_t num_inner_line_bytes = inner_line_size * sizeof(T);
6081
6082 if (input_height == 0) {
6083 memset(output_data, pad_value,
6084 num_top_block_bytes + num_bottom_block_bytes);
6085 } else {
6086 for (int i = 0; i < batch; ++i) {
6087 // For each image in the batch, apply the top padding, then iterate
6088 // through rows, then apply the bottom padding.
6089 //
6090 // By unwinding one iteration, we can combine the first left-margin
6091 // padding with the top padding, and the last right-margin padding with
6092 // the bottom padding.
6093 memset(output_data, pad_value,
6094 num_top_block_bytes + num_left_block_bytes);
6095 output_data += top_block_size + left_blocks_size;
6096 memcpy(output_data, input_data, num_inner_line_bytes);
6097 input_data += inner_line_size;
6098 output_data += inner_line_size;
6099 // One iteration unwound.
6100 // Unwinding this loop affords the opportunity to reorder the loop work
6101 // and hence combine memset() calls.
6102 //
6103 // Before unwinding:
6104 // for (int j = 0; j < input_height; ++j) {
6105 // // Pad on left, copy central data, pad on right.
6106 // memset(output_data, pad_value, num_left_block_bytes);
6107 // output_data += left_blocks_size;
6108 // memcpy(output_data, input_data, num_inner_line_bytes);
6109 // input_data += inner_line_size;
6110 // output_data += inner_line_size;
6111 // memset(output_data, pad_value, num_right_block_bytes);
6112 // output_data += right_blocks_size;
6113 // }
6114 for (int j = 1; j < input_height; ++j) {
6115 memset(output_data, pad_value,
6116 num_right_block_bytes + num_left_block_bytes);
6117 output_data += right_blocks_size + left_blocks_size;
6118 memcpy(output_data, input_data, num_inner_line_bytes);
6119 input_data += inner_line_size;
6120 output_data += inner_line_size;
6121 }
6122 memset(output_data, pad_value,
6123 num_right_block_bytes + num_bottom_block_bytes);
6124 output_data += right_blocks_size + bottom_block_size;
6125 }
6126 }
6127 }
6128
6129 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)6130 inline void PadImageStyle(const tflite::PadParams& op_params,
6131 const RuntimeShape& input_shape, const T* input_data,
6132 const P* pad_value_ptr,
6133 const RuntimeShape& output_shape, T* output_data) {
6134 TFLITE_ASSERT_FALSE;
6135 }
6136
6137 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)6138 inline void PadImageStyle(const tflite::PadParams& op_params,
6139 const RuntimeShape& input_shape,
6140 const uint8* input_data, const P* pad_value_ptr,
6141 const RuntimeShape& output_shape,
6142 uint8* output_data) {
6143 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
6144 output_shape, output_data);
6145 }
6146
6147 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)6148 inline void PadImageStyle(const tflite::PadParams& op_params,
6149 const RuntimeShape& input_shape,
6150 const float* input_data, const P* pad_value_ptr,
6151 const RuntimeShape& output_shape,
6152 float* output_data) {
6153 const float converted_pad_value = static_cast<float>(*pad_value_ptr);
6154 if (converted_pad_value == 0.0f) {
6155 PadImageStyleMemset(op_params, input_shape, input_data, pad_value_ptr,
6156 output_shape, output_data);
6157 } else {
6158 PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
6159 output_data);
6160 }
6161 }
6162
6163 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)6164 inline void Slice(const tflite::SliceParams& op_params,
6165 const RuntimeShape& input_shape, const T* input_data,
6166 const RuntimeShape& output_shape, T* output_data) {
6167 gemmlowp::ScopedProfilingLabel label("Slice");
6168 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
6169 // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
6170 TFLITE_DCHECK_LE(op_params.begin_count, 4);
6171 TFLITE_DCHECK_LE(op_params.size_count, 4);
6172 const int begin_count = op_params.begin_count;
6173 const int size_count = op_params.size_count;
6174 // We front-pad the begin and size vectors.
6175 const int start_b = 4 - begin_count > 0 ? 0 : op_params.begin[0];
6176 const int stop_b = (4 - size_count > 0 || op_params.size[0] == -1)
6177 ? ext_shape.Dims(0) - start_b
6178 : start_b + op_params.size[0];
6179 const int start_h = begin_count < 3 ? 0 : op_params.begin[begin_count - 3];
6180 const int stop_h = (size_count < 3 || op_params.size[size_count - 3] == -1)
6181 ? ext_shape.Dims(1) - start_h
6182 : start_h + op_params.size[size_count - 3];
6183 const int start_w = begin_count < 2 ? 0 : op_params.begin[begin_count - 2];
6184 const int stop_w = (size_count < 2 || op_params.size[size_count - 2] == -1)
6185 ? ext_shape.Dims(2) - start_w
6186 : start_w + op_params.size[size_count - 2];
6187 const int start_d = begin_count < 1 ? 0 : op_params.begin[begin_count - 1];
6188 const int stop_d = (size_count < 1 || op_params.size[size_count - 1] == -1)
6189 ? ext_shape.Dims(3) - start_d
6190 : start_d + op_params.size[size_count - 1];
6191
6192 T* out_ptr = output_data;
6193 for (int in_b = start_b; in_b < stop_b; ++in_b) {
6194 for (int in_h = start_h; in_h < stop_h; ++in_h) {
6195 for (int in_w = start_w; in_w < stop_w; ++in_w) {
6196 const int len = stop_d - start_d;
6197 memcpy(out_ptr,
6198 input_data + Offset(ext_shape, in_b, in_h, in_w, start_d),
6199 len * sizeof(T));
6200 out_ptr += len;
6201 }
6202 }
6203 }
6204 }
6205
6206 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6207 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
6208 const T* input2_data, const RuntimeShape& output_shape,
6209 T* output_data) {
6210 gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
6211 auto input1_map = MapAsVector(input1_data, input1_shape);
6212 auto output_map = MapAsVector(output_data, output_shape);
6213 auto min_value = input2_data[0];
6214 output_map.array() = input1_map.array().min(min_value);
6215 }
6216
6217 // Convenience version that allows, for example, generated-code calls to be
6218 // the same as other binary ops.
6219 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6220 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
6221 const RuntimeShape&, const T* input2_data,
6222 const RuntimeShape& output_shape, T* output_data) {
6223 // Drop shape of second input: not needed.
6224 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
6225 }
6226
6227 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6228 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
6229 const T* input2_data, const RuntimeShape& output_shape,
6230 T* output_data) {
6231 gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
6232 auto input1_map = MapAsVector(input1_data, input1_shape);
6233 auto output_map = MapAsVector(output_data, output_shape);
6234 auto max_value = input2_data[0];
6235 output_map.array() = input1_map.array().max(max_value);
6236 }
6237
6238 // Convenience version that allows, for example, generated-code calls to be
6239 // the same as other binary ops.
6240 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)6241 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
6242 const RuntimeShape&, const T* input2_data,
6243 const RuntimeShape& output_shape, T* output_data) {
6244 // Drop shape of second input: not needed.
6245 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
6246 }
6247
6248 template <typename T>
TransposeIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)6249 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
6250 const RuntimeShape& input_shape, const T* input_data,
6251 const RuntimeShape& filter_shape,
6252 const RuntimeShape& output_shape, T* im2col_data) {
6253 gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
6254 const int stride_width = params.stride_width;
6255 const int stride_height = params.stride_height;
6256 const int pad_width = params.padding_values.width;
6257 const int pad_height = params.padding_values.height;
6258 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
6259 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
6260 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
6261 TFLITE_DCHECK(im2col_data);
6262
6263 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
6264 const int input_height = input_shape.Dims(1);
6265 const int input_width = input_shape.Dims(2);
6266 const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
6267 const int filter_height = filter_shape.Dims(1);
6268 const int filter_width = filter_shape.Dims(2);
6269 const int output_height = output_shape.Dims(1);
6270 const int output_width = output_shape.Dims(2);
6271 MatchingDim(output_shape, 3, filter_shape, 0); // output_depth
6272
6273 // Construct the MxN sized im2col matrix.
6274 // The rows M, are sub-ordered B x H x W
6275 const RuntimeShape row_shape({1, batches, output_height, output_width});
6276 // The columns, N, are sub-ordered Kh x Kw x Din
6277 const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
6278 // Use dimensions M and N to construct dims for indexing directly into im2col
6279 const RuntimeShape im2col_shape(
6280 {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
6281
6282 // Build the im2col matrix by looping through all the input pixels,
6283 // computing their influence on the output, rather than looping through all
6284 // the output pixels. We therefore must initialize the im2col array to zero.
6285 // This is potentially inefficient because we subsequently overwrite bytes
6286 // set here. However, in practice memset is very fast and costs negligible.
6287 memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
6288
6289 // Loop through the output batches
6290 for (int batch = 0; batch < batches; ++batch) {
6291 // Loop through input pixels one at a time.
6292 for (int in_y = 0; in_y < input_height; ++in_y) {
6293 for (int in_x = 0; in_x < input_width; ++in_x) {
6294 // Loop through the output pixels it will influence
6295 const int out_x_origin = (in_x * stride_width) - pad_width;
6296 const int out_y_origin = (in_y * stride_height) - pad_height;
6297 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
6298 const int out_y = out_y_origin + filter_y;
6299 // Is output pixel within height bounds?
6300 if ((out_y >= 0) && (out_y < output_height)) {
6301 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
6302 const int out_x = out_x_origin + filter_x;
6303 // Is output pixel within width bounds?
6304 if ((out_x >= 0) && (out_x < output_width)) {
6305 // Copy the input elements of this pixel
6306 T const* src =
6307 input_data + Offset(input_shape, batch, in_y, in_x, 0);
6308 int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
6309 int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
6310 T* dst = im2col_data +
6311 Offset(im2col_shape, 0, 0, row_offset, col_offset);
6312 memcpy(dst, src, input_depth * sizeof(T));
6313 }
6314 }
6315 }
6316 }
6317 }
6318 }
6319 }
6320 }
6321
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)6322 inline void TransposeConv(
6323 const ConvParams& params, const RuntimeShape& input_shape,
6324 const float* input_data, const RuntimeShape& filter_shape,
6325 const float* filter_data, const RuntimeShape& output_shape,
6326 float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
6327 gemmlowp::ScopedProfilingLabel label("TransposeConv");
6328 // The complexity of the reference implementation is input.flat_size() *
6329 // filter.flat_size() / in_channel.
6330 //
6331 // While the complexity of im2col->gemm
6332 // implmentation is batch * output_height * output_width *
6333 // (filter.flat_size() / out_channel)^2 * out_channel.
6334 //
6335 // so if input.flat_size() * out_channel^2 is much smaller than
6336 // output.flat_size() * filter.size() * in_channel we should fall back to the
6337 // reference implementation.
6338 //
6339 // TODO(b/122331966): optimize the intuitive implementation.
6340 const int out_channel = output_shape.Dims(3);
6341 const int in_channel = input_shape.Dims(3);
6342 if ((input_shape.FlatSize() * out_channel * out_channel * 4) <
6343 (filter_shape.FlatSize() * output_shape.FlatSize() * in_channel)) {
6344 reference_ops::TransposeConv(params, input_shape, input_data, filter_shape,
6345 filter_data, output_shape, output_data,
6346 im2col_shape, im2col_data);
6347 return;
6348 }
6349 // Note we could use transposed weights with forward conv for unstrided
6350 // cases. But we are already getting good performance with this code as-is.
6351 TFLITE_DCHECK(im2col_data);
6352 TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
6353 output_shape, im2col_data);
6354
6355 const auto im2col_matrix_map =
6356 MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
6357 const auto filter_matrix_map =
6358 MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
6359 auto output_matrix_map =
6360 MapAsMatrixWithLastDimAsRows(output_data, output_shape);
6361
6362 Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
6363 }
6364
6365 // Integer-only version of ResizeNearestNeighbor. Since scales are represented
6366 // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
6367 // reference version. Debug checks are in place to test if this occurs.
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const uint8 * input_data,const RuntimeShape & output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,uint8 * output_data)6368 inline void ResizeNearestNeighbor(
6369 const tflite::ResizeNearestNeighborParams& op_params,
6370 const RuntimeShape& unextended_input_shape, const uint8* input_data,
6371 const RuntimeShape& output_size_shape, const int32* output_size_data,
6372 const RuntimeShape& unextended_output_shape, uint8* output_data) {
6373 // Align corners = true is not supported.
6374 TFLITE_DCHECK(!op_params.align_corners);
6375 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
6376 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
6377
6378 const RuntimeShape input_shape =
6379 RuntimeShape::ExtendedShape(4, unextended_input_shape);
6380 const RuntimeShape output_shape =
6381 RuntimeShape::ExtendedShape(4, unextended_output_shape);
6382
6383 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
6384 int32 input_height = input_shape.Dims(1);
6385 int32 input_width = input_shape.Dims(2);
6386 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
6387
6388 // The Tensorflow version of this op allows resize on the width and height
6389 // axis only.
6390 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
6391 int32 output_height = output_size_data[0];
6392 int32 output_width = output_size_data[1];
6393
6394 // Convert scales to fixed-point with 16 fractional bits. We add 1 as an
6395 // error factor and to avoid zero scales. For example, with input_height = 1,
6396 // output_height = 3, the float scaling factor would be non-zero at 1/3.
6397 // With fixed-point, this is zero.
6398 int32 height_scale = (input_height << 16) / output_height + 1;
6399 int32 width_scale = (input_width << 16) / output_width + 1;
6400
6401 const int col_offset = input_shape.Dims(3);
6402 const int row_offset = input_shape.Dims(2) * col_offset;
6403 const int batch_offset = input_shape.Dims(1) * row_offset;
6404
6405 const uint8* input_ptr = input_data;
6406 uint8* output_ptr = output_data;
6407 for (int b = 0; b < batches; ++b) {
6408 for (int y = 0; y < output_height; ++y) {
6409 int32 in_y = std::min((y * height_scale) >> 16, input_height - 1);
6410 // Check offset calculation is the same as the reference version. See
6411 // function comment for details. We check using a non-float version of:
6412 // TFLITE_DCHECK_EQ(in_y, std::floor(y * (static_cast<float>(input_height)
6413 // / output_height)));
6414 TFLITE_DCHECK_LT(y * input_height, output_height + in_y * output_height);
6415 TFLITE_DCHECK_GE(y * input_height, in_y * output_height);
6416 const uint8* y_input_ptr = input_ptr + in_y * row_offset;
6417 for (int x = 0; x < output_width; ++x) {
6418 int32 in_x = std::min((x * width_scale) >> 16, input_width - 1);
6419 // Check offset calculation is the same as the reference version. See
6420 // function comment for details. We check using a non-float version of:
6421 // TFLITE_DCHECK_EQ(in_y,
6422 // std::floor(y * (static_cast<float>(input_width)
6423 // / output_width)));
6424 TFLITE_DCHECK_LT(x * input_width, output_width + in_x * output_width);
6425 TFLITE_DCHECK_GE(x * input_width, in_x * output_width);
6426 const uint8* x_input_ptr = y_input_ptr + in_x * col_offset;
6427 memcpy(output_ptr, x_input_ptr, depth);
6428 output_ptr += depth;
6429 }
6430 }
6431 input_ptr += batch_offset;
6432 }
6433 }
6434
6435 } // namespace optimized_ops
6436 } // namespace tflite
6437
6438 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
6439 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
6440 #pragma GCC diagnostic pop
6441 #endif
6442
6443 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
6444