1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17 
18 #include <algorithm>
19 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
20 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
21 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
22 #endif
23 #endif
24 
25 #include <functional>
26 
27 #include "fixedpoint/fixedpoint.h"
28 #include "tensorflow/lite/kernels/internal/cppmath.h"
29 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
30 #include "tensorflow/lite/kernels/internal/types.h"
31 
32 namespace tflite {
33 
34 constexpr int kReverseShift = -1;
35 
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)36 inline void GetActivationMinMax(FusedActivationFunctionType ac,
37                                 float* output_activation_min,
38                                 float* output_activation_max) {
39   switch (ac) {
40     case FusedActivationFunctionType::kNone:
41       *output_activation_min = std::numeric_limits<float>::lowest();
42       *output_activation_max = std::numeric_limits<float>::max();
43       break;
44     case FusedActivationFunctionType::kRelu:
45       *output_activation_min = 0.f;
46       *output_activation_max = std::numeric_limits<float>::max();
47       break;
48     case FusedActivationFunctionType::kRelu1:
49       *output_activation_min = -1.f;
50       *output_activation_max = 1.f;
51       break;
52     case FusedActivationFunctionType::kRelu6:
53       *output_activation_min = 0.f;
54       *output_activation_max = 6.f;
55       break;
56   }
57 }
58 
59 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)60 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
61                                       T output_activation_max) {
62   using std::max;
63   using std::min;
64   return min(max(x, output_activation_min), output_activation_max);
65 }
66 
67 // Legacy function, left for compatibility only.
68 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)69 float ActivationFunction(float x) {
70   float output_activation_min, output_activation_max;
71   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
72   return ActivationFunctionWithMinMax(x, output_activation_min,
73                                       output_activation_max);
74 }
75 
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)76 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
77                          const float* bias_data, int array_size,
78                          float* array_data) {
79   if (bias_size == 0) return;
80   // Note: see b/132215220: in May 2019 we thought it would be OK to replace
81   // this with the Eigen one-liner:
82   //   return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
83   // This turned out to severely regress performance: +4ms (i.e. 8%) on
84   // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
85   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
86 #ifdef USE_NEON
87   float* array_ptr = array_data;
88   float* array_end_ptr = array_ptr + array_size;
89   const auto clamp_min_vec = vdupq_n_f32(clamp_min);
90   const auto clamp_max_vec = vdupq_n_f32(clamp_max);
91   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
92     int i = 0;
93     for (; i <= bias_size - 16; i += 16) {
94       auto b0 = vld1q_f32(bias_data + i);
95       auto b1 = vld1q_f32(bias_data + i + 4);
96       auto b2 = vld1q_f32(bias_data + i + 8);
97       auto b3 = vld1q_f32(bias_data + i + 12);
98       auto a0 = vld1q_f32(array_ptr + i);
99       auto a1 = vld1q_f32(array_ptr + i + 4);
100       auto a2 = vld1q_f32(array_ptr + i + 8);
101       auto a3 = vld1q_f32(array_ptr + i + 12);
102       auto x0 = vaddq_f32(a0, b0);
103       auto x1 = vaddq_f32(a1, b1);
104       auto x2 = vaddq_f32(a2, b2);
105       auto x3 = vaddq_f32(a3, b3);
106       x0 = vmaxq_f32(clamp_min_vec, x0);
107       x1 = vmaxq_f32(clamp_min_vec, x1);
108       x2 = vmaxq_f32(clamp_min_vec, x2);
109       x3 = vmaxq_f32(clamp_min_vec, x3);
110       x0 = vminq_f32(clamp_max_vec, x0);
111       x1 = vminq_f32(clamp_max_vec, x1);
112       x2 = vminq_f32(clamp_max_vec, x2);
113       x3 = vminq_f32(clamp_max_vec, x3);
114       vst1q_f32(array_ptr + i, x0);
115       vst1q_f32(array_ptr + i + 4, x1);
116       vst1q_f32(array_ptr + i + 8, x2);
117       vst1q_f32(array_ptr + i + 12, x3);
118     }
119     for (; i <= bias_size - 4; i += 4) {
120       auto b = vld1q_f32(bias_data + i);
121       auto a = vld1q_f32(array_ptr + i);
122       auto x = vaddq_f32(a, b);
123       x = vmaxq_f32(clamp_min_vec, x);
124       x = vminq_f32(clamp_max_vec, x);
125       vst1q_f32(array_ptr + i, x);
126     }
127     for (; i < bias_size; i++) {
128       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
129                                                   clamp_min, clamp_max);
130     }
131   }
132 #else  // not NEON
133   for (int array_offset = 0; array_offset < array_size;
134        array_offset += bias_size) {
135     for (int i = 0; i < bias_size; i++) {
136       array_data[array_offset + i] = ActivationFunctionWithMinMax(
137           array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
138     }
139   }
140 #endif
141 }
142 
143 // Single-rounding MultiplyByQuantizedMultiplier
144 #if TFLITE_SINGLE_ROUNDING
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)145 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
146                                              int32_t quantized_multiplier,
147                                              int shift) {
148   TFLITE_DCHECK(quantized_multiplier >= 0);
149   TFLITE_DCHECK(shift >= -31 && shift <= 30);
150 
151   const int64_t total_shift = 31 - shift;
152   const int64_t round = static_cast<int64_t>(1) << (total_shift - 1);
153   int64_t result = x * static_cast<int64_t>(quantized_multiplier) + round;
154   result = result >> total_shift;
155 
156   TFLITE_DCHECK(result >= std::numeric_limits<int32_t>::min() &&
157                 result <= std::numeric_limits<int32_t>::max());
158   return static_cast<int32_t>(result);
159 }
160 
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int shift)161 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
162     int32_t x, int32_t quantized_multiplier, int shift) {
163   TFLITE_DCHECK_LE(shift, 0);
164   return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
165 }
166 
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int shift)167 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
168     int32_t x, int32_t quantized_multiplier, int shift) {
169   TFLITE_DCHECK_GE(shift, 0);
170   return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
171 }
172 
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)173 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
174                                              int32_t quantized_multiplier,
175                                              int shift) {
176   // Inputs:
177   // - quantized_multiplier has fixed point at bit 31
178   // - shift is -31 to +7 (negative for right shift)
179   //
180   // Assumptions: The following input ranges are assumed
181   // - quantize_scale>=0  (the usual range is (1<<30) to (1>>31)-1)
182   // - scaling is chosen so final scaled result fits in int32_t
183   // - input x is in the range -(1<<47) <= x < (1<<47)
184   TFLITE_DCHECK(quantized_multiplier >= 0);
185   TFLITE_DCHECK(shift >= -31 && shift < 8);
186   TFLITE_DCHECK(x >= -(static_cast<int64_t>(1) << 47) &&
187                 x < (static_cast<int64_t>(1) << 47));
188 
189   const int32_t reduced_multiplier =
190       (quantized_multiplier < 0x7FFF0000)
191           ? ((quantized_multiplier + (1 << 15)) >> 16)
192           : 0x7FFF;
193   const int64_t total_shift = 15 - shift;
194   const int64_t round = static_cast<int64_t>(1) << (total_shift - 1);
195   int64_t result = x * static_cast<int64_t>(reduced_multiplier) + round;
196   result = result >> total_shift;
197 
198   TFLITE_DCHECK(result >= std::numeric_limits<int32_t>::min() &&
199                 result <= std::numeric_limits<int32_t>::max());
200   return static_cast<int32_t>(result);
201 }
202 
203 #ifdef USE_NEON
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)204 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
205     int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
206   TFLITE_DCHECK(quantized_multiplier >= 0);
207 
208   const int right_shift = std::min(-1, shift);
209   const int left_shift = shift - right_shift;
210 
211   const int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
212   const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
213   const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
214 
215   int32x4x4_t result;
216   result.val[0] = vrshlq_s32(
217       vqdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), multiplier_dup),
218       right_shift_dup);
219 
220   result.val[1] = vrshlq_s32(
221       vqdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), multiplier_dup),
222       right_shift_dup);
223 
224   result.val[2] = vrshlq_s32(
225       vqdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), multiplier_dup),
226       right_shift_dup);
227 
228   result.val[3] = vrshlq_s32(
229       vqdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), multiplier_dup),
230       right_shift_dup);
231 
232   return result;
233 }
234 #endif  // USE_NEON
235 // Double-rounding MultiplyByQuantizedMultiplier
236 #else
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)237 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
238     int32_t x, int32_t quantized_multiplier, int left_shift) {
239   using gemmlowp::RoundingDivideByPOT;
240   using gemmlowp::SaturatingRoundingDoublingHighMul;
241   return RoundingDivideByPOT(
242       SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
243 }
244 
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)245 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
246     int32_t x, int32_t quantized_multiplier, int left_shift) {
247   using gemmlowp::SaturatingRoundingDoublingHighMul;
248   return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
249                                            quantized_multiplier);
250 }
251 
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)252 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
253                                              int32_t quantized_multiplier,
254                                              int shift) {
255   using gemmlowp::RoundingDivideByPOT;
256   using gemmlowp::SaturatingRoundingDoublingHighMul;
257   int left_shift = shift > 0 ? shift : 0;
258   int right_shift = shift > 0 ? 0 : -shift;
259   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
260                                  x * (1 << left_shift), quantized_multiplier),
261                              right_shift);
262 }
263 
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)264 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
265                                              int32_t quantized_multiplier,
266                                              int shift) {
267   // Inputs:
268   // - quantized_multiplier has fixed point at bit 31
269   // - shift is -31 to +7 (negative for right shift)
270   //
271   // Assumptions: The following input ranges are assumed
272   // - quantize_scale>=0  (the usual range is (1<<30) to (1>>31)-1)
273   // - scaling is chosen so final scaled result fits in int32_t
274   // - input x is in the range -(1<<47) <= x < (1<<47)
275   assert(quantized_multiplier >= 0);
276   assert(shift >= -31 && shift < 8);
277   assert(x >= -(static_cast<int64_t>(1) << 47) &&
278          x < (static_cast<int64_t>(1) << 47));
279 
280   int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
281                                    ? ((quantized_multiplier + (1 << 15)) >> 16)
282                                    : 0x7FFF;
283   int total_shift = 15 - shift;
284   x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
285   int32_t result = x >> total_shift;
286   return result;
287 }
288 
289 #ifdef USE_NEON
290 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)291 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
292     int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
293   const int left_shift = std::max(shift, 0);
294   const int right_shift = std::min(shift, 0);
295   int32x4x4_t result;
296 
297   int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
298   int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
299   int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
300 
301   result.val[0] =
302       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
303                                multiplier_dup),
304                  right_shift_dup);
305 
306   result.val[1] =
307       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
308                                multiplier_dup),
309                  right_shift_dup);
310 
311   result.val[2] =
312       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
313                                multiplier_dup),
314                  right_shift_dup);
315 
316   result.val[3] =
317       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
318                                multiplier_dup),
319                  right_shift_dup);
320 
321   return result;
322 }
323 #endif  // USE_NEON
324 #endif  // TFLITE_SINGLE_ROUNDING
325 
326 template <typename T>
CountLeadingZeros(T integer_input)327 int CountLeadingZeros(T integer_input) {
328   static_assert(std::is_unsigned<T>::value,
329                 "Only unsigned integer types handled.");
330 #if defined(__GNUC__)
331   return integer_input ? __builtin_clz(integer_input)
332                        : std::numeric_limits<T>::digits;
333 #else
334   if (integer_input == 0) {
335     return std::numeric_limits<T>::digits;
336   }
337 
338   const T one_in_leading_positive = static_cast<T>(1)
339                                     << (std::numeric_limits<T>::digits - 1);
340   int leading_zeros = 0;
341   while (integer_input < one_in_leading_positive) {
342     integer_input <<= 1;
343     ++leading_zeros;
344   }
345   return leading_zeros;
346 #endif
347 }
348 
349 template <typename T>
CountLeadingSignBits(T integer_input)350 inline int CountLeadingSignBits(T integer_input) {
351   static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
352 #if defined(__GNUC__) && !defined(__clang__)
353   return integer_input ? __builtin_clrsb(integer_input)
354                        : std::numeric_limits<T>::digits;
355 #else
356   using U = typename std::make_unsigned<T>::type;
357   return integer_input >= 0
358              ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
359          : integer_input != std::numeric_limits<T>::min()
360              ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
361              : 0;
362 #endif
363 }
364 
365 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
366 template <typename Integer>
FloorLog2(Integer n)367 inline Integer FloorLog2(Integer n) {
368   static_assert(std::is_integral<Integer>::value, "");
369   static_assert(std::is_signed<Integer>::value, "");
370   static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
371   TFLITE_CHECK_GT(n, 0);
372   if (sizeof(Integer) == 4) {
373     return 30 - CountLeadingSignBits(n);
374   } else {
375     return 62 - CountLeadingSignBits(n);
376   }
377 }
378 
379 // The size of the LUT depends on the type of input. For int8 inputs a simple
380 // 256 entries LUT is used. For int16 inputs the high 9 bits are used for
381 // indexing and the 7 remaining bits are used for interpolation. We thus use a
382 // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
383 // to interpolate the last value.
384 template <typename LutInT>
lut_size()385 constexpr int lut_size() {
386   static_assert(std::is_same<LutInT, int8_t>::value ||
387                     std::is_same<LutInT, int16_t>::value,
388                 "Only LUTs with int8 or int16 inputs are supported.");
389   return std::is_same<LutInT, int8_t>::value ? 256 : 513;
390 }
391 
392 // Generate a LUT for 'func' which can be used to approximate functions like
393 // exp, log, ...
394 //
395 // - func: the function to build the LUT for (e.g exp(x))
396 // - input_min, input_max: range of the func inputs
397 // - output_min, output_max: range of the func outputs
398 // - lut: pointer to the LUT table to fill, the table must be of size
399 // lut_size<LutInT>()
400 template <typename FloatT, typename LutInT, typename LutOutT>
gen_lut(FloatT (* func)(FloatT),FloatT input_min,FloatT input_max,FloatT output_min,FloatT output_max,LutOutT * lut)401 inline void gen_lut(FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
402                     FloatT output_min, FloatT output_max, LutOutT* lut) {
403   static_assert(std::is_same<LutInT, int8_t>::value ||
404                     std::is_same<LutInT, int16_t>::value,
405                 "Only LUTs with int8 or int16 inputs are supported.");
406   static_assert(std::is_same<LutOutT, int8_t>::value ||
407                     std::is_same<LutOutT, int16_t>::value,
408                 "Only LUTs with int8 or int16 outputs are supported.");
409   static_assert(std::is_floating_point<FloatT>::value,
410                 "FloatT must be a floating-point type.");
411 
412   const int nb_steps = std::is_same<LutInT, int8_t>::value ? 256 : 512;
413   const FloatT step = (input_max - input_min) / nb_steps;
414   const FloatT half_step = step / 2;
415   const FloatT output_scaling_inv =
416       static_cast<FloatT>(std::numeric_limits<LutOutT>::max() -
417                           std::numeric_limits<LutOutT>::min() + 1) /
418       (output_max - output_min);
419   const FloatT table_min =
420       static_cast<FloatT>(std::numeric_limits<LutOutT>::min());
421   const FloatT table_max =
422       static_cast<FloatT>(std::numeric_limits<LutOutT>::max());
423 
424   for (int i = 0; i < nb_steps; i++) {
425     const FloatT val = func(input_min + i * step);
426     const FloatT val_midpoint = func(input_min + i * step + half_step);
427     const FloatT val_next = func(input_min + (i + 1) * step);
428 
429     const FloatT sample_val = TfLiteRound(val * output_scaling_inv);
430     const FloatT midpoint_interp_val =
431         TfLiteRound((val_next * output_scaling_inv +
432                      TfLiteRound(val * output_scaling_inv)) /
433                     2);
434     const FloatT midpoint_val = TfLiteRound(val_midpoint * output_scaling_inv);
435     const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
436     const FloatT bias = TfLiteRound(midpoint_err / 2);
437 
438     lut[i] = static_cast<LutOutT>(std::min<FloatT>(
439         std::max<FloatT>(sample_val - bias, table_min), table_max));
440   }
441 
442   const bool with_extra_interpolation_value =
443       std::is_same<LutInT, int16_t>::value;
444   if (with_extra_interpolation_value) {
445     lut[nb_steps] = static_cast<LutOutT>(std::min<FloatT>(
446         std::max<FloatT>(TfLiteRound(func(input_max) * output_scaling_inv),
447                          table_min),
448         table_max));
449   }
450 }
451 
452 // LUT must have 513 values
453 template <typename LutOutT>
lut_lookup_with_interpolation(int16_t value,const LutOutT * lut)454 inline LutOutT lut_lookup_with_interpolation(int16_t value,
455                                              const LutOutT* lut) {
456   static_assert(std::is_same<LutOutT, int8_t>::value ||
457                     std::is_same<LutOutT, int16_t>::value,
458                 "Only LUTs with int8 or int16 outputs are supported.");
459   // 512 base values, lut[513] is only used to calculate the slope
460   const uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
461   assert(index < 512 && "LUT index out of range.");
462   const int16_t offset = value & 0x7f;
463 
464   // Base and slope are Q0.x
465   const LutOutT base = lut[index];
466   const LutOutT slope = lut[index + 1] - lut[index];
467 
468   // Q0.x * Q0.7 = Q0.(x + 7)
469   // Round and convert from Q0.(x + 7) to Q0.x
470   const int delta = (slope * offset + 64) >> 7;
471 
472   // Q0.15 + Q0.15
473   return static_cast<LutOutT>(base + delta);
474 }
475 
476 // int16_t -> int16_t table lookup with interpolation
477 // LUT must have 513 values
lut_lookup(int16_t value,const int16_t * lut)478 inline int16_t lut_lookup(int16_t value, const int16_t* lut) {
479   return lut_lookup_with_interpolation(value, lut);
480 }
481 
482 // int16_t -> int8_t table lookup with interpolation
483 // LUT must have 513 values
lut_lookup(int16_t value,const int8_t * lut)484 inline int8_t lut_lookup(int16_t value, const int8_t* lut) {
485   return lut_lookup_with_interpolation(value, lut);
486 }
487 
488 // int8_t -> int8_t table lookup without interpolation
489 // LUT must have 256 values
lut_lookup(int8_t value,const int8_t * lut)490 inline int8_t lut_lookup(int8_t value, const int8_t* lut) {
491   return lut[128 + value];
492 }
493 
494 // int8_t -> int16_t table lookup without interpolation
495 // LUT must have 256 values
lut_lookup(int8_t value,const int16_t * lut)496 inline int16_t lut_lookup(int8_t value, const int16_t* lut) {
497   return lut[128 + value];
498 }
499 
500 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
501 
502 // We use combined sigmoid and tanh look-up table, since
503 // tanh(x) = 2*sigmoid(2*x) -1.
504 // Both functions are symmetric, so the LUT table is only needed
505 // for the absolute value of the input.
506 static const uint16_t sigmoid_table_uint16[256] = {
507     32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
508     40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
509     46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
510     52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
511     56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
512     59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
513     61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
514     62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
515     63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
516     64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
517     64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
518     65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
519     65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
520     65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
521     65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
522     65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
523     65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
524     65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
525     65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
526     65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
527     65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
528     65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
529     65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
530     65534, 65534, 65535};
531 
532 // TODO(b/77858996): Add these to gemmlowp.
533 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)534 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
535   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
536   return a;
537 }
538 
539 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)540 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
541   std::int64_t a64 = a;
542   std::int64_t b64 = b;
543   std::int64_t sum = a64 + b64;
544   return static_cast<std::int32_t>(std::min(
545       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
546       std::max(
547           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
548           sum)));
549 }
550 
551 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)552 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
553     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
554     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
555   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
556       SaturatingAddNonGemmlowp(a.raw(), b.raw()));
557 }
558 
559 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)560 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
561   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
562   return a;
563 }
564 
565 template <>
SaturatingSub(std::int16_t a,std::int16_t b)566 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
567   std::int32_t a32 = a;
568   std::int32_t b32 = b;
569   std::int32_t diff = a32 - b32;
570   return static_cast<std::int16_t>(
571       std::min(static_cast<int32_t>(32767),
572                std::max(static_cast<int32_t>(-32768), diff)));
573 }
574 
575 template <>
SaturatingSub(std::int32_t a,std::int32_t b)576 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
577   std::int64_t a64 = a;
578   std::int64_t b64 = b;
579   std::int64_t diff = a64 - b64;
580   return static_cast<std::int32_t>(std::min(
581       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
582       std::max(
583           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
584           diff)));
585 }
586 
587 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)588 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
589     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
590     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
591   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
592       SaturatingSub(a.raw(), b.raw()));
593 }
594 // End section to be moved to gemmlowp.
595 
596 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)597 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
598   if (exponent == 0) {
599     return x;
600   }
601   using ScalarIntegerType =
602       typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
603   const IntegerType min =
604       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
605   const IntegerType max =
606       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
607   const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
608 
609   const std::int32_t threshold =
610       ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
611   const IntegerType positive_mask =
612       gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
613   const IntegerType negative_mask =
614       gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
615 
616   IntegerType result = gemmlowp::ShiftLeft(x, exponent);
617   result = gemmlowp::SelectUsingMask(positive_mask, max, result);
618   result = gemmlowp::SelectUsingMask(negative_mask, min, result);
619   return result;
620 }
621 
622 // If we want to leave IntegerBits fixed, then multiplication
623 // by a power of two has to be saturating/rounding, not exact anymore.
624 template <typename tRawType, int tIntegerBits>
625 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)626 SaturatingRoundingMultiplyByPOTParam(
627     gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
628   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
629       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
630 }
631 
632 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)633 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
634                                             int16_t* multiplier_int16_t) {
635   TFLITE_DCHECK_GE(multiplier_int32_t, 0);
636   static constexpr int32_t kRoundingOffset = 1 << 15;
637   if (multiplier_int32_t >=
638       std::numeric_limits<int32_t>::max() - kRoundingOffset) {
639     *multiplier_int16_t = std::numeric_limits<int16_t>::max();
640     return;
641   }
642   const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
643   TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
644   TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
645   *multiplier_int16_t = result;
646   TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
647 }
648 
649 // Minimum output bits to accommodate log of maximum input range.  It actually
650 // does not matter if one considers, say, [-64,64] or [-64,64).
651 //
652 // For example, run this through Octave:
653 // [0:127; ...
654 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
655 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)656 constexpr int min_log_x_output_bits(int input_bits) {
657   return input_bits > 90   ? 7
658          : input_bits > 44 ? 6
659          : input_bits > 21 ? 5
660          : input_bits > 10 ? 4
661          : input_bits > 4  ? 3
662          : input_bits > 1  ? 2
663                            : 1;
664 }
665 
666 // Although currently the name of this function says that it cannot handle
667 // values less than 1, in practice it can handle as low as 1/x_max, where
668 // x_max is the largest representable input.  In other words, the output range
669 // is symmetric.
670 template <int OutputIntegerBits, int InputIntegerBits>
671 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)672 log_x_for_x_greater_than_or_equal_to_1_impl(
673     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
674   // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
675   // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
676   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
677   // The reason for accumulating the result with an extra bit of headroom is
678   // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
679   // recip_denom will otherwise introduce an error.
680   static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
681   using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
682 
683   const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
684       FixedPoint0, 1488522236, std::log(2.0));
685   const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
686       FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
687   const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
688       FixedPoint0, 1518500250, std::sqrt(0.5));
689   const FixedPoint0 one_quarter =
690       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
691 
692   const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
693       FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
694   const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
695       FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
696   const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
697       FixedPoint0, 1057819769,
698       2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
699   const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
700       FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
701 
702   const FixedPointAccum shifted_quarter =
703       gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
704 
705   // Reinterpret the input value as Q0.31, because we will figure out the
706   // required shift "ourselves" instead of using, say, Rescale.
707   FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
708   // z_a_pow_2 = input_integer_bits - z_a_headroom;
709   int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
710   FixedPoint0 r_a_tmp =
711       SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
712   const int32_t r_a_raw =
713       SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
714   // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
715   // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
716   //                   InputIntegerBits - z_b_headroom - 0.25);
717   const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
718       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
719           static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
720           31 - kAccumIntegerBits)),
721       shifted_quarter);
722 
723   // z_b is treated like z_a, but premultiplying by sqrt(0.5).
724   FixedPoint0 z_b = z_a * sqrt_half;
725   int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
726   const int32_t r_b_raw =
727       SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
728   const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
729       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
730           static_cast<int32_t>(InputIntegerBits - z_b_headroom),
731           31 - kAccumIntegerBits)),
732       shifted_quarter);
733 
734   const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
735   const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
736       std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
737 
738   const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
739   FixedPoint0 q = r - sqrt_sqrt_half;
740   q = q + q;
741 
742   const FixedPoint0 common_sq = q * q;
743   const FixedPoint0 num = q * r + q * common_sq * alpha_n;
744   const FixedPoint0 denom_minus_one_0 =
745       p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
746   const FixedPoint0 recip_denom =
747       one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
748 
749   const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
750   return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
751                                               num_scaled * recip_denom);
752 }
753 
754 template <int OutputIntegerBits, int InputIntegerBits>
755 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)756 log_x_for_x_greater_than_or_equal_to_1(
757     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
758   static_assert(
759       OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
760       "Output integer bits must be sufficient to accommodate logs of inputs.");
761   return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
762                                                      InputIntegerBits>(
763       input_val);
764 }
765 
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)766 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
767                              int* num_bits_over_unit) {
768   int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
769   // This is the number of bits to the left of the binary point above 1.0.
770   // Consider x=1.25.  In that case shifted_scale=0.8 and
771   // no later adjustment will be needed.
772   *num_bits_over_unit = x_integer_digits - headroom_plus_one;
773   const int32_t shifted_sum_minus_one =
774       static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
775                            (static_cast<uint32_t>(1) << 31));
776 
777   gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
778       gemmlowp::one_over_one_plus_x_for_x_in_0_1(
779           gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
780   return shifted_scale.raw();
781 }
782 
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)783 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
784                                              int32_t* output_inv_sqrt,
785                                              int* output_shift) {
786   TFLITE_DCHECK_GE(input, 0);
787   if (input <= 1) {
788     // Handle the input value 1 separately to avoid overflow in that case
789     // in the general computation below (b/143972021). Also handle 0 as if it
790     // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
791     // but rare/unrealistic input value. We can expect both to occur in some
792     // incompletely trained models, but probably not in fully trained models.
793     *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
794     *output_shift = 0;
795     return;
796   }
797   TFLITE_DCHECK_GT(input, 1);
798   *output_shift = 11;
799   while (input >= (1 << 29)) {
800     input /= 4;
801     ++*output_shift;
802   }
803   const unsigned max_left_shift_bits =
804       CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
805   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
806   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
807   *output_shift -= left_shift_bit_pairs;
808   input <<= 2 * left_shift_bit_pairs;
809   TFLITE_DCHECK_GE(input, (1 << 27));
810   TFLITE_DCHECK_LT(input, (1 << 29));
811   using gemmlowp::FixedPoint;
812   using gemmlowp::Rescale;
813   using gemmlowp::SaturatingRoundingMultiplyByPOT;
814   // Using 3 integer bits gives us enough room for the internal arithmetic in
815   // this Newton-Raphson iteration.
816   using F3 = FixedPoint<int32_t, 3>;
817   using F0 = FixedPoint<int32_t, 0>;
818   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
819   const F3 fixedpoint_half_input =
820       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
821   const F3 fixedpoint_half_three =
822       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
823   // Newton-Raphson iteration
824   // Naive unoptimized starting guess: x = 1
825   F3 x = F3::One();
826   // Naive unoptimized number of iterations: 5
827   for (int i = 0; i < 5; i++) {
828     const F3 x3 = Rescale<3>(x * x * x);
829     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
830   }
831   const F0 fixedpoint_half_sqrt_2 =
832       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
833   x = x * fixedpoint_half_sqrt_2;
834   *output_inv_sqrt = x.raw();
835   if (*output_shift < 0) {
836     *output_inv_sqrt <<= -*output_shift;
837     *output_shift = 0;
838   }
839   // Convert right shift (right is positive) to left shift.
840   *output_shift *= reverse_shift;
841 }
842 
843 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
844 // BROADCASTING.
845 //
846 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
847 // rectangular array of numbers.
848 //
849 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
850 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
851 // to enable simple unoptimized implementations of element-wise broadcasting
852 // operations.
853 template <int N>
854 struct NdArrayDesc {
855   // The "extent" of each dimension. Indices along dimension d must be in the
856   // half-open interval [0, extents[d]).
857   int extents[N];
858 
859   // The number of *elements* (not bytes) between consecutive indices of each
860   // dimension.
861   int strides[N];
862 };
863 
864 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
865 // BROADCASTING.
866 //
867 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)868 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
869                             int i3) {
870   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
871   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
872   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
873   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
874   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
875          i3 * desc.strides[3];
876 }
877 
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])878 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
879   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
880          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
881          indexes[4] * desc.strides[4];
882 }
883 
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])884 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
885   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
886          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
887          indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
888          indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
889 }
890 
891 // Given the dimensions of the operands for an element-wise binary broadcast,
892 // adjusts them so that they can be directly iterated over with simple loops.
893 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
894 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
895 //
896 // This function assumes that the two input shapes are compatible up to
897 // broadcasting and the shorter one has already been prepended with 1s to be the
898 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
899 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
900 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
901 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
902 //
903 // When two shapes are compatible up to broadcasting, for each dimension d,
904 // the input extents are either equal, or one of them is 1.
905 //
906 // This function performs the following for each dimension d:
907 // - If the extents are equal, then do nothing since the loop that walks over
908 //   both of the input arrays is correct.
909 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
910 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
911 //   array0 to be referenced *at any index* in dimension d and still access the
912 //   same slice.
913 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)914 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
915                                                 const Dims<N>& input1_dims,
916                                                 NdArrayDesc<N>* desc0_out,
917                                                 NdArrayDesc<N>* desc1_out) {
918   TFLITE_DCHECK(desc0_out != nullptr);
919   TFLITE_DCHECK(desc1_out != nullptr);
920 
921   // Copy dims to desc.
922   for (int i = 0; i < N; ++i) {
923     desc0_out->extents[i] = input0_dims.sizes[i];
924     desc0_out->strides[i] = input0_dims.strides[i];
925     desc1_out->extents[i] = input1_dims.sizes[i];
926     desc1_out->strides[i] = input1_dims.strides[i];
927   }
928 
929   // Walk over each dimension. If the extents are equal do nothing.
930   // Otherwise, set the desc with extent 1 to have extent equal to the other and
931   // stride 0.
932   for (int i = 0; i < N; ++i) {
933     const int extent0 = ArraySize(input0_dims, i);
934     const int extent1 = ArraySize(input1_dims, i);
935     if (extent0 != extent1) {
936       if (extent0 == 1) {
937         desc0_out->strides[i] = 0;
938         desc0_out->extents[i] = extent1;
939       } else {
940         TFLITE_DCHECK_EQ(extent1, 1);
941         desc1_out->strides[i] = 0;
942         desc1_out->extents[i] = extent0;
943       }
944     }
945   }
946 }
947 
948 // Copies dims to desc, calculating strides.
949 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)950 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
951                            NdArrayDesc<N>* desc_out) {
952   int desc_stride = 1;
953   for (int i = N - 1; i >= 0; --i) {
954     desc_out->extents[i] = input_shape.Dims(i);
955     desc_out->strides[i] = desc_stride;
956     desc_stride *= input_shape.Dims(i);
957   }
958 }
959 
960 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)961 inline void NdArrayDescsForElementwiseBroadcast(
962     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
963     NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
964   TFLITE_DCHECK(desc0_out != nullptr);
965   TFLITE_DCHECK(desc1_out != nullptr);
966 
967   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
968   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
969 
970   // Copy dims to desc, calculating strides.
971   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
972   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
973 
974   // Walk over each dimension. If the extents are equal do nothing.
975   // Otherwise, set the desc with extent 1 to have extent equal to the other and
976   // stride 0.
977   for (int i = 0; i < N; ++i) {
978     const int extent0 = extended_input0_shape.Dims(i);
979     const int extent1 = extended_input1_shape.Dims(i);
980     if (extent0 != extent1) {
981       if (extent0 == 1) {
982         desc0_out->strides[i] = 0;
983         desc0_out->extents[i] = extent1;
984       } else {
985         TFLITE_DCHECK_EQ(extent1, 1);
986         desc1_out->strides[i] = 0;
987         desc1_out->extents[i] = extent0;
988       }
989     }
990   }
991 }
992 
993 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)994 inline void NdArrayDescsForElementwiseBroadcast(
995     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
996     const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
997     NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
998   TFLITE_DCHECK(desc0_out != nullptr);
999   TFLITE_DCHECK(desc1_out != nullptr);
1000   TFLITE_DCHECK(desc2_out != nullptr);
1001 
1002   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
1003   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
1004   auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
1005 
1006   // Copy dims to desc, calculating strides.
1007   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
1008   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
1009   CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
1010 
1011   // Walk over each dimension. If the extents are equal do nothing.
1012   // Otherwise, set the desc with extent 1 to have extent equal to the other and
1013   // stride 0.
1014   for (int i = 0; i < N; ++i) {
1015     const int extent0 = extended_input0_shape.Dims(i);
1016     const int extent1 = extended_input1_shape.Dims(i);
1017     const int extent2 = extended_input2_shape.Dims(i);
1018 
1019     int extent = extent0;
1020     if (extent1 != 1) extent = extent1;
1021     if (extent2 != 1) extent = extent2;
1022 
1023     TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
1024     TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
1025     TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
1026 
1027     if (!(extent0 == extent1 && extent1 == extent2)) {
1028       if (extent0 == 1) {
1029         desc0_out->strides[i] = 0;
1030         desc0_out->extents[i] = extent;
1031       }
1032       if (extent1 == 1) {
1033         desc1_out->strides[i] = 0;
1034         desc1_out->extents[i] = extent;
1035       }
1036       if (extent2 == 1) {
1037         desc2_out->strides[i] = 0;
1038         desc2_out->extents[i] = extent;
1039       }
1040     }
1041   }
1042 }
1043 
1044 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
1045 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
1046 // re-writen as:
1047 // for (int b = 0; b < output.extents[0]; ++b) {
1048 //   for (int y = 0; y < output.extents[1]; ++y) {
1049 //     for (int x = 0; x < output.extents[2]; ++x) {
1050 //       for (int c = 0; c < output.extents[3]; ++c) {
1051 //           calc({b,y,x,c});
1052 //       }
1053 //     }
1054 //   }
1055 // }
1056 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])1057 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
1058     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
1059   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
1060     NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
1061   }
1062 }
1063 
1064 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])1065 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
1066     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
1067   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
1068     calc(indexes);
1069   }
1070 }
1071 
1072 // Execute the calc function in the innermost iteration based on the shape of
1073 // the output. The calc function should take a single argument of type int[N].
1074 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)1075 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
1076   int indexes[N] = {0};
1077   NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
1078 }
1079 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
1080 // gemmlowp.
1081 //
1082 // Returns the runtime argument rounded down to the nearest multiple of
1083 // the fixed Modulus.
1084 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)1085 Integer RoundDown(Integer i) {
1086   return i - (i % Modulus);
1087 }
1088 
1089 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
1090 // gemmlowp.
1091 //
1092 // Returns the runtime argument rounded up to the nearest multiple of
1093 // the fixed Modulus.
1094 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)1095 Integer RoundUp(Integer i) {
1096   return RoundDown<Modulus>(i + Modulus - 1);
1097 }
1098 
1099 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
1100 // gemmlowp.
1101 //
1102 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
1103 template <typename Integer>
CeilQuotient(Integer a,Integer b)1104 Integer CeilQuotient(Integer a, Integer b) {
1105   return (a + b - 1) / b;
1106 }
1107 
1108 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
1109 // the direct dependency of internal/optimized/ on gemmlowp.
1110 //
1111 // It computes a reasonable number of threads to use for a GEMM of shape
1112 // (rows, cols, depth).
1113 //
1114 // TODO(b/131910176): get rid of this function by switching each call site
1115 // to its own more sensible logic for its own workload.
1116 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)1117 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
1118                                 int depth) {
1119   // Early-exit in the default case where multi-threading is disabled.
1120   if (max_num_threads == 1) {
1121     return 1;
1122   }
1123 
1124   // Ensure that each thread has KernelRows rows to process, if at all possible.
1125   int thread_count = std::min(max_num_threads, rows / KernelRows);
1126 
1127   // Limit the number of threads according to the overall size of the problem.
1128   if (thread_count > 1) {
1129     // Empirically determined value.
1130     static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
1131 
1132     // We can only multiply two out of three sizes without risking overflow
1133     const std::uint64_t cubic_size =
1134         std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
1135 
1136     thread_count = std::min(
1137         thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
1138   }
1139 
1140   if (thread_count < 1) {
1141     thread_count = 1;
1142   }
1143 
1144   assert(thread_count > 0 && thread_count <= max_num_threads);
1145   return thread_count;
1146 }
1147 
1148 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1149 void optimized_ops_preload_l1_stream(const T* ptr) {
1150 #ifdef __GNUC__
1151   // builtin offered by GCC-compatible compilers including clang
1152   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1153 #else
1154   (void)ptr;
1155 #endif
1156 }
1157 
1158 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1159 void optimized_ops_preload_l1_keep(const T* ptr) {
1160 #ifdef __GNUC__
1161   // builtin offered by GCC-compatible compilers including clang
1162   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1163 #else
1164   (void)ptr;
1165 #endif
1166 }
1167 
1168 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1169 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1170 #ifdef __GNUC__
1171   // builtin offered by GCC-compatible compilers including clang
1172   __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1173 #else
1174   (void)ptr;
1175 #endif
1176 }
1177 
1178 }  // namespace tflite
1179 
1180 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1181