1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23
24 #include <functional>
25
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30
31 namespace tflite {
32
33 constexpr int kReverseShift = -1;
34
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36 float* output_activation_min,
37 float* output_activation_max) {
38 switch (ac) {
39 case FusedActivationFunctionType::kNone:
40 *output_activation_min = std::numeric_limits<float>::lowest();
41 *output_activation_max = std::numeric_limits<float>::max();
42 break;
43 case FusedActivationFunctionType::kRelu:
44 *output_activation_min = 0.f;
45 *output_activation_max = std::numeric_limits<float>::max();
46 break;
47 case FusedActivationFunctionType::kRelu1:
48 *output_activation_min = -1.f;
49 *output_activation_max = 1.f;
50 break;
51 case FusedActivationFunctionType::kRelu6:
52 *output_activation_min = 0.f;
53 *output_activation_max = 6.f;
54 break;
55 }
56 }
57
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60 T output_activation_max) {
61 using std::max;
62 using std::min;
63 return min(max(x, output_activation_min), output_activation_max);
64 }
65
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69 float output_activation_min, output_activation_max;
70 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71 return ActivationFunctionWithMinMax(x, output_activation_min,
72 output_activation_max);
73 }
74
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76 const float* bias_data, int array_size,
77 float* array_data) {
78 // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79 // this with the Eigen one-liner:
80 // return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81 // This turned out to severely regress performance: +4ms (i.e. 8%) on
82 // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85 float* array_ptr = array_data;
86 float* array_end_ptr = array_ptr + array_size;
87 const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88 const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89 for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90 int i = 0;
91 for (; i <= bias_size - 16; i += 16) {
92 auto b0 = vld1q_f32(bias_data + i);
93 auto b1 = vld1q_f32(bias_data + i + 4);
94 auto b2 = vld1q_f32(bias_data + i + 8);
95 auto b3 = vld1q_f32(bias_data + i + 12);
96 auto a0 = vld1q_f32(array_ptr + i);
97 auto a1 = vld1q_f32(array_ptr + i + 4);
98 auto a2 = vld1q_f32(array_ptr + i + 8);
99 auto a3 = vld1q_f32(array_ptr + i + 12);
100 auto x0 = vaddq_f32(a0, b0);
101 auto x1 = vaddq_f32(a1, b1);
102 auto x2 = vaddq_f32(a2, b2);
103 auto x3 = vaddq_f32(a3, b3);
104 x0 = vmaxq_f32(clamp_min_vec, x0);
105 x1 = vmaxq_f32(clamp_min_vec, x1);
106 x2 = vmaxq_f32(clamp_min_vec, x2);
107 x3 = vmaxq_f32(clamp_min_vec, x3);
108 x0 = vminq_f32(clamp_max_vec, x0);
109 x1 = vminq_f32(clamp_max_vec, x1);
110 x2 = vminq_f32(clamp_max_vec, x2);
111 x3 = vminq_f32(clamp_max_vec, x3);
112 vst1q_f32(array_ptr + i, x0);
113 vst1q_f32(array_ptr + i + 4, x1);
114 vst1q_f32(array_ptr + i + 8, x2);
115 vst1q_f32(array_ptr + i + 12, x3);
116 }
117 for (; i <= bias_size - 4; i += 4) {
118 auto b = vld1q_f32(bias_data + i);
119 auto a = vld1q_f32(array_ptr + i);
120 auto x = vaddq_f32(a, b);
121 x = vmaxq_f32(clamp_min_vec, x);
122 x = vminq_f32(clamp_max_vec, x);
123 vst1q_f32(array_ptr + i, x);
124 }
125 for (; i < bias_size; i++) {
126 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127 clamp_min, clamp_max);
128 }
129 }
130 #else // not NEON
131 for (int array_offset = 0; array_offset < array_size;
132 array_offset += bias_size) {
133 for (int i = 0; i < bias_size; i++) {
134 array_data[array_offset + i] = ActivationFunctionWithMinMax(
135 array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136 }
137 }
138 #endif
139 }
140
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142 int32_t x, int32_t quantized_multiplier, int left_shift) {
143 using gemmlowp::RoundingDivideByPOT;
144 using gemmlowp::SaturatingRoundingDoublingHighMul;
145 return RoundingDivideByPOT(
146 SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150 int32_t x, int32_t quantized_multiplier, int left_shift) {
151 using gemmlowp::SaturatingRoundingDoublingHighMul;
152 return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153 quantized_multiplier);
154 }
155
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157 int32_t quantized_multiplier,
158 int shift) {
159 using gemmlowp::RoundingDivideByPOT;
160 using gemmlowp::SaturatingRoundingDoublingHighMul;
161 int left_shift = shift > 0 ? shift : 0;
162 int right_shift = shift > 0 ? 0 : -shift;
163 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164 x * (1 << left_shift), quantized_multiplier),
165 right_shift);
166 }
167
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169 int32_t quantized_multiplier,
170 int shift) {
171 // Inputs:
172 // - quantized_multiplier has fixed point at bit 31
173 // - shift is -31 to +7 (negative for right shift)
174 //
175 // Assumptions: The following input ranges are assumed
176 // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
177 // - scaling is chosen so final scaled result fits in int32_t
178 // - input x is in the range -(1<<47) <= x < (1<<47)
179 assert(quantized_multiplier >= 0);
180 assert(shift >= -31 && shift < 8);
181 assert(x >= -(static_cast<int64_t>(1) << 47) &&
182 x < (static_cast<int64_t>(1) << 47));
183
184 int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185 ? ((quantized_multiplier + (1 << 15)) >> 16)
186 : 0x7FFF;
187 int total_shift = 15 - shift;
188 x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189 int32_t result = x >> total_shift;
190 return result;
191 }
192
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196 int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197 const int left_shift = std::max(shift, 0);
198 const int right_shift = std::min(shift, 0);
199 int32x4x4_t result;
200
201 int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202 int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203 int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204
205 result.val[0] =
206 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207 multiplier_dup),
208 right_shift_dup);
209
210 result.val[1] =
211 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212 multiplier_dup),
213 right_shift_dup);
214
215 result.val[2] =
216 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217 multiplier_dup),
218 right_shift_dup);
219
220 result.val[3] =
221 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222 multiplier_dup),
223 right_shift_dup);
224
225 return result;
226 }
227 #endif
228
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231 static_assert(std::is_unsigned<T>::value,
232 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234 return integer_input ? __builtin_clz(integer_input)
235 : std::numeric_limits<T>::digits;
236 #else
237 if (integer_input == 0) {
238 return std::numeric_limits<T>::digits;
239 }
240
241 const T one_in_leading_positive = static_cast<T>(1)
242 << (std::numeric_limits<T>::digits - 1);
243 int leading_zeros = 0;
244 while (integer_input < one_in_leading_positive) {
245 integer_input <<= 1;
246 ++leading_zeros;
247 }
248 return leading_zeros;
249 #endif
250 }
251
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254 static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256 return integer_input ? __builtin_clrsb(integer_input)
257 : std::numeric_limits<T>::digits;
258 #else
259 using U = typename std::make_unsigned<T>::type;
260 return integer_input >= 0
261 ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262 : integer_input != std::numeric_limits<T>::min()
263 ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264 : 0;
265 #endif
266 }
267
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271 static_assert(std::is_integral<Integer>::value, "");
272 static_assert(std::is_signed<Integer>::value, "");
273 static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274 TFLITE_CHECK_GT(n, 0);
275 if (sizeof(Integer) == 4) {
276 return 30 - CountLeadingSignBits(n);
277 } else {
278 return 62 - CountLeadingSignBits(n);
279 }
280 }
281
282 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
283 // softmax
284 // func - the function to build the LUT for (e.g exp(x))
285 // min,max - table limits
286 // table - pointer to buffer
287 // num - number of elements in the LUT
gen_lut(double (* func)(double),double min,double max,int16_t * table,const int num)288 inline void gen_lut(double (*func)(double), double min, double max,
289 int16_t* table, const int num) {
290 // size of table should equal to num + 1
291 // last element only for slope calculation
292 double step = (max - min) / (num - 1);
293 double half_step = step / 2.0;
294 for (int i = 0; i < num - 1; i++) {
295 double sample_val = TfLiteRound(func(min + i * step) * 32768.0);
296 double midpoint_interp_val =
297 TfLiteRound((func(min + (i + 1) * step) * 32768.0 +
298 TfLiteRound(func(min + i * step) * 32768.0)) /
299 2.0);
300 double midpoint_val =
301 TfLiteRound(func(min + i * step + half_step) * 32768.0);
302 double midpoint_err = midpoint_interp_val - midpoint_val;
303 double bias = TfLiteRound(midpoint_err / 2.0);
304 table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
305 32767.0);
306 }
307 table[num - 1] = std::min<double>(
308 std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
309 }
310
311 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
312 // softmax
313 // func - the function to build the LUT for (e.g exp(x))
314 // min,max - table limits
315 // table - pointer to buffer
316 // num - number of elements in the LUT
gen_lut(float (* func)(float),float min,float max,int16_t * table,const int num)317 inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
318 const int num) {
319 // size of table should equal to num + 1
320 // last element only for slope calculation
321 float step = (max - min) / (num - 1);
322 float half_step = step / 2.0f;
323 for (int i = 0; i < num - 1; i++) {
324 float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
325 float midpoint_interp_val =
326 TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
327 TfLiteRound(func(min + i * step) * 32768.0f)) /
328 2.0f);
329 float midpoint_val =
330 TfLiteRound(func(min + i * step + half_step) * 32768.0f);
331 float midpoint_err = midpoint_interp_val - midpoint_val;
332 float bias = TfLiteRound(midpoint_err / 2.0f);
333 table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
334 32767.0f);
335 }
336 table[num - 1] = std::min<float>(
337 std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
338 }
339
340 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
generic_int16_table_lookup(int16_t value,const int16_t * lut)341 inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
342 // 512 base value, lut[513] only for calculate slope
343 uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
344 assert(index < 512 && "LUT index out of range.");
345 int16_t offset = value & 0x7f;
346
347 // base and slope are Q0.15
348 int16_t base = lut[index];
349 int16_t slope = lut[index + 1] - lut[index];
350
351 // Q0.15 * Q0.7 = Q0.22
352 // Round and convert from Q0.22 to Q0.15
353 int32_t delta = (static_cast<int32_t>(slope) * offset + 64) >> 7;
354
355 // Q0.15 + Q0.15
356 return base + delta;
357 }
358
359 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
360
361 // We use combined sigmoid and tanh look-up table, since
362 // tanh(x) = 2*sigmoid(2*x) -1.
363 // Both functions are symmetric, so the LUT table is only needed
364 // for the absolute value of the input.
365 static const uint16_t sigmoid_table_uint16[256] = {
366 32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
367 40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
368 46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
369 52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
370 56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
371 59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
372 61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
373 62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
374 63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
375 64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
376 64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
377 65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
378 65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
379 65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
380 65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
381 65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
382 65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
383 65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
384 65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
385 65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
386 65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
387 65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
388 65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
389 65534, 65534, 65535};
390
391 // TODO(b/77858996): Add these to gemmlowp.
392 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)393 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
394 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
395 return a;
396 }
397
398 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)399 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
400 std::int64_t a64 = a;
401 std::int64_t b64 = b;
402 std::int64_t sum = a64 + b64;
403 return static_cast<std::int32_t>(std::min(
404 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
405 std::max(
406 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
407 sum)));
408 }
409
410 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)411 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
412 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
413 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
414 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
415 SaturatingAddNonGemmlowp(a.raw(), b.raw()));
416 }
417
418 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)419 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
420 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
421 return a;
422 }
423
424 template <>
SaturatingSub(std::int16_t a,std::int16_t b)425 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
426 std::int32_t a32 = a;
427 std::int32_t b32 = b;
428 std::int32_t diff = a32 - b32;
429 return static_cast<std::int16_t>(
430 std::min(static_cast<int32_t>(32767),
431 std::max(static_cast<int32_t>(-32768), diff)));
432 }
433
434 template <>
SaturatingSub(std::int32_t a,std::int32_t b)435 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
436 std::int64_t a64 = a;
437 std::int64_t b64 = b;
438 std::int64_t diff = a64 - b64;
439 return static_cast<std::int32_t>(std::min(
440 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
441 std::max(
442 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
443 diff)));
444 }
445
446 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)447 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
448 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
449 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
450 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
451 SaturatingSub(a.raw(), b.raw()));
452 }
453 // End section to be moved to gemmlowp.
454
455 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)456 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
457 if (exponent == 0) {
458 return x;
459 }
460 using ScalarIntegerType =
461 typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
462 const IntegerType min =
463 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
464 const IntegerType max =
465 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
466 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
467
468 const std::int32_t threshold =
469 ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
470 const IntegerType positive_mask =
471 gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
472 const IntegerType negative_mask =
473 gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
474
475 IntegerType result = gemmlowp::ShiftLeft(x, exponent);
476 result = gemmlowp::SelectUsingMask(positive_mask, max, result);
477 result = gemmlowp::SelectUsingMask(negative_mask, min, result);
478 return result;
479 }
480
481 // If we want to leave IntegerBits fixed, then multiplication
482 // by a power of two has to be saturating/rounding, not exact anymore.
483 template <typename tRawType, int tIntegerBits>
484 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)485 SaturatingRoundingMultiplyByPOTParam(
486 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
487 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
488 SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
489 }
490
491 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)492 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
493 int16_t* multiplier_int16_t) {
494 TFLITE_DCHECK_GE(multiplier_int32_t, 0);
495 static constexpr int32_t kRoundingOffset = 1 << 15;
496 if (multiplier_int32_t >=
497 std::numeric_limits<int32_t>::max() - kRoundingOffset) {
498 *multiplier_int16_t = std::numeric_limits<int16_t>::max();
499 return;
500 }
501 const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
502 TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
503 TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
504 *multiplier_int16_t = result;
505 TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
506 }
507
508 // Minimum output bits to accommodate log of maximum input range. It actually
509 // does not matter if one considers, say, [-64,64] or [-64,64).
510 //
511 // For example, run this through Octave:
512 // [0:127; ...
513 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
514 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)515 constexpr int min_log_x_output_bits(int input_bits) {
516 return input_bits > 90 ? 7
517 : input_bits > 44 ? 6
518 : input_bits > 21 ? 5
519 : input_bits > 10 ? 4
520 : input_bits > 4 ? 3
521 : input_bits > 1 ? 2
522 : 1;
523 }
524
525 // Although currently the name of this function says that it cannot handle
526 // values less than 1, in practice it can handle as low as 1/x_max, where
527 // x_max is the largest representable input. In other words, the output range
528 // is symmetric.
529 template <int OutputIntegerBits, int InputIntegerBits>
530 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)531 log_x_for_x_greater_than_or_equal_to_1_impl(
532 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
533 // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
534 // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
535 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
536 // The reason for accumulating the result with an extra bit of headroom is
537 // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
538 // recip_denom will otherwise introduce an error.
539 static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
540 using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
541
542 const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
543 FixedPoint0, 1488522236, std::log(2.0));
544 const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
545 FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
546 const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
547 FixedPoint0, 1518500250, std::sqrt(0.5));
548 const FixedPoint0 one_quarter =
549 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
550
551 const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
552 FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
553 const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
554 FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
555 const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
556 FixedPoint0, 1057819769,
557 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
558 const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
559 FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
560
561 const FixedPointAccum shifted_quarter =
562 gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
563
564 // Reinterpret the input value as Q0.31, because we will figure out the
565 // required shift "ourselves" instead of using, say, Rescale.
566 FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
567 // z_a_pow_2 = input_integer_bits - z_a_headroom;
568 int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
569 FixedPoint0 r_a_tmp =
570 SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
571 const int32_t r_a_raw =
572 SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
573 // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
574 // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
575 // InputIntegerBits - z_b_headroom - 0.25);
576 const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
577 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
578 InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
579 shifted_quarter);
580
581 // z_b is treated like z_a, but premultiplying by sqrt(0.5).
582 FixedPoint0 z_b = z_a * sqrt_half;
583 int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
584 const int32_t r_b_raw =
585 SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
586 const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
587 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
588 InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
589 shifted_quarter);
590
591 const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
592 const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
593 std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
594
595 const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
596 FixedPoint0 q = r - sqrt_sqrt_half;
597 q = q + q;
598
599 const FixedPoint0 common_sq = q * q;
600 const FixedPoint0 num = q * r + q * common_sq * alpha_n;
601 const FixedPoint0 denom_minus_one_0 =
602 p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
603 const FixedPoint0 recip_denom =
604 one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
605
606 const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
607 return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
608 num_scaled * recip_denom);
609 }
610
611 template <int OutputIntegerBits, int InputIntegerBits>
612 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)613 log_x_for_x_greater_than_or_equal_to_1(
614 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
615 static_assert(
616 OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
617 "Output integer bits must be sufficient to accommodate logs of inputs.");
618 return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
619 InputIntegerBits>(
620 input_val);
621 }
622
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)623 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
624 int* num_bits_over_unit) {
625 int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
626 // This is the number of bits to the left of the binary point above 1.0.
627 // Consider x=1.25. In that case shifted_scale=0.8 and
628 // no later adjustment will be needed.
629 *num_bits_over_unit = x_integer_digits - headroom_plus_one;
630 const int32_t shifted_sum_minus_one =
631 static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
632 (static_cast<uint32_t>(1) << 31));
633
634 gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
635 gemmlowp::one_over_one_plus_x_for_x_in_0_1(
636 gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
637 return shifted_scale.raw();
638 }
639
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)640 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
641 int32_t* output_inv_sqrt,
642 int* output_shift) {
643 TFLITE_DCHECK_GE(input, 0);
644 if (input <= 1) {
645 // Handle the input value 1 separately to avoid overflow in that case
646 // in the general computation below (b/143972021). Also handle 0 as if it
647 // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
648 // but rare/unrealistic input value. We can expect both to occur in some
649 // incompletely trained models, but probably not in fully trained models.
650 *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
651 *output_shift = 0;
652 return;
653 }
654 TFLITE_DCHECK_GT(input, 1);
655 *output_shift = 11;
656 while (input >= (1 << 29)) {
657 input /= 4;
658 ++*output_shift;
659 }
660 const unsigned max_left_shift_bits =
661 CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
662 const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
663 const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
664 *output_shift -= left_shift_bit_pairs;
665 input <<= 2 * left_shift_bit_pairs;
666 TFLITE_DCHECK_GE(input, (1 << 27));
667 TFLITE_DCHECK_LT(input, (1 << 29));
668 using gemmlowp::FixedPoint;
669 using gemmlowp::Rescale;
670 using gemmlowp::SaturatingRoundingMultiplyByPOT;
671 // Using 3 integer bits gives us enough room for the internal arithmetic in
672 // this Newton-Raphson iteration.
673 using F3 = FixedPoint<int32_t, 3>;
674 using F0 = FixedPoint<int32_t, 0>;
675 const F3 fixedpoint_input = F3::FromRaw(input >> 1);
676 const F3 fixedpoint_half_input =
677 SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
678 const F3 fixedpoint_half_three =
679 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
680 // Newton-Raphson iteration
681 // Naive unoptimized starting guess: x = 1
682 F3 x = F3::One();
683 // Naive unoptimized number of iterations: 5
684 for (int i = 0; i < 5; i++) {
685 const F3 x3 = Rescale<3>(x * x * x);
686 x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
687 }
688 const F0 fixedpoint_half_sqrt_2 =
689 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
690 x = x * fixedpoint_half_sqrt_2;
691 *output_inv_sqrt = x.raw();
692 if (*output_shift < 0) {
693 *output_inv_sqrt <<= -*output_shift;
694 *output_shift = 0;
695 }
696 // Convert right shift (right is positive) to left shift.
697 *output_shift *= reverse_shift;
698 }
699
700 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
701 // BROADCASTING.
702 //
703 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
704 // rectangular array of numbers.
705 //
706 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
707 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
708 // to enable simple unoptimized implementations of element-wise broadcasting
709 // operations.
710 template <int N>
711 struct NdArrayDesc {
712 // The "extent" of each dimension. Indices along dimension d must be in the
713 // half-open interval [0, extents[d]).
714 int extents[N];
715
716 // The number of *elements* (not bytes) between consecutive indices of each
717 // dimension.
718 int strides[N];
719 };
720
721 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
722 // BROADCASTING.
723 //
724 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)725 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
726 int i3) {
727 TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
728 TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
729 TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
730 TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
731 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
732 i3 * desc.strides[3];
733 }
734
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])735 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
736 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
737 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
738 indexes[4] * desc.strides[4];
739 }
740
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])741 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
742 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
743 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
744 indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
745 indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
746 }
747
748 // Given the dimensions of the operands for an element-wise binary broadcast,
749 // adjusts them so that they can be directly iterated over with simple loops.
750 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
751 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
752 //
753 // This function assumes that the two input shapes are compatible up to
754 // broadcasting and the shorter one has already been prepended with 1s to be the
755 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
756 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
757 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
758 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
759 //
760 // When two shapes are compatible up to broadcasting, for each dimension d,
761 // the input extents are either equal, or one of them is 1.
762 //
763 // This function performs the following for each dimension d:
764 // - If the extents are equal, then do nothing since the loop that walks over
765 // both of the input arrays is correct.
766 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
767 // and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
768 // array0 to be referenced *at any index* in dimension d and still access the
769 // same slice.
770 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)771 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
772 const Dims<N>& input1_dims,
773 NdArrayDesc<N>* desc0_out,
774 NdArrayDesc<N>* desc1_out) {
775 TFLITE_DCHECK(desc0_out != nullptr);
776 TFLITE_DCHECK(desc1_out != nullptr);
777
778 // Copy dims to desc.
779 for (int i = 0; i < N; ++i) {
780 desc0_out->extents[i] = input0_dims.sizes[i];
781 desc0_out->strides[i] = input0_dims.strides[i];
782 desc1_out->extents[i] = input1_dims.sizes[i];
783 desc1_out->strides[i] = input1_dims.strides[i];
784 }
785
786 // Walk over each dimension. If the extents are equal do nothing.
787 // Otherwise, set the desc with extent 1 to have extent equal to the other and
788 // stride 0.
789 for (int i = 0; i < N; ++i) {
790 const int extent0 = ArraySize(input0_dims, i);
791 const int extent1 = ArraySize(input1_dims, i);
792 if (extent0 != extent1) {
793 if (extent0 == 1) {
794 desc0_out->strides[i] = 0;
795 desc0_out->extents[i] = extent1;
796 } else {
797 TFLITE_DCHECK_EQ(extent1, 1);
798 desc1_out->strides[i] = 0;
799 desc1_out->extents[i] = extent0;
800 }
801 }
802 }
803 }
804
805 // Copies dims to desc, calculating strides.
806 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)807 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
808 NdArrayDesc<N>* desc_out) {
809 int desc_stride = 1;
810 for (int i = N - 1; i >= 0; --i) {
811 desc_out->extents[i] = input_shape.Dims(i);
812 desc_out->strides[i] = desc_stride;
813 desc_stride *= input_shape.Dims(i);
814 }
815 }
816
817 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)818 inline void NdArrayDescsForElementwiseBroadcast(
819 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
820 NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
821 TFLITE_DCHECK(desc0_out != nullptr);
822 TFLITE_DCHECK(desc1_out != nullptr);
823
824 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
825 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
826
827 // Copy dims to desc, calculating strides.
828 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
829 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
830
831 // Walk over each dimension. If the extents are equal do nothing.
832 // Otherwise, set the desc with extent 1 to have extent equal to the other and
833 // stride 0.
834 for (int i = 0; i < N; ++i) {
835 const int extent0 = extended_input0_shape.Dims(i);
836 const int extent1 = extended_input1_shape.Dims(i);
837 if (extent0 != extent1) {
838 if (extent0 == 1) {
839 desc0_out->strides[i] = 0;
840 desc0_out->extents[i] = extent1;
841 } else {
842 TFLITE_DCHECK_EQ(extent1, 1);
843 desc1_out->strides[i] = 0;
844 desc1_out->extents[i] = extent0;
845 }
846 }
847 }
848 }
849
850 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)851 inline void NdArrayDescsForElementwiseBroadcast(
852 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
853 const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
854 NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
855 TFLITE_DCHECK(desc0_out != nullptr);
856 TFLITE_DCHECK(desc1_out != nullptr);
857 TFLITE_DCHECK(desc2_out != nullptr);
858
859 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
860 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
861 auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
862
863 // Copy dims to desc, calculating strides.
864 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
865 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
866 CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
867
868 // Walk over each dimension. If the extents are equal do nothing.
869 // Otherwise, set the desc with extent 1 to have extent equal to the other and
870 // stride 0.
871 for (int i = 0; i < N; ++i) {
872 const int extent0 = extended_input0_shape.Dims(i);
873 const int extent1 = extended_input1_shape.Dims(i);
874 const int extent2 = extended_input2_shape.Dims(i);
875
876 int extent = extent0;
877 if (extent1 != 1) extent = extent1;
878 if (extent2 != 1) extent = extent2;
879
880 TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
881 TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
882 TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
883
884 if (!(extent0 == extent1 && extent1 == extent2)) {
885 if (extent0 == 1) {
886 desc0_out->strides[i] = 0;
887 desc0_out->extents[i] = extent;
888 }
889 if (extent1 == 1) {
890 desc1_out->strides[i] = 0;
891 desc1_out->extents[i] = extent;
892 }
893 if (extent2 == 1) {
894 desc2_out->strides[i] = 0;
895 desc2_out->extents[i] = extent;
896 }
897 }
898 }
899 }
900
901 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
902 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
903 // re-writen as:
904 // for (int b = 0; b < output.extents[0]; ++b) {
905 // for (int y = 0; y < output.extents[1]; ++y) {
906 // for (int x = 0; x < output.extents[2]; ++x) {
907 // for (int c = 0; c < output.extents[3]; ++c) {
908 // calc({b,y,x,c});
909 // }
910 // }
911 // }
912 // }
913 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])914 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
915 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
916 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
917 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
918 }
919 }
920
921 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])922 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
923 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
924 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
925 calc(indexes);
926 }
927 }
928
929 // Execute the calc function in the innermost iteration based on the shape of
930 // the output. The calc function should take a single argument of type int[N].
931 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)932 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
933 int indexes[N] = {0};
934 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
935 }
936 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
937 // gemmlowp.
938 //
939 // Returns the runtime argument rounded down to the nearest multiple of
940 // the fixed Modulus.
941 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)942 Integer RoundDown(Integer i) {
943 return i - (i % Modulus);
944 }
945
946 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
947 // gemmlowp.
948 //
949 // Returns the runtime argument rounded up to the nearest multiple of
950 // the fixed Modulus.
951 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)952 Integer RoundUp(Integer i) {
953 return RoundDown<Modulus>(i + Modulus - 1);
954 }
955
956 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
957 // gemmlowp.
958 //
959 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
960 template <typename Integer>
CeilQuotient(Integer a,Integer b)961 Integer CeilQuotient(Integer a, Integer b) {
962 return (a + b - 1) / b;
963 }
964
965 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
966 // the direct dependency of internal/optimized/ on gemmlowp.
967 //
968 // It computes a reasonable number of threads to use for a GEMM of shape
969 // (rows, cols, depth).
970 //
971 // TODO(b/131910176): get rid of this function by switching each call site
972 // to its own more sensible logic for its own workload.
973 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)974 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
975 int depth) {
976 // Early-exit in the default case where multi-threading is disabled.
977 if (max_num_threads == 1) {
978 return 1;
979 }
980
981 // Ensure that each thread has KernelRows rows to process, if at all possible.
982 int thread_count = std::min(max_num_threads, rows / KernelRows);
983
984 // Limit the number of threads according to the overall size of the problem.
985 if (thread_count > 1) {
986 // Empirically determined value.
987 static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
988
989 // We can only multiply two out of three sizes without risking overflow
990 const std::uint64_t cubic_size =
991 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
992
993 thread_count = std::min(
994 thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
995 }
996
997 if (thread_count < 1) {
998 thread_count = 1;
999 }
1000
1001 assert(thread_count > 0 && thread_count <= max_num_threads);
1002 return thread_count;
1003 }
1004
1005 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1006 void optimized_ops_preload_l1_stream(const T* ptr) {
1007 #ifdef __GNUC__
1008 // builtin offered by GCC-compatible compilers including clang
1009 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1010 #else
1011 (void)ptr;
1012 #endif
1013 }
1014
1015 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1016 void optimized_ops_preload_l1_keep(const T* ptr) {
1017 #ifdef __GNUC__
1018 // builtin offered by GCC-compatible compilers including clang
1019 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1020 #else
1021 (void)ptr;
1022 #endif
1023 }
1024
1025 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1026 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1027 #ifdef __GNUC__
1028 // builtin offered by GCC-compatible compilers including clang
1029 __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1030 #else
1031 (void)ptr;
1032 #endif
1033 }
1034
1035 } // namespace tflite
1036
1037 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1038