1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17
18 #include <algorithm>
19 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
20 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
21 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
22 #endif
23 #endif
24
25 #include <functional>
26
27 #include "fixedpoint/fixedpoint.h"
28 #include "tensorflow/lite/kernels/internal/cppmath.h"
29 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
30 #include "tensorflow/lite/kernels/internal/types.h"
31
32 namespace tflite {
33
34 constexpr int kReverseShift = -1;
35
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)36 inline void GetActivationMinMax(FusedActivationFunctionType ac,
37 float* output_activation_min,
38 float* output_activation_max) {
39 switch (ac) {
40 case FusedActivationFunctionType::kNone:
41 *output_activation_min = std::numeric_limits<float>::lowest();
42 *output_activation_max = std::numeric_limits<float>::max();
43 break;
44 case FusedActivationFunctionType::kRelu:
45 *output_activation_min = 0.f;
46 *output_activation_max = std::numeric_limits<float>::max();
47 break;
48 case FusedActivationFunctionType::kRelu1:
49 *output_activation_min = -1.f;
50 *output_activation_max = 1.f;
51 break;
52 case FusedActivationFunctionType::kRelu6:
53 *output_activation_min = 0.f;
54 *output_activation_max = 6.f;
55 break;
56 }
57 }
58
59 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)60 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
61 T output_activation_max) {
62 using std::max;
63 using std::min;
64 return min(max(x, output_activation_min), output_activation_max);
65 }
66
67 // Legacy function, left for compatibility only.
68 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)69 float ActivationFunction(float x) {
70 float output_activation_min, output_activation_max;
71 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
72 return ActivationFunctionWithMinMax(x, output_activation_min,
73 output_activation_max);
74 }
75
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)76 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
77 const float* bias_data, int array_size,
78 float* array_data) {
79 if (bias_size == 0) return;
80 // Note: see b/132215220: in May 2019 we thought it would be OK to replace
81 // this with the Eigen one-liner:
82 // return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
83 // This turned out to severely regress performance: +4ms (i.e. 8%) on
84 // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
85 TFLITE_DCHECK_EQ((array_size % bias_size), 0);
86 #ifdef USE_NEON
87 float* array_ptr = array_data;
88 float* array_end_ptr = array_ptr + array_size;
89 const auto clamp_min_vec = vdupq_n_f32(clamp_min);
90 const auto clamp_max_vec = vdupq_n_f32(clamp_max);
91 for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
92 int i = 0;
93 for (; i <= bias_size - 16; i += 16) {
94 auto b0 = vld1q_f32(bias_data + i);
95 auto b1 = vld1q_f32(bias_data + i + 4);
96 auto b2 = vld1q_f32(bias_data + i + 8);
97 auto b3 = vld1q_f32(bias_data + i + 12);
98 auto a0 = vld1q_f32(array_ptr + i);
99 auto a1 = vld1q_f32(array_ptr + i + 4);
100 auto a2 = vld1q_f32(array_ptr + i + 8);
101 auto a3 = vld1q_f32(array_ptr + i + 12);
102 auto x0 = vaddq_f32(a0, b0);
103 auto x1 = vaddq_f32(a1, b1);
104 auto x2 = vaddq_f32(a2, b2);
105 auto x3 = vaddq_f32(a3, b3);
106 x0 = vmaxq_f32(clamp_min_vec, x0);
107 x1 = vmaxq_f32(clamp_min_vec, x1);
108 x2 = vmaxq_f32(clamp_min_vec, x2);
109 x3 = vmaxq_f32(clamp_min_vec, x3);
110 x0 = vminq_f32(clamp_max_vec, x0);
111 x1 = vminq_f32(clamp_max_vec, x1);
112 x2 = vminq_f32(clamp_max_vec, x2);
113 x3 = vminq_f32(clamp_max_vec, x3);
114 vst1q_f32(array_ptr + i, x0);
115 vst1q_f32(array_ptr + i + 4, x1);
116 vst1q_f32(array_ptr + i + 8, x2);
117 vst1q_f32(array_ptr + i + 12, x3);
118 }
119 for (; i <= bias_size - 4; i += 4) {
120 auto b = vld1q_f32(bias_data + i);
121 auto a = vld1q_f32(array_ptr + i);
122 auto x = vaddq_f32(a, b);
123 x = vmaxq_f32(clamp_min_vec, x);
124 x = vminq_f32(clamp_max_vec, x);
125 vst1q_f32(array_ptr + i, x);
126 }
127 for (; i < bias_size; i++) {
128 array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
129 clamp_min, clamp_max);
130 }
131 }
132 #else // not NEON
133 for (int array_offset = 0; array_offset < array_size;
134 array_offset += bias_size) {
135 for (int i = 0; i < bias_size; i++) {
136 array_data[array_offset + i] = ActivationFunctionWithMinMax(
137 array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
138 }
139 }
140 #endif
141 }
142
143 // Single-rounding MultiplyByQuantizedMultiplier
144 #if TFLITE_SINGLE_ROUNDING
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)145 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
146 int32_t quantized_multiplier,
147 int shift) {
148 TFLITE_DCHECK(quantized_multiplier >= 0);
149 TFLITE_DCHECK(shift >= -31 && shift <= 30);
150
151 const int64_t total_shift = 31 - shift;
152 const int64_t round = static_cast<int64_t>(1) << (total_shift - 1);
153 int64_t result = x * static_cast<int64_t>(quantized_multiplier) + round;
154 result = result >> total_shift;
155
156 TFLITE_DCHECK(result >= std::numeric_limits<int32_t>::min() &&
157 result <= std::numeric_limits<int32_t>::max());
158 return static_cast<int32_t>(result);
159 }
160
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int shift)161 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
162 int32_t x, int32_t quantized_multiplier, int shift) {
163 TFLITE_DCHECK_LE(shift, 0);
164 return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
165 }
166
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int shift)167 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
168 int32_t x, int32_t quantized_multiplier, int shift) {
169 TFLITE_DCHECK_GE(shift, 0);
170 return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
171 }
172
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)173 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
174 int32_t quantized_multiplier,
175 int shift) {
176 // Inputs:
177 // - quantized_multiplier has fixed point at bit 31
178 // - shift is -31 to +7 (negative for right shift)
179 //
180 // Assumptions: The following input ranges are assumed
181 // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
182 // - scaling is chosen so final scaled result fits in int32_t
183 // - input x is in the range -(1<<47) <= x < (1<<47)
184 TFLITE_DCHECK(quantized_multiplier >= 0);
185 TFLITE_DCHECK(shift >= -31 && shift < 8);
186 TFLITE_DCHECK(x >= -(static_cast<int64_t>(1) << 47) &&
187 x < (static_cast<int64_t>(1) << 47));
188
189 const int32_t reduced_multiplier =
190 (quantized_multiplier < 0x7FFF0000)
191 ? ((quantized_multiplier + (1 << 15)) >> 16)
192 : 0x7FFF;
193 const int64_t total_shift = 15 - shift;
194 const int64_t round = static_cast<int64_t>(1) << (total_shift - 1);
195 int64_t result = x * static_cast<int64_t>(reduced_multiplier) + round;
196 result = result >> total_shift;
197
198 TFLITE_DCHECK(result >= std::numeric_limits<int32_t>::min() &&
199 result <= std::numeric_limits<int32_t>::max());
200 return static_cast<int32_t>(result);
201 }
202
203 #ifdef USE_NEON
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)204 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
205 int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
206 TFLITE_DCHECK(quantized_multiplier >= 0);
207
208 const int right_shift = std::min(-1, shift);
209 const int left_shift = shift - right_shift;
210
211 const int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
212 const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
213 const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
214
215 int32x4x4_t result;
216 result.val[0] = vrshlq_s32(
217 vqdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), multiplier_dup),
218 right_shift_dup);
219
220 result.val[1] = vrshlq_s32(
221 vqdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), multiplier_dup),
222 right_shift_dup);
223
224 result.val[2] = vrshlq_s32(
225 vqdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), multiplier_dup),
226 right_shift_dup);
227
228 result.val[3] = vrshlq_s32(
229 vqdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), multiplier_dup),
230 right_shift_dup);
231
232 return result;
233 }
234 #endif // USE_NEON
235 // Double-rounding MultiplyByQuantizedMultiplier
236 #else
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)237 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
238 int32_t x, int32_t quantized_multiplier, int left_shift) {
239 using gemmlowp::RoundingDivideByPOT;
240 using gemmlowp::SaturatingRoundingDoublingHighMul;
241 return RoundingDivideByPOT(
242 SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
243 }
244
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)245 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
246 int32_t x, int32_t quantized_multiplier, int left_shift) {
247 using gemmlowp::SaturatingRoundingDoublingHighMul;
248 return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
249 quantized_multiplier);
250 }
251
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)252 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
253 int32_t quantized_multiplier,
254 int shift) {
255 using gemmlowp::RoundingDivideByPOT;
256 using gemmlowp::SaturatingRoundingDoublingHighMul;
257 int left_shift = shift > 0 ? shift : 0;
258 int right_shift = shift > 0 ? 0 : -shift;
259 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
260 x * (1 << left_shift), quantized_multiplier),
261 right_shift);
262 }
263
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)264 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
265 int32_t quantized_multiplier,
266 int shift) {
267 // Inputs:
268 // - quantized_multiplier has fixed point at bit 31
269 // - shift is -31 to +7 (negative for right shift)
270 //
271 // Assumptions: The following input ranges are assumed
272 // - quantize_scale>=0 (the usual range is (1<<30) to (1>>31)-1)
273 // - scaling is chosen so final scaled result fits in int32_t
274 // - input x is in the range -(1<<47) <= x < (1<<47)
275 assert(quantized_multiplier >= 0);
276 assert(shift >= -31 && shift < 8);
277 assert(x >= -(static_cast<int64_t>(1) << 47) &&
278 x < (static_cast<int64_t>(1) << 47));
279
280 int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
281 ? ((quantized_multiplier + (1 << 15)) >> 16)
282 : 0x7FFF;
283 int total_shift = 15 - shift;
284 x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
285 int32_t result = x >> total_shift;
286 return result;
287 }
288
289 #ifdef USE_NEON
290 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)291 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
292 int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
293 const int left_shift = std::max(shift, 0);
294 const int right_shift = std::min(shift, 0);
295 int32x4x4_t result;
296
297 int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
298 int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
299 int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
300
301 result.val[0] =
302 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
303 multiplier_dup),
304 right_shift_dup);
305
306 result.val[1] =
307 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
308 multiplier_dup),
309 right_shift_dup);
310
311 result.val[2] =
312 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
313 multiplier_dup),
314 right_shift_dup);
315
316 result.val[3] =
317 vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
318 multiplier_dup),
319 right_shift_dup);
320
321 return result;
322 }
323 #endif // USE_NEON
324 #endif // TFLITE_SINGLE_ROUNDING
325
326 template <typename T>
CountLeadingZeros(T integer_input)327 int CountLeadingZeros(T integer_input) {
328 static_assert(std::is_unsigned<T>::value,
329 "Only unsigned integer types handled.");
330 #if defined(__GNUC__)
331 return integer_input ? __builtin_clz(integer_input)
332 : std::numeric_limits<T>::digits;
333 #else
334 if (integer_input == 0) {
335 return std::numeric_limits<T>::digits;
336 }
337
338 const T one_in_leading_positive = static_cast<T>(1)
339 << (std::numeric_limits<T>::digits - 1);
340 int leading_zeros = 0;
341 while (integer_input < one_in_leading_positive) {
342 integer_input <<= 1;
343 ++leading_zeros;
344 }
345 return leading_zeros;
346 #endif
347 }
348
349 template <typename T>
CountLeadingSignBits(T integer_input)350 inline int CountLeadingSignBits(T integer_input) {
351 static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
352 #if defined(__GNUC__) && !defined(__clang__)
353 return integer_input ? __builtin_clrsb(integer_input)
354 : std::numeric_limits<T>::digits;
355 #else
356 using U = typename std::make_unsigned<T>::type;
357 return integer_input >= 0
358 ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
359 : integer_input != std::numeric_limits<T>::min()
360 ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
361 : 0;
362 #endif
363 }
364
365 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
366 template <typename Integer>
FloorLog2(Integer n)367 inline Integer FloorLog2(Integer n) {
368 static_assert(std::is_integral<Integer>::value, "");
369 static_assert(std::is_signed<Integer>::value, "");
370 static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
371 TFLITE_CHECK_GT(n, 0);
372 if (sizeof(Integer) == 4) {
373 return 30 - CountLeadingSignBits(n);
374 } else {
375 return 62 - CountLeadingSignBits(n);
376 }
377 }
378
379 // The size of the LUT depends on the type of input. For int8 inputs a simple
380 // 256 entries LUT is used. For int16 inputs the high 9 bits are used for
381 // indexing and the 7 remaining bits are used for interpolation. We thus use a
382 // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
383 // to interpolate the last value.
384 template <typename LutInT>
lut_size()385 constexpr int lut_size() {
386 static_assert(std::is_same<LutInT, int8_t>::value ||
387 std::is_same<LutInT, int16_t>::value,
388 "Only LUTs with int8 or int16 inputs are supported.");
389 return std::is_same<LutInT, int8_t>::value ? 256 : 513;
390 }
391
392 // Generate a LUT for 'func' which can be used to approximate functions like
393 // exp, log, ...
394 //
395 // - func: the function to build the LUT for (e.g exp(x))
396 // - input_min, input_max: range of the func inputs
397 // - output_min, output_max: range of the func outputs
398 // - lut: pointer to the LUT table to fill, the table must be of size
399 // lut_size<LutInT>()
400 template <typename FloatT, typename LutInT, typename LutOutT>
gen_lut(FloatT (* func)(FloatT),FloatT input_min,FloatT input_max,FloatT output_min,FloatT output_max,LutOutT * lut)401 inline void gen_lut(FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
402 FloatT output_min, FloatT output_max, LutOutT* lut) {
403 static_assert(std::is_same<LutInT, int8_t>::value ||
404 std::is_same<LutInT, int16_t>::value,
405 "Only LUTs with int8 or int16 inputs are supported.");
406 static_assert(std::is_same<LutOutT, int8_t>::value ||
407 std::is_same<LutOutT, int16_t>::value,
408 "Only LUTs with int8 or int16 outputs are supported.");
409 static_assert(std::is_floating_point<FloatT>::value,
410 "FloatT must be a floating-point type.");
411
412 const int nb_steps = std::is_same<LutInT, int8_t>::value ? 256 : 512;
413 const FloatT step = (input_max - input_min) / nb_steps;
414 const FloatT half_step = step / 2;
415 const FloatT output_scaling_inv =
416 static_cast<FloatT>(std::numeric_limits<LutOutT>::max() -
417 std::numeric_limits<LutOutT>::min() + 1) /
418 (output_max - output_min);
419 const FloatT table_min =
420 static_cast<FloatT>(std::numeric_limits<LutOutT>::min());
421 const FloatT table_max =
422 static_cast<FloatT>(std::numeric_limits<LutOutT>::max());
423
424 for (int i = 0; i < nb_steps; i++) {
425 const FloatT val = func(input_min + i * step);
426 const FloatT val_midpoint = func(input_min + i * step + half_step);
427 const FloatT val_next = func(input_min + (i + 1) * step);
428
429 const FloatT sample_val = TfLiteRound(val * output_scaling_inv);
430 const FloatT midpoint_interp_val =
431 TfLiteRound((val_next * output_scaling_inv +
432 TfLiteRound(val * output_scaling_inv)) /
433 2);
434 const FloatT midpoint_val = TfLiteRound(val_midpoint * output_scaling_inv);
435 const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
436 const FloatT bias = TfLiteRound(midpoint_err / 2);
437
438 lut[i] = static_cast<LutOutT>(std::min<FloatT>(
439 std::max<FloatT>(sample_val - bias, table_min), table_max));
440 }
441
442 const bool with_extra_interpolation_value =
443 std::is_same<LutInT, int16_t>::value;
444 if (with_extra_interpolation_value) {
445 lut[nb_steps] = static_cast<LutOutT>(std::min<FloatT>(
446 std::max<FloatT>(TfLiteRound(func(input_max) * output_scaling_inv),
447 table_min),
448 table_max));
449 }
450 }
451
452 // LUT must have 513 values
453 template <typename LutOutT>
lut_lookup_with_interpolation(int16_t value,const LutOutT * lut)454 inline LutOutT lut_lookup_with_interpolation(int16_t value,
455 const LutOutT* lut) {
456 static_assert(std::is_same<LutOutT, int8_t>::value ||
457 std::is_same<LutOutT, int16_t>::value,
458 "Only LUTs with int8 or int16 outputs are supported.");
459 // 512 base values, lut[513] is only used to calculate the slope
460 const uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
461 assert(index < 512 && "LUT index out of range.");
462 const int16_t offset = value & 0x7f;
463
464 // Base and slope are Q0.x
465 const LutOutT base = lut[index];
466 const LutOutT slope = lut[index + 1] - lut[index];
467
468 // Q0.x * Q0.7 = Q0.(x + 7)
469 // Round and convert from Q0.(x + 7) to Q0.x
470 const int delta = (slope * offset + 64) >> 7;
471
472 // Q0.15 + Q0.15
473 return static_cast<LutOutT>(base + delta);
474 }
475
476 // int16_t -> int16_t table lookup with interpolation
477 // LUT must have 513 values
lut_lookup(int16_t value,const int16_t * lut)478 inline int16_t lut_lookup(int16_t value, const int16_t* lut) {
479 return lut_lookup_with_interpolation(value, lut);
480 }
481
482 // int16_t -> int8_t table lookup with interpolation
483 // LUT must have 513 values
lut_lookup(int16_t value,const int8_t * lut)484 inline int8_t lut_lookup(int16_t value, const int8_t* lut) {
485 return lut_lookup_with_interpolation(value, lut);
486 }
487
488 // int8_t -> int8_t table lookup without interpolation
489 // LUT must have 256 values
lut_lookup(int8_t value,const int8_t * lut)490 inline int8_t lut_lookup(int8_t value, const int8_t* lut) {
491 return lut[128 + value];
492 }
493
494 // int8_t -> int16_t table lookup without interpolation
495 // LUT must have 256 values
lut_lookup(int8_t value,const int16_t * lut)496 inline int16_t lut_lookup(int8_t value, const int16_t* lut) {
497 return lut[128 + value];
498 }
499
500 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
501
502 // We use combined sigmoid and tanh look-up table, since
503 // tanh(x) = 2*sigmoid(2*x) -1.
504 // Both functions are symmetric, so the LUT table is only needed
505 // for the absolute value of the input.
506 static const uint16_t sigmoid_table_uint16[256] = {
507 32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
508 40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
509 46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
510 52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
511 56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
512 59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
513 61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
514 62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
515 63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
516 64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
517 64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
518 65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
519 65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
520 65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
521 65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
522 65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
523 65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
524 65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
525 65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
526 65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
527 65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
528 65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
529 65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
530 65534, 65534, 65535};
531
532 // TODO(b/77858996): Add these to gemmlowp.
533 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)534 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
535 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
536 return a;
537 }
538
539 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)540 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
541 std::int64_t a64 = a;
542 std::int64_t b64 = b;
543 std::int64_t sum = a64 + b64;
544 return static_cast<std::int32_t>(std::min(
545 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
546 std::max(
547 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
548 sum)));
549 }
550
551 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)552 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
553 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
554 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
555 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
556 SaturatingAddNonGemmlowp(a.raw(), b.raw()));
557 }
558
559 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)560 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
561 static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
562 return a;
563 }
564
565 template <>
SaturatingSub(std::int16_t a,std::int16_t b)566 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
567 std::int32_t a32 = a;
568 std::int32_t b32 = b;
569 std::int32_t diff = a32 - b32;
570 return static_cast<std::int16_t>(
571 std::min(static_cast<int32_t>(32767),
572 std::max(static_cast<int32_t>(-32768), diff)));
573 }
574
575 template <>
SaturatingSub(std::int32_t a,std::int32_t b)576 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
577 std::int64_t a64 = a;
578 std::int64_t b64 = b;
579 std::int64_t diff = a64 - b64;
580 return static_cast<std::int32_t>(std::min(
581 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
582 std::max(
583 static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
584 diff)));
585 }
586
587 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)588 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
589 gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
590 gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
591 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
592 SaturatingSub(a.raw(), b.raw()));
593 }
594 // End section to be moved to gemmlowp.
595
596 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)597 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
598 if (exponent == 0) {
599 return x;
600 }
601 using ScalarIntegerType =
602 typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
603 const IntegerType min =
604 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
605 const IntegerType max =
606 gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
607 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
608
609 const std::int32_t threshold =
610 ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
611 const IntegerType positive_mask =
612 gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
613 const IntegerType negative_mask =
614 gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
615
616 IntegerType result = gemmlowp::ShiftLeft(x, exponent);
617 result = gemmlowp::SelectUsingMask(positive_mask, max, result);
618 result = gemmlowp::SelectUsingMask(negative_mask, min, result);
619 return result;
620 }
621
622 // If we want to leave IntegerBits fixed, then multiplication
623 // by a power of two has to be saturating/rounding, not exact anymore.
624 template <typename tRawType, int tIntegerBits>
625 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)626 SaturatingRoundingMultiplyByPOTParam(
627 gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
628 return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
629 SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
630 }
631
632 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)633 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
634 int16_t* multiplier_int16_t) {
635 TFLITE_DCHECK_GE(multiplier_int32_t, 0);
636 static constexpr int32_t kRoundingOffset = 1 << 15;
637 if (multiplier_int32_t >=
638 std::numeric_limits<int32_t>::max() - kRoundingOffset) {
639 *multiplier_int16_t = std::numeric_limits<int16_t>::max();
640 return;
641 }
642 const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
643 TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
644 TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
645 *multiplier_int16_t = result;
646 TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
647 }
648
649 // Minimum output bits to accommodate log of maximum input range. It actually
650 // does not matter if one considers, say, [-64,64] or [-64,64).
651 //
652 // For example, run this through Octave:
653 // [0:127; ...
654 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
655 // ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)656 constexpr int min_log_x_output_bits(int input_bits) {
657 return input_bits > 90 ? 7
658 : input_bits > 44 ? 6
659 : input_bits > 21 ? 5
660 : input_bits > 10 ? 4
661 : input_bits > 4 ? 3
662 : input_bits > 1 ? 2
663 : 1;
664 }
665
666 // Although currently the name of this function says that it cannot handle
667 // values less than 1, in practice it can handle as low as 1/x_max, where
668 // x_max is the largest representable input. In other words, the output range
669 // is symmetric.
670 template <int OutputIntegerBits, int InputIntegerBits>
671 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)672 log_x_for_x_greater_than_or_equal_to_1_impl(
673 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
674 // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
675 // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
676 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
677 // The reason for accumulating the result with an extra bit of headroom is
678 // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
679 // recip_denom will otherwise introduce an error.
680 static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
681 using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
682
683 const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
684 FixedPoint0, 1488522236, std::log(2.0));
685 const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
686 FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
687 const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
688 FixedPoint0, 1518500250, std::sqrt(0.5));
689 const FixedPoint0 one_quarter =
690 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
691
692 const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
693 FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
694 const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
695 FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
696 const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
697 FixedPoint0, 1057819769,
698 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
699 const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
700 FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
701
702 const FixedPointAccum shifted_quarter =
703 gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
704
705 // Reinterpret the input value as Q0.31, because we will figure out the
706 // required shift "ourselves" instead of using, say, Rescale.
707 FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
708 // z_a_pow_2 = input_integer_bits - z_a_headroom;
709 int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
710 FixedPoint0 r_a_tmp =
711 SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
712 const int32_t r_a_raw =
713 SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
714 // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
715 // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
716 // InputIntegerBits - z_b_headroom - 0.25);
717 const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
718 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
719 static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
720 31 - kAccumIntegerBits)),
721 shifted_quarter);
722
723 // z_b is treated like z_a, but premultiplying by sqrt(0.5).
724 FixedPoint0 z_b = z_a * sqrt_half;
725 int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
726 const int32_t r_b_raw =
727 SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
728 const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
729 FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
730 static_cast<int32_t>(InputIntegerBits - z_b_headroom),
731 31 - kAccumIntegerBits)),
732 shifted_quarter);
733
734 const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
735 const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
736 std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
737
738 const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
739 FixedPoint0 q = r - sqrt_sqrt_half;
740 q = q + q;
741
742 const FixedPoint0 common_sq = q * q;
743 const FixedPoint0 num = q * r + q * common_sq * alpha_n;
744 const FixedPoint0 denom_minus_one_0 =
745 p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
746 const FixedPoint0 recip_denom =
747 one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
748
749 const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
750 return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
751 num_scaled * recip_denom);
752 }
753
754 template <int OutputIntegerBits, int InputIntegerBits>
755 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)756 log_x_for_x_greater_than_or_equal_to_1(
757 gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
758 static_assert(
759 OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
760 "Output integer bits must be sufficient to accommodate logs of inputs.");
761 return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
762 InputIntegerBits>(
763 input_val);
764 }
765
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)766 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
767 int* num_bits_over_unit) {
768 int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
769 // This is the number of bits to the left of the binary point above 1.0.
770 // Consider x=1.25. In that case shifted_scale=0.8 and
771 // no later adjustment will be needed.
772 *num_bits_over_unit = x_integer_digits - headroom_plus_one;
773 const int32_t shifted_sum_minus_one =
774 static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
775 (static_cast<uint32_t>(1) << 31));
776
777 gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
778 gemmlowp::one_over_one_plus_x_for_x_in_0_1(
779 gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
780 return shifted_scale.raw();
781 }
782
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)783 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
784 int32_t* output_inv_sqrt,
785 int* output_shift) {
786 TFLITE_DCHECK_GE(input, 0);
787 if (input <= 1) {
788 // Handle the input value 1 separately to avoid overflow in that case
789 // in the general computation below (b/143972021). Also handle 0 as if it
790 // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
791 // but rare/unrealistic input value. We can expect both to occur in some
792 // incompletely trained models, but probably not in fully trained models.
793 *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
794 *output_shift = 0;
795 return;
796 }
797 TFLITE_DCHECK_GT(input, 1);
798 *output_shift = 11;
799 while (input >= (1 << 29)) {
800 input /= 4;
801 ++*output_shift;
802 }
803 const unsigned max_left_shift_bits =
804 CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
805 const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
806 const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
807 *output_shift -= left_shift_bit_pairs;
808 input <<= 2 * left_shift_bit_pairs;
809 TFLITE_DCHECK_GE(input, (1 << 27));
810 TFLITE_DCHECK_LT(input, (1 << 29));
811 using gemmlowp::FixedPoint;
812 using gemmlowp::Rescale;
813 using gemmlowp::SaturatingRoundingMultiplyByPOT;
814 // Using 3 integer bits gives us enough room for the internal arithmetic in
815 // this Newton-Raphson iteration.
816 using F3 = FixedPoint<int32_t, 3>;
817 using F0 = FixedPoint<int32_t, 0>;
818 const F3 fixedpoint_input = F3::FromRaw(input >> 1);
819 const F3 fixedpoint_half_input =
820 SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
821 const F3 fixedpoint_half_three =
822 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
823 // Newton-Raphson iteration
824 // Naive unoptimized starting guess: x = 1
825 F3 x = F3::One();
826 // Naive unoptimized number of iterations: 5
827 for (int i = 0; i < 5; i++) {
828 const F3 x3 = Rescale<3>(x * x * x);
829 x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
830 }
831 const F0 fixedpoint_half_sqrt_2 =
832 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
833 x = x * fixedpoint_half_sqrt_2;
834 *output_inv_sqrt = x.raw();
835 if (*output_shift < 0) {
836 *output_inv_sqrt <<= -*output_shift;
837 *output_shift = 0;
838 }
839 // Convert right shift (right is positive) to left shift.
840 *output_shift *= reverse_shift;
841 }
842
843 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
844 // BROADCASTING.
845 //
846 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
847 // rectangular array of numbers.
848 //
849 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
850 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
851 // to enable simple unoptimized implementations of element-wise broadcasting
852 // operations.
853 template <int N>
854 struct NdArrayDesc {
855 // The "extent" of each dimension. Indices along dimension d must be in the
856 // half-open interval [0, extents[d]).
857 int extents[N];
858
859 // The number of *elements* (not bytes) between consecutive indices of each
860 // dimension.
861 int strides[N];
862 };
863
864 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
865 // BROADCASTING.
866 //
867 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)868 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
869 int i3) {
870 TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
871 TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
872 TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
873 TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
874 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
875 i3 * desc.strides[3];
876 }
877
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])878 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
879 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
880 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
881 indexes[4] * desc.strides[4];
882 }
883
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])884 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
885 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
886 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
887 indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
888 indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
889 }
890
891 // Given the dimensions of the operands for an element-wise binary broadcast,
892 // adjusts them so that they can be directly iterated over with simple loops.
893 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
894 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
895 //
896 // This function assumes that the two input shapes are compatible up to
897 // broadcasting and the shorter one has already been prepended with 1s to be the
898 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
899 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
900 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
901 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
902 //
903 // When two shapes are compatible up to broadcasting, for each dimension d,
904 // the input extents are either equal, or one of them is 1.
905 //
906 // This function performs the following for each dimension d:
907 // - If the extents are equal, then do nothing since the loop that walks over
908 // both of the input arrays is correct.
909 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
910 // and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
911 // array0 to be referenced *at any index* in dimension d and still access the
912 // same slice.
913 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)914 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
915 const Dims<N>& input1_dims,
916 NdArrayDesc<N>* desc0_out,
917 NdArrayDesc<N>* desc1_out) {
918 TFLITE_DCHECK(desc0_out != nullptr);
919 TFLITE_DCHECK(desc1_out != nullptr);
920
921 // Copy dims to desc.
922 for (int i = 0; i < N; ++i) {
923 desc0_out->extents[i] = input0_dims.sizes[i];
924 desc0_out->strides[i] = input0_dims.strides[i];
925 desc1_out->extents[i] = input1_dims.sizes[i];
926 desc1_out->strides[i] = input1_dims.strides[i];
927 }
928
929 // Walk over each dimension. If the extents are equal do nothing.
930 // Otherwise, set the desc with extent 1 to have extent equal to the other and
931 // stride 0.
932 for (int i = 0; i < N; ++i) {
933 const int extent0 = ArraySize(input0_dims, i);
934 const int extent1 = ArraySize(input1_dims, i);
935 if (extent0 != extent1) {
936 if (extent0 == 1) {
937 desc0_out->strides[i] = 0;
938 desc0_out->extents[i] = extent1;
939 } else {
940 TFLITE_DCHECK_EQ(extent1, 1);
941 desc1_out->strides[i] = 0;
942 desc1_out->extents[i] = extent0;
943 }
944 }
945 }
946 }
947
948 // Copies dims to desc, calculating strides.
949 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)950 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
951 NdArrayDesc<N>* desc_out) {
952 int desc_stride = 1;
953 for (int i = N - 1; i >= 0; --i) {
954 desc_out->extents[i] = input_shape.Dims(i);
955 desc_out->strides[i] = desc_stride;
956 desc_stride *= input_shape.Dims(i);
957 }
958 }
959
960 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)961 inline void NdArrayDescsForElementwiseBroadcast(
962 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
963 NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
964 TFLITE_DCHECK(desc0_out != nullptr);
965 TFLITE_DCHECK(desc1_out != nullptr);
966
967 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
968 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
969
970 // Copy dims to desc, calculating strides.
971 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
972 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
973
974 // Walk over each dimension. If the extents are equal do nothing.
975 // Otherwise, set the desc with extent 1 to have extent equal to the other and
976 // stride 0.
977 for (int i = 0; i < N; ++i) {
978 const int extent0 = extended_input0_shape.Dims(i);
979 const int extent1 = extended_input1_shape.Dims(i);
980 if (extent0 != extent1) {
981 if (extent0 == 1) {
982 desc0_out->strides[i] = 0;
983 desc0_out->extents[i] = extent1;
984 } else {
985 TFLITE_DCHECK_EQ(extent1, 1);
986 desc1_out->strides[i] = 0;
987 desc1_out->extents[i] = extent0;
988 }
989 }
990 }
991 }
992
993 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)994 inline void NdArrayDescsForElementwiseBroadcast(
995 const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
996 const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
997 NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
998 TFLITE_DCHECK(desc0_out != nullptr);
999 TFLITE_DCHECK(desc1_out != nullptr);
1000 TFLITE_DCHECK(desc2_out != nullptr);
1001
1002 auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
1003 auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
1004 auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
1005
1006 // Copy dims to desc, calculating strides.
1007 CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
1008 CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
1009 CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
1010
1011 // Walk over each dimension. If the extents are equal do nothing.
1012 // Otherwise, set the desc with extent 1 to have extent equal to the other and
1013 // stride 0.
1014 for (int i = 0; i < N; ++i) {
1015 const int extent0 = extended_input0_shape.Dims(i);
1016 const int extent1 = extended_input1_shape.Dims(i);
1017 const int extent2 = extended_input2_shape.Dims(i);
1018
1019 int extent = extent0;
1020 if (extent1 != 1) extent = extent1;
1021 if (extent2 != 1) extent = extent2;
1022
1023 TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
1024 TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
1025 TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
1026
1027 if (!(extent0 == extent1 && extent1 == extent2)) {
1028 if (extent0 == 1) {
1029 desc0_out->strides[i] = 0;
1030 desc0_out->extents[i] = extent;
1031 }
1032 if (extent1 == 1) {
1033 desc1_out->strides[i] = 0;
1034 desc1_out->extents[i] = extent;
1035 }
1036 if (extent2 == 1) {
1037 desc2_out->strides[i] = 0;
1038 desc2_out->extents[i] = extent;
1039 }
1040 }
1041 }
1042 }
1043
1044 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
1045 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
1046 // re-writen as:
1047 // for (int b = 0; b < output.extents[0]; ++b) {
1048 // for (int y = 0; y < output.extents[1]; ++y) {
1049 // for (int x = 0; x < output.extents[2]; ++x) {
1050 // for (int c = 0; c < output.extents[3]; ++c) {
1051 // calc({b,y,x,c});
1052 // }
1053 // }
1054 // }
1055 // }
1056 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])1057 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
1058 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
1059 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
1060 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
1061 }
1062 }
1063
1064 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])1065 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
1066 const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
1067 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
1068 calc(indexes);
1069 }
1070 }
1071
1072 // Execute the calc function in the innermost iteration based on the shape of
1073 // the output. The calc function should take a single argument of type int[N].
1074 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)1075 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
1076 int indexes[N] = {0};
1077 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
1078 }
1079 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
1080 // gemmlowp.
1081 //
1082 // Returns the runtime argument rounded down to the nearest multiple of
1083 // the fixed Modulus.
1084 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)1085 Integer RoundDown(Integer i) {
1086 return i - (i % Modulus);
1087 }
1088
1089 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
1090 // gemmlowp.
1091 //
1092 // Returns the runtime argument rounded up to the nearest multiple of
1093 // the fixed Modulus.
1094 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)1095 Integer RoundUp(Integer i) {
1096 return RoundDown<Modulus>(i + Modulus - 1);
1097 }
1098
1099 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
1100 // gemmlowp.
1101 //
1102 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
1103 template <typename Integer>
CeilQuotient(Integer a,Integer b)1104 Integer CeilQuotient(Integer a, Integer b) {
1105 return (a + b - 1) / b;
1106 }
1107
1108 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
1109 // the direct dependency of internal/optimized/ on gemmlowp.
1110 //
1111 // It computes a reasonable number of threads to use for a GEMM of shape
1112 // (rows, cols, depth).
1113 //
1114 // TODO(b/131910176): get rid of this function by switching each call site
1115 // to its own more sensible logic for its own workload.
1116 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)1117 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
1118 int depth) {
1119 // Early-exit in the default case where multi-threading is disabled.
1120 if (max_num_threads == 1) {
1121 return 1;
1122 }
1123
1124 // Ensure that each thread has KernelRows rows to process, if at all possible.
1125 int thread_count = std::min(max_num_threads, rows / KernelRows);
1126
1127 // Limit the number of threads according to the overall size of the problem.
1128 if (thread_count > 1) {
1129 // Empirically determined value.
1130 static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
1131
1132 // We can only multiply two out of three sizes without risking overflow
1133 const std::uint64_t cubic_size =
1134 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
1135
1136 thread_count = std::min(
1137 thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
1138 }
1139
1140 if (thread_count < 1) {
1141 thread_count = 1;
1142 }
1143
1144 assert(thread_count > 0 && thread_count <= max_num_threads);
1145 return thread_count;
1146 }
1147
1148 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1149 void optimized_ops_preload_l1_stream(const T* ptr) {
1150 #ifdef __GNUC__
1151 // builtin offered by GCC-compatible compilers including clang
1152 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1153 #else
1154 (void)ptr;
1155 #endif
1156 }
1157
1158 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1159 void optimized_ops_preload_l1_keep(const T* ptr) {
1160 #ifdef __GNUC__
1161 // builtin offered by GCC-compatible compilers including clang
1162 __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1163 #else
1164 (void)ptr;
1165 #endif
1166 }
1167
1168 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1169 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1170 #ifdef __GNUC__
1171 // builtin offered by GCC-compatible compilers including clang
1172 __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1173 #else
1174 (void)ptr;
1175 #endif
1176 }
1177
1178 } // namespace tflite
1179
1180 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1181