• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17 
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23 
24 #include <functional>
25 
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 
31 namespace tflite {
32 
33 constexpr int kReverseShift = -1;
34 
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36                                 float* output_activation_min,
37                                 float* output_activation_max) {
38   switch (ac) {
39     case FusedActivationFunctionType::kNone:
40       *output_activation_min = std::numeric_limits<float>::lowest();
41       *output_activation_max = std::numeric_limits<float>::max();
42       break;
43     case FusedActivationFunctionType::kRelu:
44       *output_activation_min = 0.f;
45       *output_activation_max = std::numeric_limits<float>::max();
46       break;
47     case FusedActivationFunctionType::kRelu1:
48       *output_activation_min = -1.f;
49       *output_activation_max = 1.f;
50       break;
51     case FusedActivationFunctionType::kRelu6:
52       *output_activation_min = 0.f;
53       *output_activation_max = 6.f;
54       break;
55   }
56 }
57 
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60                                       T output_activation_max) {
61   using std::max;
62   using std::min;
63   return min(max(x, output_activation_min), output_activation_max);
64 }
65 
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69   float output_activation_min, output_activation_max;
70   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71   return ActivationFunctionWithMinMax(x, output_activation_min,
72                                       output_activation_max);
73 }
74 
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76                          const float* bias_data, int array_size,
77                          float* array_data) {
78   // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79   // this with the Eigen one-liner:
80   //   return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81   // This turned out to severely regress performance: +4ms (i.e. 8%) on
82   // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85   float* array_ptr = array_data;
86   float* array_end_ptr = array_ptr + array_size;
87   const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88   const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90     int i = 0;
91     for (; i <= bias_size - 16; i += 16) {
92       auto b0 = vld1q_f32(bias_data + i);
93       auto b1 = vld1q_f32(bias_data + i + 4);
94       auto b2 = vld1q_f32(bias_data + i + 8);
95       auto b3 = vld1q_f32(bias_data + i + 12);
96       auto a0 = vld1q_f32(array_ptr + i);
97       auto a1 = vld1q_f32(array_ptr + i + 4);
98       auto a2 = vld1q_f32(array_ptr + i + 8);
99       auto a3 = vld1q_f32(array_ptr + i + 12);
100       auto x0 = vaddq_f32(a0, b0);
101       auto x1 = vaddq_f32(a1, b1);
102       auto x2 = vaddq_f32(a2, b2);
103       auto x3 = vaddq_f32(a3, b3);
104       x0 = vmaxq_f32(clamp_min_vec, x0);
105       x1 = vmaxq_f32(clamp_min_vec, x1);
106       x2 = vmaxq_f32(clamp_min_vec, x2);
107       x3 = vmaxq_f32(clamp_min_vec, x3);
108       x0 = vminq_f32(clamp_max_vec, x0);
109       x1 = vminq_f32(clamp_max_vec, x1);
110       x2 = vminq_f32(clamp_max_vec, x2);
111       x3 = vminq_f32(clamp_max_vec, x3);
112       vst1q_f32(array_ptr + i, x0);
113       vst1q_f32(array_ptr + i + 4, x1);
114       vst1q_f32(array_ptr + i + 8, x2);
115       vst1q_f32(array_ptr + i + 12, x3);
116     }
117     for (; i <= bias_size - 4; i += 4) {
118       auto b = vld1q_f32(bias_data + i);
119       auto a = vld1q_f32(array_ptr + i);
120       auto x = vaddq_f32(a, b);
121       x = vmaxq_f32(clamp_min_vec, x);
122       x = vminq_f32(clamp_max_vec, x);
123       vst1q_f32(array_ptr + i, x);
124     }
125     for (; i < bias_size; i++) {
126       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127                                                   clamp_min, clamp_max);
128     }
129   }
130 #else  // not NEON
131   for (int array_offset = 0; array_offset < array_size;
132        array_offset += bias_size) {
133     for (int i = 0; i < bias_size; i++) {
134       array_data[array_offset + i] = ActivationFunctionWithMinMax(
135           array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136     }
137   }
138 #endif
139 }
140 
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142     int32_t x, int32_t quantized_multiplier, int left_shift) {
143   using gemmlowp::RoundingDivideByPOT;
144   using gemmlowp::SaturatingRoundingDoublingHighMul;
145   return RoundingDivideByPOT(
146       SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148 
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150     int32_t x, int32_t quantized_multiplier, int left_shift) {
151   using gemmlowp::SaturatingRoundingDoublingHighMul;
152   return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153                                            quantized_multiplier);
154 }
155 
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157                                              int32_t quantized_multiplier,
158                                              int shift) {
159   using gemmlowp::RoundingDivideByPOT;
160   using gemmlowp::SaturatingRoundingDoublingHighMul;
161   int left_shift = shift > 0 ? shift : 0;
162   int right_shift = shift > 0 ? 0 : -shift;
163   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164                                  x * (1 << left_shift), quantized_multiplier),
165                              right_shift);
166 }
167 
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169                                              int32_t quantized_multiplier,
170                                              int shift) {
171   // Inputs:
172   // - quantized_multiplier has fixed point at bit 31
173   // - shift is -31 to +7 (negative for right shift)
174   //
175   // Assumptions: The following input ranges are assumed
176   // - quantize_scale>=0  (the usual range is (1<<30) to (1>>31)-1)
177   // - scaling is chosen so final scaled result fits in int32_t
178   // - input x is in the range -(1<<47) <= x < (1<<47)
179   assert(quantized_multiplier >= 0);
180   assert(shift >= -31 && shift < 8);
181   assert(x >= -(static_cast<int64_t>(1) << 47) &&
182          x < (static_cast<int64_t>(1) << 47));
183 
184   int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185                                    ? ((quantized_multiplier + (1 << 15)) >> 16)
186                                    : 0x7FFF;
187   int total_shift = 15 - shift;
188   x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189   int32_t result = x >> total_shift;
190   return result;
191 }
192 
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196     int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197   const int left_shift = std::max(shift, 0);
198   const int right_shift = std::min(shift, 0);
199   int32x4x4_t result;
200 
201   int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202   int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203   int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204 
205   result.val[0] =
206       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207                                multiplier_dup),
208                  right_shift_dup);
209 
210   result.val[1] =
211       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212                                multiplier_dup),
213                  right_shift_dup);
214 
215   result.val[2] =
216       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217                                multiplier_dup),
218                  right_shift_dup);
219 
220   result.val[3] =
221       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222                                multiplier_dup),
223                  right_shift_dup);
224 
225   return result;
226 }
227 #endif
228 
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231   static_assert(std::is_unsigned<T>::value,
232                 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234   return integer_input ? __builtin_clz(integer_input)
235                        : std::numeric_limits<T>::digits;
236 #else
237   if (integer_input == 0) {
238     return std::numeric_limits<T>::digits;
239   }
240 
241   const T one_in_leading_positive = static_cast<T>(1)
242                                     << (std::numeric_limits<T>::digits - 1);
243   int leading_zeros = 0;
244   while (integer_input < one_in_leading_positive) {
245     integer_input <<= 1;
246     ++leading_zeros;
247   }
248   return leading_zeros;
249 #endif
250 }
251 
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254   static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256   return integer_input ? __builtin_clrsb(integer_input)
257                        : std::numeric_limits<T>::digits;
258 #else
259   using U = typename std::make_unsigned<T>::type;
260   return integer_input >= 0
261              ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262          : integer_input != std::numeric_limits<T>::min()
263              ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264              : 0;
265 #endif
266 }
267 
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271   static_assert(std::is_integral<Integer>::value, "");
272   static_assert(std::is_signed<Integer>::value, "");
273   static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274   TFLITE_CHECK_GT(n, 0);
275   if (sizeof(Integer) == 4) {
276     return 30 - CountLeadingSignBits(n);
277   } else {
278     return 62 - CountLeadingSignBits(n);
279   }
280 }
281 
282 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
283 // softmax
284 // func - the function to build the LUT for (e.g exp(x))
285 // min,max - table limits
286 // table - pointer to buffer
287 // num - number of elements in the LUT
gen_lut(double (* func)(double),double min,double max,int16_t * table,const int num)288 inline void gen_lut(double (*func)(double), double min, double max,
289                     int16_t* table, const int num) {
290   // size of table should equal to num + 1
291   // last element only for slope calculation
292   double step = (max - min) / (num - 1);
293   double half_step = step / 2.0;
294   for (int i = 0; i < num - 1; i++) {
295     double sample_val = TfLiteRound(func(min + i * step) * 32768.0);
296     double midpoint_interp_val =
297         TfLiteRound((func(min + (i + 1) * step) * 32768.0 +
298                      TfLiteRound(func(min + i * step) * 32768.0)) /
299                     2.0);
300     double midpoint_val =
301         TfLiteRound(func(min + i * step + half_step) * 32768.0);
302     double midpoint_err = midpoint_interp_val - midpoint_val;
303     double bias = TfLiteRound(midpoint_err / 2.0);
304     table[i] = std::min<double>(std::max<double>(sample_val - bias, -32768.0),
305                                 32767.0);
306   }
307   table[num - 1] = std::min<double>(
308       std::max<double>(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
309 }
310 
311 // generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
312 // softmax
313 // func - the function to build the LUT for (e.g exp(x))
314 // min,max - table limits
315 // table - pointer to buffer
316 // num - number of elements in the LUT
gen_lut(float (* func)(float),float min,float max,int16_t * table,const int num)317 inline void gen_lut(float (*func)(float), float min, float max, int16_t* table,
318                     const int num) {
319   // size of table should equal to num + 1
320   // last element only for slope calculation
321   float step = (max - min) / (num - 1);
322   float half_step = step / 2.0f;
323   for (int i = 0; i < num - 1; i++) {
324     float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
325     float midpoint_interp_val =
326         TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
327                      TfLiteRound(func(min + i * step) * 32768.0f)) /
328                     2.0f);
329     float midpoint_val =
330         TfLiteRound(func(min + i * step + half_step) * 32768.0f);
331     float midpoint_err = midpoint_interp_val - midpoint_val;
332     float bias = TfLiteRound(midpoint_err / 2.0f);
333     table[i] = std::min<float>(std::max<float>(sample_val - bias, -32768.0f),
334                                32767.0f);
335   }
336   table[num - 1] = std::min<float>(
337       std::max<float>(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
338 }
339 
340 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
generic_int16_table_lookup(int16_t value,const int16_t * lut)341 inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
342   // 512 base value, lut[513] only for calculate slope
343   uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
344   assert(index < 512 && "LUT index out of range.");
345   int16_t offset = value & 0x7f;
346 
347   // base and slope are Q0.15
348   int16_t base = lut[index];
349   int16_t slope = lut[index + 1] - lut[index];
350 
351   // Q0.15 * Q0.7 = Q0.22
352   // Round and convert from Q0.22 to Q0.15
353   int32_t delta = (static_cast<int32_t>(slope) * offset + 64) >> 7;
354 
355   // Q0.15 + Q0.15
356   return base + delta;
357 }
358 
359 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
360 
361 // We use combined sigmoid and tanh look-up table, since
362 // tanh(x) = 2*sigmoid(2*x) -1.
363 // Both functions are symmetric, so the LUT table is only needed
364 // for the absolute value of the input.
365 static const uint16_t sigmoid_table_uint16[256] = {
366     32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
367     40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
368     46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
369     52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
370     56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
371     59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
372     61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
373     62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
374     63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
375     64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
376     64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
377     65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
378     65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
379     65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
380     65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
381     65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
382     65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
383     65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
384     65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
385     65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
386     65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
387     65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
388     65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
389     65534, 65534, 65535};
390 
391 // TODO(b/77858996): Add these to gemmlowp.
392 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)393 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
394   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
395   return a;
396 }
397 
398 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)399 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
400   std::int64_t a64 = a;
401   std::int64_t b64 = b;
402   std::int64_t sum = a64 + b64;
403   return static_cast<std::int32_t>(std::min(
404       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
405       std::max(
406           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
407           sum)));
408 }
409 
410 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)411 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
412     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
413     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
414   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
415       SaturatingAddNonGemmlowp(a.raw(), b.raw()));
416 }
417 
418 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)419 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
420   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
421   return a;
422 }
423 
424 template <>
SaturatingSub(std::int16_t a,std::int16_t b)425 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
426   std::int32_t a32 = a;
427   std::int32_t b32 = b;
428   std::int32_t diff = a32 - b32;
429   return static_cast<std::int16_t>(
430       std::min(static_cast<int32_t>(32767),
431                std::max(static_cast<int32_t>(-32768), diff)));
432 }
433 
434 template <>
SaturatingSub(std::int32_t a,std::int32_t b)435 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
436   std::int64_t a64 = a;
437   std::int64_t b64 = b;
438   std::int64_t diff = a64 - b64;
439   return static_cast<std::int32_t>(std::min(
440       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
441       std::max(
442           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
443           diff)));
444 }
445 
446 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)447 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
448     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
449     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
450   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
451       SaturatingSub(a.raw(), b.raw()));
452 }
453 // End section to be moved to gemmlowp.
454 
455 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)456 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
457   if (exponent == 0) {
458     return x;
459   }
460   using ScalarIntegerType =
461       typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
462   const IntegerType min =
463       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
464   const IntegerType max =
465       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
466   const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
467 
468   const std::int32_t threshold =
469       ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
470   const IntegerType positive_mask =
471       gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
472   const IntegerType negative_mask =
473       gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
474 
475   IntegerType result = gemmlowp::ShiftLeft(x, exponent);
476   result = gemmlowp::SelectUsingMask(positive_mask, max, result);
477   result = gemmlowp::SelectUsingMask(negative_mask, min, result);
478   return result;
479 }
480 
481 // If we want to leave IntegerBits fixed, then multiplication
482 // by a power of two has to be saturating/rounding, not exact anymore.
483 template <typename tRawType, int tIntegerBits>
484 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)485 SaturatingRoundingMultiplyByPOTParam(
486     gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
487   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
488       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
489 }
490 
491 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)492 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
493                                             int16_t* multiplier_int16_t) {
494   TFLITE_DCHECK_GE(multiplier_int32_t, 0);
495   static constexpr int32_t kRoundingOffset = 1 << 15;
496   if (multiplier_int32_t >=
497       std::numeric_limits<int32_t>::max() - kRoundingOffset) {
498     *multiplier_int16_t = std::numeric_limits<int16_t>::max();
499     return;
500   }
501   const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
502   TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
503   TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
504   *multiplier_int16_t = result;
505   TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
506 }
507 
508 // Minimum output bits to accommodate log of maximum input range.  It actually
509 // does not matter if one considers, say, [-64,64] or [-64,64).
510 //
511 // For example, run this through Octave:
512 // [0:127; ...
513 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
514 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)515 constexpr int min_log_x_output_bits(int input_bits) {
516   return input_bits > 90   ? 7
517          : input_bits > 44 ? 6
518          : input_bits > 21 ? 5
519          : input_bits > 10 ? 4
520          : input_bits > 4  ? 3
521          : input_bits > 1  ? 2
522                            : 1;
523 }
524 
525 // Although currently the name of this function says that it cannot handle
526 // values less than 1, in practice it can handle as low as 1/x_max, where
527 // x_max is the largest representable input.  In other words, the output range
528 // is symmetric.
529 template <int OutputIntegerBits, int InputIntegerBits>
530 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)531 log_x_for_x_greater_than_or_equal_to_1_impl(
532     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
533   // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
534   // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
535   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
536   // The reason for accumulating the result with an extra bit of headroom is
537   // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
538   // recip_denom will otherwise introduce an error.
539   static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
540   using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
541 
542   const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
543       FixedPoint0, 1488522236, std::log(2.0));
544   const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
545       FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
546   const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
547       FixedPoint0, 1518500250, std::sqrt(0.5));
548   const FixedPoint0 one_quarter =
549       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
550 
551   const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
552       FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
553   const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
554       FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
555   const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
556       FixedPoint0, 1057819769,
557       2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
558   const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
559       FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
560 
561   const FixedPointAccum shifted_quarter =
562       gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
563 
564   // Reinterpret the input value as Q0.31, because we will figure out the
565   // required shift "ourselves" instead of using, say, Rescale.
566   FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
567   // z_a_pow_2 = input_integer_bits - z_a_headroom;
568   int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
569   FixedPoint0 r_a_tmp =
570       SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
571   const int32_t r_a_raw =
572       SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
573   // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
574   // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
575   //                   InputIntegerBits - z_b_headroom - 0.25);
576   const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
577       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
578           InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
579       shifted_quarter);
580 
581   // z_b is treated like z_a, but premultiplying by sqrt(0.5).
582   FixedPoint0 z_b = z_a * sqrt_half;
583   int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
584   const int32_t r_b_raw =
585       SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
586   const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
587       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
588           InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
589       shifted_quarter);
590 
591   const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
592   const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
593       std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
594 
595   const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
596   FixedPoint0 q = r - sqrt_sqrt_half;
597   q = q + q;
598 
599   const FixedPoint0 common_sq = q * q;
600   const FixedPoint0 num = q * r + q * common_sq * alpha_n;
601   const FixedPoint0 denom_minus_one_0 =
602       p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
603   const FixedPoint0 recip_denom =
604       one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
605 
606   const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
607   return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
608                                               num_scaled * recip_denom);
609 }
610 
611 template <int OutputIntegerBits, int InputIntegerBits>
612 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)613 log_x_for_x_greater_than_or_equal_to_1(
614     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
615   static_assert(
616       OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
617       "Output integer bits must be sufficient to accommodate logs of inputs.");
618   return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
619                                                      InputIntegerBits>(
620       input_val);
621 }
622 
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)623 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
624                              int* num_bits_over_unit) {
625   int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
626   // This is the number of bits to the left of the binary point above 1.0.
627   // Consider x=1.25.  In that case shifted_scale=0.8 and
628   // no later adjustment will be needed.
629   *num_bits_over_unit = x_integer_digits - headroom_plus_one;
630   const int32_t shifted_sum_minus_one =
631       static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
632                            (static_cast<uint32_t>(1) << 31));
633 
634   gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
635       gemmlowp::one_over_one_plus_x_for_x_in_0_1(
636           gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
637   return shifted_scale.raw();
638 }
639 
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)640 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
641                                              int32_t* output_inv_sqrt,
642                                              int* output_shift) {
643   TFLITE_DCHECK_GE(input, 0);
644   if (input <= 1) {
645     // Handle the input value 1 separately to avoid overflow in that case
646     // in the general computation below (b/143972021). Also handle 0 as if it
647     // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
648     // but rare/unrealistic input value. We can expect both to occur in some
649     // incompletely trained models, but probably not in fully trained models.
650     *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
651     *output_shift = 0;
652     return;
653   }
654   TFLITE_DCHECK_GT(input, 1);
655   *output_shift = 11;
656   while (input >= (1 << 29)) {
657     input /= 4;
658     ++*output_shift;
659   }
660   const unsigned max_left_shift_bits =
661       CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
662   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
663   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
664   *output_shift -= left_shift_bit_pairs;
665   input <<= 2 * left_shift_bit_pairs;
666   TFLITE_DCHECK_GE(input, (1 << 27));
667   TFLITE_DCHECK_LT(input, (1 << 29));
668   using gemmlowp::FixedPoint;
669   using gemmlowp::Rescale;
670   using gemmlowp::SaturatingRoundingMultiplyByPOT;
671   // Using 3 integer bits gives us enough room for the internal arithmetic in
672   // this Newton-Raphson iteration.
673   using F3 = FixedPoint<int32_t, 3>;
674   using F0 = FixedPoint<int32_t, 0>;
675   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
676   const F3 fixedpoint_half_input =
677       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
678   const F3 fixedpoint_half_three =
679       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
680   // Newton-Raphson iteration
681   // Naive unoptimized starting guess: x = 1
682   F3 x = F3::One();
683   // Naive unoptimized number of iterations: 5
684   for (int i = 0; i < 5; i++) {
685     const F3 x3 = Rescale<3>(x * x * x);
686     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
687   }
688   const F0 fixedpoint_half_sqrt_2 =
689       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
690   x = x * fixedpoint_half_sqrt_2;
691   *output_inv_sqrt = x.raw();
692   if (*output_shift < 0) {
693     *output_inv_sqrt <<= -*output_shift;
694     *output_shift = 0;
695   }
696   // Convert right shift (right is positive) to left shift.
697   *output_shift *= reverse_shift;
698 }
699 
700 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
701 // BROADCASTING.
702 //
703 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
704 // rectangular array of numbers.
705 //
706 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
707 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
708 // to enable simple unoptimized implementations of element-wise broadcasting
709 // operations.
710 template <int N>
711 struct NdArrayDesc {
712   // The "extent" of each dimension. Indices along dimension d must be in the
713   // half-open interval [0, extents[d]).
714   int extents[N];
715 
716   // The number of *elements* (not bytes) between consecutive indices of each
717   // dimension.
718   int strides[N];
719 };
720 
721 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
722 // BROADCASTING.
723 //
724 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)725 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
726                             int i3) {
727   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
728   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
729   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
730   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
731   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
732          i3 * desc.strides[3];
733 }
734 
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])735 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
736   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
737          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
738          indexes[4] * desc.strides[4];
739 }
740 
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])741 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
742   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
743          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
744          indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
745          indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
746 }
747 
748 // Given the dimensions of the operands for an element-wise binary broadcast,
749 // adjusts them so that they can be directly iterated over with simple loops.
750 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
751 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
752 //
753 // This function assumes that the two input shapes are compatible up to
754 // broadcasting and the shorter one has already been prepended with 1s to be the
755 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
756 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
757 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
758 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
759 //
760 // When two shapes are compatible up to broadcasting, for each dimension d,
761 // the input extents are either equal, or one of them is 1.
762 //
763 // This function performs the following for each dimension d:
764 // - If the extents are equal, then do nothing since the loop that walks over
765 //   both of the input arrays is correct.
766 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
767 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
768 //   array0 to be referenced *at any index* in dimension d and still access the
769 //   same slice.
770 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)771 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
772                                                 const Dims<N>& input1_dims,
773                                                 NdArrayDesc<N>* desc0_out,
774                                                 NdArrayDesc<N>* desc1_out) {
775   TFLITE_DCHECK(desc0_out != nullptr);
776   TFLITE_DCHECK(desc1_out != nullptr);
777 
778   // Copy dims to desc.
779   for (int i = 0; i < N; ++i) {
780     desc0_out->extents[i] = input0_dims.sizes[i];
781     desc0_out->strides[i] = input0_dims.strides[i];
782     desc1_out->extents[i] = input1_dims.sizes[i];
783     desc1_out->strides[i] = input1_dims.strides[i];
784   }
785 
786   // Walk over each dimension. If the extents are equal do nothing.
787   // Otherwise, set the desc with extent 1 to have extent equal to the other and
788   // stride 0.
789   for (int i = 0; i < N; ++i) {
790     const int extent0 = ArraySize(input0_dims, i);
791     const int extent1 = ArraySize(input1_dims, i);
792     if (extent0 != extent1) {
793       if (extent0 == 1) {
794         desc0_out->strides[i] = 0;
795         desc0_out->extents[i] = extent1;
796       } else {
797         TFLITE_DCHECK_EQ(extent1, 1);
798         desc1_out->strides[i] = 0;
799         desc1_out->extents[i] = extent0;
800       }
801     }
802   }
803 }
804 
805 // Copies dims to desc, calculating strides.
806 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)807 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
808                            NdArrayDesc<N>* desc_out) {
809   int desc_stride = 1;
810   for (int i = N - 1; i >= 0; --i) {
811     desc_out->extents[i] = input_shape.Dims(i);
812     desc_out->strides[i] = desc_stride;
813     desc_stride *= input_shape.Dims(i);
814   }
815 }
816 
817 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)818 inline void NdArrayDescsForElementwiseBroadcast(
819     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
820     NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
821   TFLITE_DCHECK(desc0_out != nullptr);
822   TFLITE_DCHECK(desc1_out != nullptr);
823 
824   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
825   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
826 
827   // Copy dims to desc, calculating strides.
828   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
829   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
830 
831   // Walk over each dimension. If the extents are equal do nothing.
832   // Otherwise, set the desc with extent 1 to have extent equal to the other and
833   // stride 0.
834   for (int i = 0; i < N; ++i) {
835     const int extent0 = extended_input0_shape.Dims(i);
836     const int extent1 = extended_input1_shape.Dims(i);
837     if (extent0 != extent1) {
838       if (extent0 == 1) {
839         desc0_out->strides[i] = 0;
840         desc0_out->extents[i] = extent1;
841       } else {
842         TFLITE_DCHECK_EQ(extent1, 1);
843         desc1_out->strides[i] = 0;
844         desc1_out->extents[i] = extent0;
845       }
846     }
847   }
848 }
849 
850 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)851 inline void NdArrayDescsForElementwiseBroadcast(
852     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
853     const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
854     NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
855   TFLITE_DCHECK(desc0_out != nullptr);
856   TFLITE_DCHECK(desc1_out != nullptr);
857   TFLITE_DCHECK(desc2_out != nullptr);
858 
859   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
860   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
861   auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
862 
863   // Copy dims to desc, calculating strides.
864   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
865   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
866   CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
867 
868   // Walk over each dimension. If the extents are equal do nothing.
869   // Otherwise, set the desc with extent 1 to have extent equal to the other and
870   // stride 0.
871   for (int i = 0; i < N; ++i) {
872     const int extent0 = extended_input0_shape.Dims(i);
873     const int extent1 = extended_input1_shape.Dims(i);
874     const int extent2 = extended_input2_shape.Dims(i);
875 
876     int extent = extent0;
877     if (extent1 != 1) extent = extent1;
878     if (extent2 != 1) extent = extent2;
879 
880     TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
881     TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
882     TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
883 
884     if (!(extent0 == extent1 && extent1 == extent2)) {
885       if (extent0 == 1) {
886         desc0_out->strides[i] = 0;
887         desc0_out->extents[i] = extent;
888       }
889       if (extent1 == 1) {
890         desc1_out->strides[i] = 0;
891         desc1_out->extents[i] = extent;
892       }
893       if (extent2 == 1) {
894         desc2_out->strides[i] = 0;
895         desc2_out->extents[i] = extent;
896       }
897     }
898   }
899 }
900 
901 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
902 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
903 // re-writen as:
904 // for (int b = 0; b < output.extents[0]; ++b) {
905 //   for (int y = 0; y < output.extents[1]; ++y) {
906 //     for (int x = 0; x < output.extents[2]; ++x) {
907 //       for (int c = 0; c < output.extents[3]; ++c) {
908 //           calc({b,y,x,c});
909 //       }
910 //     }
911 //   }
912 // }
913 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])914 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
915     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
916   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
917     NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
918   }
919 }
920 
921 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])922 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
923     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
924   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
925     calc(indexes);
926   }
927 }
928 
929 // Execute the calc function in the innermost iteration based on the shape of
930 // the output. The calc function should take a single argument of type int[N].
931 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)932 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
933   int indexes[N] = {0};
934   NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
935 }
936 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
937 // gemmlowp.
938 //
939 // Returns the runtime argument rounded down to the nearest multiple of
940 // the fixed Modulus.
941 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)942 Integer RoundDown(Integer i) {
943   return i - (i % Modulus);
944 }
945 
946 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
947 // gemmlowp.
948 //
949 // Returns the runtime argument rounded up to the nearest multiple of
950 // the fixed Modulus.
951 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)952 Integer RoundUp(Integer i) {
953   return RoundDown<Modulus>(i + Modulus - 1);
954 }
955 
956 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
957 // gemmlowp.
958 //
959 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
960 template <typename Integer>
CeilQuotient(Integer a,Integer b)961 Integer CeilQuotient(Integer a, Integer b) {
962   return (a + b - 1) / b;
963 }
964 
965 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
966 // the direct dependency of internal/optimized/ on gemmlowp.
967 //
968 // It computes a reasonable number of threads to use for a GEMM of shape
969 // (rows, cols, depth).
970 //
971 // TODO(b/131910176): get rid of this function by switching each call site
972 // to its own more sensible logic for its own workload.
973 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)974 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
975                                 int depth) {
976   // Early-exit in the default case where multi-threading is disabled.
977   if (max_num_threads == 1) {
978     return 1;
979   }
980 
981   // Ensure that each thread has KernelRows rows to process, if at all possible.
982   int thread_count = std::min(max_num_threads, rows / KernelRows);
983 
984   // Limit the number of threads according to the overall size of the problem.
985   if (thread_count > 1) {
986     // Empirically determined value.
987     static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
988 
989     // We can only multiply two out of three sizes without risking overflow
990     const std::uint64_t cubic_size =
991         std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
992 
993     thread_count = std::min(
994         thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
995   }
996 
997   if (thread_count < 1) {
998     thread_count = 1;
999   }
1000 
1001   assert(thread_count > 0 && thread_count <= max_num_threads);
1002   return thread_count;
1003 }
1004 
1005 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1006 void optimized_ops_preload_l1_stream(const T* ptr) {
1007 #ifdef __GNUC__
1008   // builtin offered by GCC-compatible compilers including clang
1009   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1010 #else
1011   (void)ptr;
1012 #endif
1013 }
1014 
1015 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1016 void optimized_ops_preload_l1_keep(const T* ptr) {
1017 #ifdef __GNUC__
1018   // builtin offered by GCC-compatible compilers including clang
1019   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1020 #else
1021   (void)ptr;
1022 #endif
1023 }
1024 
1025 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1026 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1027 #ifdef __GNUC__
1028   // builtin offered by GCC-compatible compilers including clang
1029   __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1030 #else
1031   (void)ptr;
1032 #endif
1033 }
1034 
1035 }  // namespace tflite
1036 
1037 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1038