1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23
24 #include <functional>
25
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30
31 namespace tflite {
32
33 constexpr int kReverseShift = -1;
34
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36 float* output_activation_min,
37 float* output_activation_max) {
38 switch (ac) {
39 case FusedActivationFunctionType::kNone:
40 *output_activation_min = std::numeric_limits<float>::lowest();
41 *output_activation_max = std::numeric_limits<float>::max();
42 break;
43 case FusedActivationFunctionType::kRelu:
44 *output_activation_min = 0.f;
45 *output_activation_max = std::numeric_limits<float>::max();
46 break;
47 case FusedActivationFunctionType::kRelu1:
48 *output_activation_min = -1.f;
49 *output_activation_max = 1.f;
50 break;
51 case FusedActivationFunctionType::kRelu6:
52 *output_activation_min = 0.f;
53 *output_activation_max = 6.f;
54 break;
55 }
56 }
57
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60 T output_activation_max) {
61 using std::max;
62 using std::min;
63 return min(max(x, output_activation_min), output_activation_max);
64 }
65
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69 float output_activation_min, output_activation_max;
70 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71 return ActivationFunctionWithMinMax(x, output_activation_min,
72 output_activation_max);
73 }
74
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76 const float* bias_data, int array_size,
77 float* array_data) {
78 // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79 // this with the Eigen one-liner:
80 // return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81 // This turned out to severely regress performance: +4ms (i.e. 8%) on
82 // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85 float* array_ptr = array_data;
86 float* array_end_ptr = array_ptr + array_size;
87 const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88 const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89 for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90 int i = 0;
91 for (; i <= bias_size - 16; i += 16) {
92 auto b0 = vld1q_f32(bias_data + i);
93 auto b1 = vld1q_f32(bias_data + i + 4);
94 auto b2 = vld1q_f32(bias_data + i + 8);
95 auto b3 = vld1q_f32(bias_data + i + 12);
96 auto a0 = vld1q_f32(array_ptr + i);
97 auto a1 = vld1q_f32(array_ptr + i + 4);
98 auto a2 = vld1q_f32(array_ptr + i + 8);
99 auto a3 = vld1q_f32(array_ptr + i + 12);
100 auto x0 = vaddq_f32(a0, b0);
101 auto x1 = vaddq_f32(a1, b1);
102 auto x2 = vaddq_f32(a2, b2);
103 auto x3 = vaddq_f32(a3, b3);
104 x0 = vmaxq_f32(clamp_min_vec, x0);
105 x1 = vmaxq_f32(clamp_min_vec, x1);
106 x2 = vmaxq_f32(clamp_min_vec, x2);
107 x3 = vmaxq_f32(clamp_min_vec, x3);
108 x0 = vminq_f32(clamp_max_vec, x0);
109 x1 = vminq_f32(clamp_max_vec, x1);
110 x2 = vminq_f32(clamp_max_vec, x2);
111 x3 = vminq_f32(clamp_max_vec, x3);
112 vst1q_f32(array_ptr + i, x0);
113 vst1q_f32(array_ptr + i + 4, x1);
114 vst1q_f32(array_ptr + i + 8, x2);
115 vst1q_f32(array_ptr + i + 12, x3);
116 }
117 for (; i <= bias_size - 4; i += 4) {
118 auto b = vld1q_f32(bias_data + i);
119 auto a = vld1q_f32(array_ptr + i);
120 auto x = vaddq_f32(a, b);
121 x = vmaxq_f32(clamp_min_vec, x);
122 x = vminq_f32(clamp_max_vec, x);
123 vst1q_f32(array_ptr + i, x);
124 }
125 for (; i < bias_size; i++) {
126 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127 clamp_min, clamp_max);
128 }
129 }
130 #else // not NEON
131 for (int array_offset = 0; array_offset < array_size;
132 array_offset += bias_size) {
133 for (int i = 0; i < bias_size; i++) {
134 array_data[array_offset + i] = ActivationFunctionWithMinMax(
135 array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136 }
137 }
138 #endif
139 }
140
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142 int32_t x, int32_t quantized_multiplier, int left_shift) {
143 using gemmlowp::RoundingDivideByPOT;
144 using gemmlowp::SaturatingRoundingDoublingHighMul;
145 return RoundingDivideByPOT(
146 SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150 int32_t x, int32_t quantized_multiplier, int left_shift) {
151 using gemmlowp::SaturatingRoundingDoublingHighMul;
152 return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153 quantized_multiplier);
154 }
155
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157 int32_t quantized_multiplier,
158 int shift) {
159 using gemmlowp::RoundingDivideByPOT;
160 using gemmlowp::SaturatingRoundingDoublingHighMul;
161 int left_shift = shift > 0 ? shift : 0;
162 int right_shift = shift > 0 ? 0 : -shift;
163 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164 x * (1 << left_shift), quantized_multiplier),
165 right_shift);
166 }
167
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169 int32_t quantized_multiplier,
170 int shift) {
171 // Inputs:
172 // - quantized_multiplier has fixed point at bit 31
173 // - shift is -31 to +7 (negative for right shift)
174 //
175 // Assumptions: The following input ranges are assumed
176 // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
177 // - scaling is chosen so final scaled result fits in int32_t
178 // - input x is in the range -(1<<47) <= x < (1<<47)
179 assert(quantized_multiplier >= 0);
180 assert(shift >= -31 && shift < 8);
181 assert(x >= -(static_cast<int64_t>(1) << 47) &&
182 x < (static_cast<int64_t>(1) << 47));
183
184 int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185 ? ((quantized_multiplier + (1 << 15)) >> 16)
186 : 0x7FFF;
187 int total_shift = 15 - shift;
188 x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189 int32_t result = x >> total_shift;
190 return result;
191 }
192
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196 int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197 const int left_shift = std::max(shift, 0);
198 const int right_shift = std::min(shift, 0);
199 int32x4x4_t result;
200
201 int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202 int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203 int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204
205 result.val[0] =
206 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207 multiplier_dup),
208 right_shift_dup);
209
210 result.val[1] =
211 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212 multiplier_dup),
213 right_shift_dup);
214
215 result.val[2] =
216 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217 multiplier_dup),
218 right_shift_dup);
219
220 result.val[3] =
221 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222 multiplier_dup),
223 right_shift_dup);
224
225 return result;
226 }
227 #endif
228
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231 static_assert(std::is_unsigned<T>::value,
232 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234 return integer_input ? __builtin_clz(integer_input)
235 : std::numeric_limits<T>::digits;
236 #else
237 if (integer_input == 0) {
238 return std::numeric_limits<T>::digits;
239 }
240
241 const T one_in_leading_positive = static_cast<T>(1)
242 << (std::numeric_limits<T>::digits - 1);
243 int leading_zeros = 0;
244 while (integer_input < one_in_leading_positive) {
245 integer_input <<= 1;
246 ++leading_zeros;
247 }
248 return leading_zeros;
249 #endif
250 }
251
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254 static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256 return integer_input ? __builtin_clrsb(integer_input)
257 : std::numeric_limits<T>::digits;
258 #else
259 using U = typename std::make_unsigned<T>::type;
260 return integer_input >= 0
261 ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262 : integer_input != std::numeric_limits<T>::min()
263 ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264 : 0;
265 #endif
266 }
267
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271 static_assert(std::is_integral<Integer>::value, "");
272 static_assert(std::is_signed<Integer>::value, "");
273 static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274 TFLITE_CHECK_GT(n, 0);
275 if (sizeof(Integer) == 4) {
276 return 30 - CountLeadingSignBits(n);
277 } else {
278 return 62 - CountLeadingSignBits(n);
279 }
280 }
281
282 // The size of the LUT depends on the type of input. For int8 inputs a simple
283 // 256 entries LUT is used. For int16 inputs the high 9 bits are used for
284 // indexing and the 7 remaining bits are used for interpolation. We thus use a
285 // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
286 // to interpolate the last value.
287 template <typename LutInT>
lut_size()288 constexpr int lut_size() {
289 static_assert(std::is_same<LutInT, int8_t>::value ||
290 std::is_same<LutInT, int16_t>::value,
291 "Only LUTs with int8 or int16 inputs are supported.");
292 return std::is_same<LutInT, int8_t>::value ? 256 : 513;
293 }
294
295 // Generate a LUT for 'func' which can be used to approximate functions like
296 // exp, log, ...
297 //
298 // - func: the function to build the LUT for (e.g exp(x))
299 // - input_min, input_max: range of the func inputs
300 // - output_min, output_max: range of the func outputs
301 // - lut: pointer to the LUT table to fill, the table must be of size
302 // lut_size<LutInT>()
303 template <typename FloatT, typename LutInT, typename LutOutT>
gen_lut(FloatT (* func)(FloatT),FloatT input_min,FloatT input_max,FloatT output_min,FloatT output_max,LutOutT * lut)304 inline void gen_lut(FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
305 FloatT output_min, FloatT output_max, LutOutT* lut) {
306 static_assert(std::is_same<LutInT, int8_t>::value ||
307 std::is_same<LutInT, int16_t>::value,
308 "Only LUTs with int8 or int16 inputs are supported.");
309 static_assert(std::is_same<LutOutT, int8_t>::value ||
310 std::is_same<LutOutT, int16_t>::value,
311 "Only LUTs with int8 or int16 outputs are supported.");
312 static_assert(std::is_floating_point<FloatT>::value,
313 "FloatT must be a floating-point type.");
314
315 const int nb_steps = std::is_same<LutInT, int8_t>::value ? 256 : 512;
316 const FloatT step = (input_max - input_min) / nb_steps;
317 const FloatT half_step = step / 2;
318 const FloatT output_scaling_inv =
319 static_cast<FloatT>(std::numeric_limits<LutOutT>::max() -
320 std::numeric_limits<LutOutT>::min() + 1) /
321 (output_max - output_min);
322 const FloatT table_min =
323 static_cast<FloatT>(std::numeric_limits<LutOutT>::min());
324 const FloatT table_max =
325 static_cast<FloatT>(std::numeric_limits<LutOutT>::max());
326
327 for (int i = 0; i < nb_steps; i++) {
328 const FloatT val = func(input_min + i * step);
329 const FloatT val_midpoint = func(input_min + i * step + half_step);
330 const FloatT val_next = func(input_min + (i + 1) * step);
331
332 const FloatT sample_val = TfLiteRound(val * output_scaling_inv);
333 const FloatT midpoint_interp_val =
334 TfLiteRound((val_next * output_scaling_inv +
335 TfLiteRound(val * output_scaling_inv)) /
336 2);
337 const FloatT midpoint_val = TfLiteRound(val_midpoint * output_scaling_inv);
338 const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
339 const FloatT bias = TfLiteRound(midpoint_err / 2);
340
341 lut[i] = static_cast<LutOutT>(std::min<FloatT>(
342 std::max<FloatT>(sample_val - bias, table_min), table_max));
343 }
344
345 const bool with_extra_interpolation_value =
346 std::is_same<LutInT, int16_t>::value;
347 if (with_extra_interpolation_value) {
348 lut[nb_steps] = static_cast<LutOutT>(std::min<FloatT>(
349 std::max<FloatT>(TfLiteRound(func(input_max) * output_scaling_inv),
350 table_min),
351 table_max));
352 }
353 }
354
355 // LUT must have 513 values
356 template <typename LutOutT>
lut_lookup_with_interpolation(int16_t value,const LutOutT * lut)357 inline LutOutT lut_lookup_with_interpolation(int16_t value,
358 const LutOutT* lut) {
359 static_assert(std::is_same<LutOutT, int8_t>::value ||
360 std::is_same<LutOutT, int16_t>::value,
361 "Only LUTs with int8 or int16 outputs are supported.");
362 // 512 base values, lut[513] is only used to calculate the slope
363 const uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
364 assert(index < 512 && "LUT index out of range.");
365 const int16_t offset = value & 0x7f;
366
367 // Base and slope are Q0.x
368 const LutOutT base = lut[index];
369 const LutOutT slope = lut[index + 1] - lut[index];
370
371 // Q0.x * Q0.7 = Q0.(x + 7)
372 // Round and convert from Q0.(x + 7) to Q0.x
373 const int delta = (slope * offset + 64) >> 7;
374
375 // Q0.15 + Q0.15
376 return static_cast<LutOutT>(base + delta);
377 }
378
379 // int16_t -> int16_t table lookup with interpolation
380 // LUT must have 513 values
lut_lookup(int16_t value,const int16_t * lut)381 inline int16_t lut_lookup(int16_t value, const int16_t* lut) {
382 return lut_lookup_with_interpolation(value, lut);
383 }
384
385 // int16_t -> int8_t table lookup with interpolation
386 // LUT must have 513 values
lut_lookup(int16_t value,const int8_t * lut)387 inline int8_t lut_lookup(int16_t value, const int8_t* lut) {
388 return lut_lookup_with_interpolation(value, lut);
389 }
390
391 // int8_t -> int8_t table lookup without interpolation
392 // LUT must have 256 values
lut_lookup(int8_t value,const int8_t * lut)393 inline int8_t lut_lookup(int8_t value, const int8_t* lut) {
394 return lut[128 + value];
395 }
396
397 // int8_t -> int16_t table lookup without interpolation
398 // LUT must have 256 values
lut_lookup(int8_t value,const int16_t * lut)399 inline int16_t lut_lookup(int8_t value, const int16_t* lut) {
400 return lut[128 + value];
401 }
402
403 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
404
405 // We use combined sigmoid and tanh look-up table, since
406 // tanh(x) = 2*sigmoid(2*x) -1.
407 // Both functions are symmetric, so the LUT table is only needed
408 // for the absolute value of the input.
409 static const uint16_t sigmoid_table_uint16[256] = {
410 32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
411 40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
412 46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
413 52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
414 56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
415 59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
416 61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
417 62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
418 63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
419 64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
420 64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
421 65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
422 65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
423 65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
424 65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
425 65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
426 65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
427 65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
428 65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
429 65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
430 65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
431 65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
432 65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
433 65534, 65534, 65535};
434
435 // TODO(b/77858996): Add these to gemmlowp.
436 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)437 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
438 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
439 return a;
440 }
441
442 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)443 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
444 std::int64_t a64 = a;
445 std::int64_t b64 = b;
446 std::int64_t sum = a64 + b64;
447 return static_cast<std::int32_t>(std::min(
448 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
449 std::max(
450 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
451 sum)));
452 }
453
454 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)455 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
456 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
457 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
458 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
459 SaturatingAddNonGemmlowp(a.raw(), b.raw()));
460 }
461
462 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)463 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
464 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
465 return a;
466 }
467
468 template <>
SaturatingSub(std::int16_t a,std::int16_t b)469 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
470 std::int32_t a32 = a;
471 std::int32_t b32 = b;
472 std::int32_t diff = a32 - b32;
473 return static_cast<std::int16_t>(
474 std::min(static_cast<int32_t>(32767),
475 std::max(static_cast<int32_t>(-32768), diff)));
476 }
477
478 template <>
SaturatingSub(std::int32_t a,std::int32_t b)479 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
480 std::int64_t a64 = a;
481 std::int64_t b64 = b;
482 std::int64_t diff = a64 - b64;
483 return static_cast<std::int32_t>(std::min(
484 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
485 std::max(
486 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
487 diff)));
488 }
489
490 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)491 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
492 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
493 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
494 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
495 SaturatingSub(a.raw(), b.raw()));
496 }
497 // End section to be moved to gemmlowp.
498
499 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)500 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
501 if (exponent == 0) {
502 return x;
503 }
504 using ScalarIntegerType =
505 typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
506 const IntegerType min =
507 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
508 const IntegerType max =
509 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
510 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
511
512 const std::int32_t threshold =
513 ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
514 const IntegerType positive_mask =
515 gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
516 const IntegerType negative_mask =
517 gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
518
519 IntegerType result = gemmlowp::ShiftLeft(x, exponent);
520 result = gemmlowp::SelectUsingMask(positive_mask, max, result);
521 result = gemmlowp::SelectUsingMask(negative_mask, min, result);
522 return result;
523 }
524
525 // If we want to leave IntegerBits fixed, then multiplication
526 // by a power of two has to be saturating/rounding, not exact anymore.
527 template <typename tRawType, int tIntegerBits>
528 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)529 SaturatingRoundingMultiplyByPOTParam(
530 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
531 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
532 SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
533 }
534
535 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)536 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
537 int16_t* multiplier_int16_t) {
538 TFLITE_DCHECK_GE(multiplier_int32_t, 0);
539 static constexpr int32_t kRoundingOffset = 1 << 15;
540 if (multiplier_int32_t >=
541 std::numeric_limits<int32_t>::max() - kRoundingOffset) {
542 *multiplier_int16_t = std::numeric_limits<int16_t>::max();
543 return;
544 }
545 const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
546 TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
547 TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
548 *multiplier_int16_t = result;
549 TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
550 }
551
552 // Minimum output bits to accommodate log of maximum input range. It actually
553 // does not matter if one considers, say, [-64,64] or [-64,64).
554 //
555 // For example, run this through Octave:
556 // [0:127; ...
557 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
558 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)559 constexpr int min_log_x_output_bits(int input_bits) {
560 return input_bits > 90 ? 7
561 : input_bits > 44 ? 6
562 : input_bits > 21 ? 5
563 : input_bits > 10 ? 4
564 : input_bits > 4 ? 3
565 : input_bits > 1 ? 2
566 : 1;
567 }
568
569 // Although currently the name of this function says that it cannot handle
570 // values less than 1, in practice it can handle as low as 1/x_max, where
571 // x_max is the largest representable input. In other words, the output range
572 // is symmetric.
573 template <int OutputIntegerBits, int InputIntegerBits>
574 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)575 log_x_for_x_greater_than_or_equal_to_1_impl(
576 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
577 // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
578 // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
579 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
580 // The reason for accumulating the result with an extra bit of headroom is
581 // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
582 // recip_denom will otherwise introduce an error.
583 static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
584 using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
585
586 const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
587 FixedPoint0, 1488522236, std::log(2.0));
588 const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
589 FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
590 const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
591 FixedPoint0, 1518500250, std::sqrt(0.5));
592 const FixedPoint0 one_quarter =
593 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
594
595 const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
596 FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
597 const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
598 FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
599 const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
600 FixedPoint0, 1057819769,
601 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
602 const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
603 FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
604
605 const FixedPointAccum shifted_quarter =
606 gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
607
608 // Reinterpret the input value as Q0.31, because we will figure out the
609 // required shift "ourselves" instead of using, say, Rescale.
610 FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
611 // z_a_pow_2 = input_integer_bits - z_a_headroom;
612 int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
613 FixedPoint0 r_a_tmp =
614 SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
615 const int32_t r_a_raw =
616 SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
617 // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
618 // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
619 // InputIntegerBits - z_b_headroom - 0.25);
620 const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
621 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
622 static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
623 31 - kAccumIntegerBits)),
624 shifted_quarter);
625
626 // z_b is treated like z_a, but premultiplying by sqrt(0.5).
627 FixedPoint0 z_b = z_a * sqrt_half;
628 int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
629 const int32_t r_b_raw =
630 SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
631 const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
632 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
633 static_cast<int32_t>(InputIntegerBits - z_b_headroom),
634 31 - kAccumIntegerBits)),
635 shifted_quarter);
636
637 const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
638 const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
639 std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
640
641 const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
642 FixedPoint0 q = r - sqrt_sqrt_half;
643 q = q + q;
644
645 const FixedPoint0 common_sq = q * q;
646 const FixedPoint0 num = q * r + q * common_sq * alpha_n;
647 const FixedPoint0 denom_minus_one_0 =
648 p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
649 const FixedPoint0 recip_denom =
650 one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
651
652 const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
653 return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
654 num_scaled * recip_denom);
655 }
656
657 template <int OutputIntegerBits, int InputIntegerBits>
658 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)659 log_x_for_x_greater_than_or_equal_to_1(
660 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
661 static_assert(
662 OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
663 "Output integer bits must be sufficient to accommodate logs of inputs.");
664 return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
665 InputIntegerBits>(
666 input_val);
667 }
668
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)669 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
670 int* num_bits_over_unit) {
671 int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
672 // This is the number of bits to the left of the binary point above 1.0.
673 // Consider x=1.25. In that case shifted_scale=0.8 and
674 // no later adjustment will be needed.
675 *num_bits_over_unit = x_integer_digits - headroom_plus_one;
676 const int32_t shifted_sum_minus_one =
677 static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
678 (static_cast<uint32_t>(1) << 31));
679
680 gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
681 gemmlowp::one_over_one_plus_x_for_x_in_0_1(
682 gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
683 return shifted_scale.raw();
684 }
685
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)686 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
687 int32_t* output_inv_sqrt,
688 int* output_shift) {
689 TFLITE_DCHECK_GE(input, 0);
690 if (input <= 1) {
691 // Handle the input value 1 separately to avoid overflow in that case
692 // in the general computation below (b/143972021). Also handle 0 as if it
693 // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
694 // but rare/unrealistic input value. We can expect both to occur in some
695 // incompletely trained models, but probably not in fully trained models.
696 *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
697 *output_shift = 0;
698 return;
699 }
700 TFLITE_DCHECK_GT(input, 1);
701 *output_shift = 11;
702 while (input >= (1 << 29)) {
703 input /= 4;
704 ++*output_shift;
705 }
706 const unsigned max_left_shift_bits =
707 CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
708 const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
709 const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
710 *output_shift -= left_shift_bit_pairs;
711 input <<= 2 * left_shift_bit_pairs;
712 TFLITE_DCHECK_GE(input, (1 << 27));
713 TFLITE_DCHECK_LT(input, (1 << 29));
714 using gemmlowp::FixedPoint;
715 using gemmlowp::Rescale;
716 using gemmlowp::SaturatingRoundingMultiplyByPOT;
717 // Using 3 integer bits gives us enough room for the internal arithmetic in
718 // this Newton-Raphson iteration.
719 using F3 = FixedPoint<int32_t, 3>;
720 using F0 = FixedPoint<int32_t, 0>;
721 const F3 fixedpoint_input = F3::FromRaw(input >> 1);
722 const F3 fixedpoint_half_input =
723 SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
724 const F3 fixedpoint_half_three =
725 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
726 // Newton-Raphson iteration
727 // Naive unoptimized starting guess: x = 1
728 F3 x = F3::One();
729 // Naive unoptimized number of iterations: 5
730 for (int i = 0; i < 5; i++) {
731 const F3 x3 = Rescale<3>(x * x * x);
732 x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
733 }
734 const F0 fixedpoint_half_sqrt_2 =
735 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
736 x = x * fixedpoint_half_sqrt_2;
737 *output_inv_sqrt = x.raw();
738 if (*output_shift < 0) {
739 *output_inv_sqrt <<= -*output_shift;
740 *output_shift = 0;
741 }
742 // Convert right shift (right is positive) to left shift.
743 *output_shift *= reverse_shift;
744 }
745
746 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
747 // BROADCASTING.
748 //
749 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
750 // rectangular array of numbers.
751 //
752 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
753 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
754 // to enable simple unoptimized implementations of element-wise broadcasting
755 // operations.
756 template <int N>
757 struct NdArrayDesc {
758 // The "extent" of each dimension. Indices along dimension d must be in the
759 // half-open interval [0, extents[d]).
760 int extents[N];
761
762 // The number of *elements* (not bytes) between consecutive indices of each
763 // dimension.
764 int strides[N];
765 };
766
767 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
768 // BROADCASTING.
769 //
770 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)771 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
772 int i3) {
773 TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
774 TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
775 TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
776 TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
777 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
778 i3 * desc.strides[3];
779 }
780
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])781 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
782 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
783 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
784 indexes[4] * desc.strides[4];
785 }
786
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])787 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
788 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
789 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
790 indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
791 indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
792 }
793
794 // Given the dimensions of the operands for an element-wise binary broadcast,
795 // adjusts them so that they can be directly iterated over with simple loops.
796 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
797 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
798 //
799 // This function assumes that the two input shapes are compatible up to
800 // broadcasting and the shorter one has already been prepended with 1s to be the
801 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
802 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
803 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
804 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
805 //
806 // When two shapes are compatible up to broadcasting, for each dimension d,
807 // the input extents are either equal, or one of them is 1.
808 //
809 // This function performs the following for each dimension d:
810 // - If the extents are equal, then do nothing since the loop that walks over
811 // both of the input arrays is correct.
812 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
813 // and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
814 // array0 to be referenced *at any index* in dimension d and still access the
815 // same slice.
816 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)817 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
818 const Dims<N>& input1_dims,
819 NdArrayDesc<N>* desc0_out,
820 NdArrayDesc<N>* desc1_out) {
821 TFLITE_DCHECK(desc0_out != nullptr);
822 TFLITE_DCHECK(desc1_out != nullptr);
823
824 // Copy dims to desc.
825 for (int i = 0; i < N; ++i) {
826 desc0_out->extents[i] = input0_dims.sizes[i];
827 desc0_out->strides[i] = input0_dims.strides[i];
828 desc1_out->extents[i] = input1_dims.sizes[i];
829 desc1_out->strides[i] = input1_dims.strides[i];
830 }
831
832 // Walk over each dimension. If the extents are equal do nothing.
833 // Otherwise, set the desc with extent 1 to have extent equal to the other and
834 // stride 0.
835 for (int i = 0; i < N; ++i) {
836 const int extent0 = ArraySize(input0_dims, i);
837 const int extent1 = ArraySize(input1_dims, i);
838 if (extent0 != extent1) {
839 if (extent0 == 1) {
840 desc0_out->strides[i] = 0;
841 desc0_out->extents[i] = extent1;
842 } else {
843 TFLITE_DCHECK_EQ(extent1, 1);
844 desc1_out->strides[i] = 0;
845 desc1_out->extents[i] = extent0;
846 }
847 }
848 }
849 }
850
851 // Copies dims to desc, calculating strides.
852 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)853 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
854 NdArrayDesc<N>* desc_out) {
855 int desc_stride = 1;
856 for (int i = N - 1; i >= 0; --i) {
857 desc_out->extents[i] = input_shape.Dims(i);
858 desc_out->strides[i] = desc_stride;
859 desc_stride *= input_shape.Dims(i);
860 }
861 }
862
863 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)864 inline void NdArrayDescsForElementwiseBroadcast(
865 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
866 NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
867 TFLITE_DCHECK(desc0_out != nullptr);
868 TFLITE_DCHECK(desc1_out != nullptr);
869
870 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
871 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
872
873 // Copy dims to desc, calculating strides.
874 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
875 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
876
877 // Walk over each dimension. If the extents are equal do nothing.
878 // Otherwise, set the desc with extent 1 to have extent equal to the other and
879 // stride 0.
880 for (int i = 0; i < N; ++i) {
881 const int extent0 = extended_input0_shape.Dims(i);
882 const int extent1 = extended_input1_shape.Dims(i);
883 if (extent0 != extent1) {
884 if (extent0 == 1) {
885 desc0_out->strides[i] = 0;
886 desc0_out->extents[i] = extent1;
887 } else {
888 TFLITE_DCHECK_EQ(extent1, 1);
889 desc1_out->strides[i] = 0;
890 desc1_out->extents[i] = extent0;
891 }
892 }
893 }
894 }
895
896 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)897 inline void NdArrayDescsForElementwiseBroadcast(
898 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
899 const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
900 NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
901 TFLITE_DCHECK(desc0_out != nullptr);
902 TFLITE_DCHECK(desc1_out != nullptr);
903 TFLITE_DCHECK(desc2_out != nullptr);
904
905 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
906 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
907 auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
908
909 // Copy dims to desc, calculating strides.
910 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
911 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
912 CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
913
914 // Walk over each dimension. If the extents are equal do nothing.
915 // Otherwise, set the desc with extent 1 to have extent equal to the other and
916 // stride 0.
917 for (int i = 0; i < N; ++i) {
918 const int extent0 = extended_input0_shape.Dims(i);
919 const int extent1 = extended_input1_shape.Dims(i);
920 const int extent2 = extended_input2_shape.Dims(i);
921
922 int extent = extent0;
923 if (extent1 != 1) extent = extent1;
924 if (extent2 != 1) extent = extent2;
925
926 TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
927 TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
928 TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
929
930 if (!(extent0 == extent1 && extent1 == extent2)) {
931 if (extent0 == 1) {
932 desc0_out->strides[i] = 0;
933 desc0_out->extents[i] = extent;
934 }
935 if (extent1 == 1) {
936 desc1_out->strides[i] = 0;
937 desc1_out->extents[i] = extent;
938 }
939 if (extent2 == 1) {
940 desc2_out->strides[i] = 0;
941 desc2_out->extents[i] = extent;
942 }
943 }
944 }
945 }
946
947 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
948 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
949 // re-writen as:
950 // for (int b = 0; b < output.extents[0]; ++b) {
951 // for (int y = 0; y < output.extents[1]; ++y) {
952 // for (int x = 0; x < output.extents[2]; ++x) {
953 // for (int c = 0; c < output.extents[3]; ++c) {
954 // calc({b,y,x,c});
955 // }
956 // }
957 // }
958 // }
959 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])960 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
961 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
962 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
963 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
964 }
965 }
966
967 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])968 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
969 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
970 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
971 calc(indexes);
972 }
973 }
974
975 // Execute the calc function in the innermost iteration based on the shape of
976 // the output. The calc function should take a single argument of type int[N].
977 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)978 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
979 int indexes[N] = {0};
980 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
981 }
982 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
983 // gemmlowp.
984 //
985 // Returns the runtime argument rounded down to the nearest multiple of
986 // the fixed Modulus.
987 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)988 Integer RoundDown(Integer i) {
989 return i - (i % Modulus);
990 }
991
992 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
993 // gemmlowp.
994 //
995 // Returns the runtime argument rounded up to the nearest multiple of
996 // the fixed Modulus.
997 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)998 Integer RoundUp(Integer i) {
999 return RoundDown<Modulus>(i + Modulus - 1);
1000 }
1001
1002 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
1003 // gemmlowp.
1004 //
1005 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
1006 template <typename Integer>
CeilQuotient(Integer a,Integer b)1007 Integer CeilQuotient(Integer a, Integer b) {
1008 return (a + b - 1) / b;
1009 }
1010
1011 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
1012 // the direct dependency of internal/optimized/ on gemmlowp.
1013 //
1014 // It computes a reasonable number of threads to use for a GEMM of shape
1015 // (rows, cols, depth).
1016 //
1017 // TODO(b/131910176): get rid of this function by switching each call site
1018 // to its own more sensible logic for its own workload.
1019 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)1020 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
1021 int depth) {
1022 // Early-exit in the default case where multi-threading is disabled.
1023 if (max_num_threads == 1) {
1024 return 1;
1025 }
1026
1027 // Ensure that each thread has KernelRows rows to process, if at all possible.
1028 int thread_count = std::min(max_num_threads, rows / KernelRows);
1029
1030 // Limit the number of threads according to the overall size of the problem.
1031 if (thread_count > 1) {
1032 // Empirically determined value.
1033 static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
1034
1035 // We can only multiply two out of three sizes without risking overflow
1036 const std::uint64_t cubic_size =
1037 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
1038
1039 thread_count = std::min(
1040 thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
1041 }
1042
1043 if (thread_count < 1) {
1044 thread_count = 1;
1045 }
1046
1047 assert(thread_count > 0 && thread_count <= max_num_threads);
1048 return thread_count;
1049 }
1050
1051 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1052 void optimized_ops_preload_l1_stream(const T* ptr) {
1053 #ifdef __GNUC__
1054 // builtin offered by GCC-compatible compilers including clang
1055 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1056 #else
1057 (void)ptr;
1058 #endif
1059 }
1060
1061 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1062 void optimized_ops_preload_l1_keep(const T* ptr) {
1063 #ifdef __GNUC__
1064 // builtin offered by GCC-compatible compilers including clang
1065 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1066 #else
1067 (void)ptr;
1068 #endif
1069 }
1070
1071 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1072 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1073 #ifdef __GNUC__
1074 // builtin offered by GCC-compatible compilers including clang
1075 __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1076 #else
1077 (void)ptr;
1078 #endif
1079 }
1080
1081 } // namespace tflite
1082
1083 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1084