• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
17 
18 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
19 #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
20 #define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
21 #endif
22 #endif
23 
24 #include <functional>
25 
26 #include "fixedpoint/fixedpoint.h"
27 #include "tensorflow/lite/kernels/internal/cppmath.h"
28 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 
31 namespace tflite {
32 
33 constexpr int kReverseShift = -1;
34 
GetActivationMinMax(FusedActivationFunctionType ac,float * output_activation_min,float * output_activation_max)35 inline void GetActivationMinMax(FusedActivationFunctionType ac,
36                                 float* output_activation_min,
37                                 float* output_activation_max) {
38   switch (ac) {
39     case FusedActivationFunctionType::kNone:
40       *output_activation_min = std::numeric_limits<float>::lowest();
41       *output_activation_max = std::numeric_limits<float>::max();
42       break;
43     case FusedActivationFunctionType::kRelu:
44       *output_activation_min = 0.f;
45       *output_activation_max = std::numeric_limits<float>::max();
46       break;
47     case FusedActivationFunctionType::kRelu1:
48       *output_activation_min = -1.f;
49       *output_activation_max = 1.f;
50       break;
51     case FusedActivationFunctionType::kRelu6:
52       *output_activation_min = 0.f;
53       *output_activation_max = 6.f;
54       break;
55   }
56 }
57 
58 template <typename T>
ActivationFunctionWithMinMax(T x,T output_activation_min,T output_activation_max)59 inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
60                                       T output_activation_max) {
61   using std::max;
62   using std::min;
63   return min(max(x, output_activation_min), output_activation_max);
64 }
65 
66 // Legacy function, left for compatibility only.
67 template <FusedActivationFunctionType Ac>
ActivationFunction(float x)68 float ActivationFunction(float x) {
69   float output_activation_min, output_activation_max;
70   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
71   return ActivationFunctionWithMinMax(x, output_activation_min,
72                                       output_activation_max);
73 }
74 
BiasAndClamp(float clamp_min,float clamp_max,int bias_size,const float * bias_data,int array_size,float * array_data)75 inline void BiasAndClamp(float clamp_min, float clamp_max, int bias_size,
76                          const float* bias_data, int array_size,
77                          float* array_data) {
78   // Note: see b/132215220: in May 2019 we thought it would be OK to replace
79   // this with the Eigen one-liner:
80   //   return (array.colwise() + bias).cwiseMin(clamp_max).cwiseMin(clamp_max).
81   // This turned out to severely regress performance: +4ms (i.e. 8%) on
82   // MobileNet v2 / 1.0 / 224. So we keep custom NEON code for now.
83   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
84 #ifdef USE_NEON
85   float* array_ptr = array_data;
86   float* array_end_ptr = array_ptr + array_size;
87   const auto clamp_min_vec = vdupq_n_f32(clamp_min);
88   const auto clamp_max_vec = vdupq_n_f32(clamp_max);
89   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
90     int i = 0;
91     for (; i <= bias_size - 16; i += 16) {
92       auto b0 = vld1q_f32(bias_data + i);
93       auto b1 = vld1q_f32(bias_data + i + 4);
94       auto b2 = vld1q_f32(bias_data + i + 8);
95       auto b3 = vld1q_f32(bias_data + i + 12);
96       auto a0 = vld1q_f32(array_ptr + i);
97       auto a1 = vld1q_f32(array_ptr + i + 4);
98       auto a2 = vld1q_f32(array_ptr + i + 8);
99       auto a3 = vld1q_f32(array_ptr + i + 12);
100       auto x0 = vaddq_f32(a0, b0);
101       auto x1 = vaddq_f32(a1, b1);
102       auto x2 = vaddq_f32(a2, b2);
103       auto x3 = vaddq_f32(a3, b3);
104       x0 = vmaxq_f32(clamp_min_vec, x0);
105       x1 = vmaxq_f32(clamp_min_vec, x1);
106       x2 = vmaxq_f32(clamp_min_vec, x2);
107       x3 = vmaxq_f32(clamp_min_vec, x3);
108       x0 = vminq_f32(clamp_max_vec, x0);
109       x1 = vminq_f32(clamp_max_vec, x1);
110       x2 = vminq_f32(clamp_max_vec, x2);
111       x3 = vminq_f32(clamp_max_vec, x3);
112       vst1q_f32(array_ptr + i, x0);
113       vst1q_f32(array_ptr + i + 4, x1);
114       vst1q_f32(array_ptr + i + 8, x2);
115       vst1q_f32(array_ptr + i + 12, x3);
116     }
117     for (; i <= bias_size - 4; i += 4) {
118       auto b = vld1q_f32(bias_data + i);
119       auto a = vld1q_f32(array_ptr + i);
120       auto x = vaddq_f32(a, b);
121       x = vmaxq_f32(clamp_min_vec, x);
122       x = vminq_f32(clamp_max_vec, x);
123       vst1q_f32(array_ptr + i, x);
124     }
125     for (; i < bias_size; i++) {
126       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
127                                                   clamp_min, clamp_max);
128     }
129   }
130 #else  // not NEON
131   for (int array_offset = 0; array_offset < array_size;
132        array_offset += bias_size) {
133     for (int i = 0; i < bias_size; i++) {
134       array_data[array_offset + i] = ActivationFunctionWithMinMax(
135           array_data[array_offset + i] + bias_data[i], clamp_min, clamp_max);
136     }
137   }
138 #endif
139 }
140 
MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x,int32_t quantized_multiplier,int left_shift)141 inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(
142     int32_t x, int32_t quantized_multiplier, int left_shift) {
143   using gemmlowp::RoundingDivideByPOT;
144   using gemmlowp::SaturatingRoundingDoublingHighMul;
145   return RoundingDivideByPOT(
146       SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
147 }
148 
MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x,int32_t quantized_multiplier,int left_shift)149 inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
150     int32_t x, int32_t quantized_multiplier, int left_shift) {
151   using gemmlowp::SaturatingRoundingDoublingHighMul;
152   return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
153                                            quantized_multiplier);
154 }
155 
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)156 inline int32_t MultiplyByQuantizedMultiplier(int32_t x,
157                                              int32_t quantized_multiplier,
158                                              int shift) {
159   using gemmlowp::RoundingDivideByPOT;
160   using gemmlowp::SaturatingRoundingDoublingHighMul;
161   int left_shift = shift > 0 ? shift : 0;
162   int right_shift = shift > 0 ? 0 : -shift;
163   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
164                                  x * (1 << left_shift), quantized_multiplier),
165                              right_shift);
166 }
167 
MultiplyByQuantizedMultiplier(int64_t x,int32_t quantized_multiplier,int shift)168 inline int32_t MultiplyByQuantizedMultiplier(int64_t x,
169                                              int32_t quantized_multiplier,
170                                              int shift) {
171   // Inputs:
172   // - quantized_multiplier has fixed point at bit 31
173   // - shift is -31 to +7 (negative for right shift)
174   //
175   // Assumptions: The following input ranges are assumed
176   // - quantize_scale>=0  (the usual range is (1<<30) to (1>>31)-1)
177   // - scaling is chosen so final scaled result fits in int32_t
178   // - input x is in the range -(1<<47) <= x < (1<<47)
179   assert(quantized_multiplier >= 0);
180   assert(shift >= -31 && shift < 8);
181   assert(x >= -(static_cast<int64_t>(1) << 47) &&
182          x < (static_cast<int64_t>(1) << 47));
183 
184   int32_t reduced_multiplier = (quantized_multiplier < 0x7FFF0000)
185                                    ? ((quantized_multiplier + (1 << 15)) >> 16)
186                                    : 0x7FFF;
187   int total_shift = 15 - shift;
188   x = (x * (int64_t)reduced_multiplier) + ((int64_t)1 << (total_shift - 1));
189   int32_t result = x >> total_shift;
190   return result;
191 }
192 
193 #ifdef USE_NEON
194 // Round uses ARM's rounding shift right.
MultiplyByQuantizedMultiplier4Rows(int32x4x4_t input_val,int32_t quantized_multiplier,int shift)195 inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows(
196     int32x4x4_t input_val, int32_t quantized_multiplier, int shift) {
197   const int left_shift = std::max(shift, 0);
198   const int right_shift = std::min(shift, 0);
199   int32x4x4_t result;
200 
201   int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier);
202   int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
203   int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
204 
205   result.val[0] =
206       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup),
207                                multiplier_dup),
208                  right_shift_dup);
209 
210   result.val[1] =
211       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup),
212                                multiplier_dup),
213                  right_shift_dup);
214 
215   result.val[2] =
216       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup),
217                                multiplier_dup),
218                  right_shift_dup);
219 
220   result.val[3] =
221       vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup),
222                                multiplier_dup),
223                  right_shift_dup);
224 
225   return result;
226 }
227 #endif
228 
229 template <typename T>
CountLeadingZeros(T integer_input)230 int CountLeadingZeros(T integer_input) {
231   static_assert(std::is_unsigned<T>::value,
232                 "Only unsigned integer types handled.");
233 #if defined(__GNUC__)
234   return integer_input ? __builtin_clz(integer_input)
235                        : std::numeric_limits<T>::digits;
236 #else
237   if (integer_input == 0) {
238     return std::numeric_limits<T>::digits;
239   }
240 
241   const T one_in_leading_positive = static_cast<T>(1)
242                                     << (std::numeric_limits<T>::digits - 1);
243   int leading_zeros = 0;
244   while (integer_input < one_in_leading_positive) {
245     integer_input <<= 1;
246     ++leading_zeros;
247   }
248   return leading_zeros;
249 #endif
250 }
251 
252 template <typename T>
CountLeadingSignBits(T integer_input)253 inline int CountLeadingSignBits(T integer_input) {
254   static_assert(std::is_signed<T>::value, "Only signed integer types handled.");
255 #if defined(__GNUC__) && !defined(__clang__)
256   return integer_input ? __builtin_clrsb(integer_input)
257                        : std::numeric_limits<T>::digits;
258 #else
259   using U = typename std::make_unsigned<T>::type;
260   return integer_input >= 0
261              ? CountLeadingZeros(static_cast<U>(integer_input)) - 1
262          : integer_input != std::numeric_limits<T>::min()
263              ? CountLeadingZeros(2 * static_cast<U>(-integer_input) - 1)
264              : 0;
265 #endif
266 }
267 
268 // Use "count leading zeros" helper functions to do a fast Floor(log_2(x)).
269 template <typename Integer>
FloorLog2(Integer n)270 inline Integer FloorLog2(Integer n) {
271   static_assert(std::is_integral<Integer>::value, "");
272   static_assert(std::is_signed<Integer>::value, "");
273   static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
274   TFLITE_CHECK_GT(n, 0);
275   if (sizeof(Integer) == 4) {
276     return 30 - CountLeadingSignBits(n);
277   } else {
278     return 62 - CountLeadingSignBits(n);
279   }
280 }
281 
282 // The size of the LUT depends on the type of input. For int8 inputs a simple
283 // 256 entries LUT is used. For int16 inputs the high 9 bits are used for
284 // indexing and the 7 remaining bits are used for interpolation. We thus use a
285 // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
286 // to interpolate the last value.
287 template <typename LutInT>
lut_size()288 constexpr int lut_size() {
289   static_assert(std::is_same<LutInT, int8_t>::value ||
290                     std::is_same<LutInT, int16_t>::value,
291                 "Only LUTs with int8 or int16 inputs are supported.");
292   return std::is_same<LutInT, int8_t>::value ? 256 : 513;
293 }
294 
295 // Generate a LUT for 'func' which can be used to approximate functions like
296 // exp, log, ...
297 //
298 // - func: the function to build the LUT for (e.g exp(x))
299 // - input_min, input_max: range of the func inputs
300 // - output_min, output_max: range of the func outputs
301 // - lut: pointer to the LUT table to fill, the table must be of size
302 // lut_size<LutInT>()
303 template <typename FloatT, typename LutInT, typename LutOutT>
gen_lut(FloatT (* func)(FloatT),FloatT input_min,FloatT input_max,FloatT output_min,FloatT output_max,LutOutT * lut)304 inline void gen_lut(FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
305                     FloatT output_min, FloatT output_max, LutOutT* lut) {
306   static_assert(std::is_same<LutInT, int8_t>::value ||
307                     std::is_same<LutInT, int16_t>::value,
308                 "Only LUTs with int8 or int16 inputs are supported.");
309   static_assert(std::is_same<LutOutT, int8_t>::value ||
310                     std::is_same<LutOutT, int16_t>::value,
311                 "Only LUTs with int8 or int16 outputs are supported.");
312   static_assert(std::is_floating_point<FloatT>::value,
313                 "FloatT must be a floating-point type.");
314 
315   const int nb_steps = std::is_same<LutInT, int8_t>::value ? 256 : 512;
316   const FloatT step = (input_max - input_min) / nb_steps;
317   const FloatT half_step = step / 2;
318   const FloatT output_scaling_inv =
319       static_cast<FloatT>(std::numeric_limits<LutOutT>::max() -
320                           std::numeric_limits<LutOutT>::min() + 1) /
321       (output_max - output_min);
322   const FloatT table_min =
323       static_cast<FloatT>(std::numeric_limits<LutOutT>::min());
324   const FloatT table_max =
325       static_cast<FloatT>(std::numeric_limits<LutOutT>::max());
326 
327   for (int i = 0; i < nb_steps; i++) {
328     const FloatT val = func(input_min + i * step);
329     const FloatT val_midpoint = func(input_min + i * step + half_step);
330     const FloatT val_next = func(input_min + (i + 1) * step);
331 
332     const FloatT sample_val = TfLiteRound(val * output_scaling_inv);
333     const FloatT midpoint_interp_val =
334         TfLiteRound((val_next * output_scaling_inv +
335                      TfLiteRound(val * output_scaling_inv)) /
336                     2);
337     const FloatT midpoint_val = TfLiteRound(val_midpoint * output_scaling_inv);
338     const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
339     const FloatT bias = TfLiteRound(midpoint_err / 2);
340 
341     lut[i] = static_cast<LutOutT>(std::min<FloatT>(
342         std::max<FloatT>(sample_val - bias, table_min), table_max));
343   }
344 
345   const bool with_extra_interpolation_value =
346       std::is_same<LutInT, int16_t>::value;
347   if (with_extra_interpolation_value) {
348     lut[nb_steps] = static_cast<LutOutT>(std::min<FloatT>(
349         std::max<FloatT>(TfLiteRound(func(input_max) * output_scaling_inv),
350                          table_min),
351         table_max));
352   }
353 }
354 
355 // LUT must have 513 values
356 template <typename LutOutT>
lut_lookup_with_interpolation(int16_t value,const LutOutT * lut)357 inline LutOutT lut_lookup_with_interpolation(int16_t value,
358                                              const LutOutT* lut) {
359   static_assert(std::is_same<LutOutT, int8_t>::value ||
360                     std::is_same<LutOutT, int16_t>::value,
361                 "Only LUTs with int8 or int16 outputs are supported.");
362   // 512 base values, lut[513] is only used to calculate the slope
363   const uint16_t index = static_cast<uint16_t>(256 + (value >> 7));
364   assert(index < 512 && "LUT index out of range.");
365   const int16_t offset = value & 0x7f;
366 
367   // Base and slope are Q0.x
368   const LutOutT base = lut[index];
369   const LutOutT slope = lut[index + 1] - lut[index];
370 
371   // Q0.x * Q0.7 = Q0.(x + 7)
372   // Round and convert from Q0.(x + 7) to Q0.x
373   const int delta = (slope * offset + 64) >> 7;
374 
375   // Q0.15 + Q0.15
376   return static_cast<LutOutT>(base + delta);
377 }
378 
379 // int16_t -> int16_t table lookup with interpolation
380 // LUT must have 513 values
lut_lookup(int16_t value,const int16_t * lut)381 inline int16_t lut_lookup(int16_t value, const int16_t* lut) {
382   return lut_lookup_with_interpolation(value, lut);
383 }
384 
385 // int16_t -> int8_t table lookup with interpolation
386 // LUT must have 513 values
lut_lookup(int16_t value,const int8_t * lut)387 inline int8_t lut_lookup(int16_t value, const int8_t* lut) {
388   return lut_lookup_with_interpolation(value, lut);
389 }
390 
391 // int8_t -> int8_t table lookup without interpolation
392 // LUT must have 256 values
lut_lookup(int8_t value,const int8_t * lut)393 inline int8_t lut_lookup(int8_t value, const int8_t* lut) {
394   return lut[128 + value];
395 }
396 
397 // int8_t -> int16_t table lookup without interpolation
398 // LUT must have 256 values
lut_lookup(int8_t value,const int16_t * lut)399 inline int16_t lut_lookup(int8_t value, const int16_t* lut) {
400   return lut[128 + value];
401 }
402 
403 // Table of sigmoid(i/24) at 0.16 format - 256 elements.
404 
405 // We use combined sigmoid and tanh look-up table, since
406 // tanh(x) = 2*sigmoid(2*x) -1.
407 // Both functions are symmetric, so the LUT table is only needed
408 // for the absolute value of the input.
409 static const uint16_t sigmoid_table_uint16[256] = {
410     32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
411     40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
412     46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
413     52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
414     56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
415     59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
416     61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
417     62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
418     63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
419     64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
420     64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
421     65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
422     65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
423     65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
424     65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
425     65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
426     65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
427     65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
428     65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
429     65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
430     65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
431     65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
432     65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
433     65534, 65534, 65535};
434 
435 // TODO(b/77858996): Add these to gemmlowp.
436 template <typename IntegerType>
SaturatingAddNonGemmlowp(IntegerType a,IntegerType b)437 IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
438   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
439   return a;
440 }
441 
442 template <>
SaturatingAddNonGemmlowp(std::int32_t a,std::int32_t b)443 inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
444   std::int64_t a64 = a;
445   std::int64_t b64 = b;
446   std::int64_t sum = a64 + b64;
447   return static_cast<std::int32_t>(std::min(
448       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
449       std::max(
450           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
451           sum)));
452 }
453 
454 template <typename tRawType, int tIntegerBits>
SaturatingAddNonGemmlowp(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)455 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
456     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
457     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
458   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
459       SaturatingAddNonGemmlowp(a.raw(), b.raw()));
460 }
461 
462 template <typename IntegerType>
SaturatingSub(IntegerType a,IntegerType b)463 IntegerType SaturatingSub(IntegerType a, IntegerType b) {
464   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
465   return a;
466 }
467 
468 template <>
SaturatingSub(std::int16_t a,std::int16_t b)469 inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
470   std::int32_t a32 = a;
471   std::int32_t b32 = b;
472   std::int32_t diff = a32 - b32;
473   return static_cast<std::int16_t>(
474       std::min(static_cast<int32_t>(32767),
475                std::max(static_cast<int32_t>(-32768), diff)));
476 }
477 
478 template <>
SaturatingSub(std::int32_t a,std::int32_t b)479 inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
480   std::int64_t a64 = a;
481   std::int64_t b64 = b;
482   std::int64_t diff = a64 - b64;
483   return static_cast<std::int32_t>(std::min(
484       static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
485       std::max(
486           static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
487           diff)));
488 }
489 
490 template <typename tRawType, int tIntegerBits>
SaturatingSub(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,gemmlowp::FixedPoint<tRawType,tIntegerBits> b)491 gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
492     gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
493     gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
494   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
495       SaturatingSub(a.raw(), b.raw()));
496 }
497 // End section to be moved to gemmlowp.
498 
499 template <typename IntegerType>
SaturatingRoundingMultiplyByPOTParam(IntegerType x,int exponent)500 IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
501   if (exponent == 0) {
502     return x;
503   }
504   using ScalarIntegerType =
505       typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
506   const IntegerType min =
507       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
508   const IntegerType max =
509       gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
510   const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
511 
512   const std::int32_t threshold =
513       ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
514   const IntegerType positive_mask =
515       gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
516   const IntegerType negative_mask =
517       gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
518 
519   IntegerType result = gemmlowp::ShiftLeft(x, exponent);
520   result = gemmlowp::SelectUsingMask(positive_mask, max, result);
521   result = gemmlowp::SelectUsingMask(negative_mask, min, result);
522   return result;
523 }
524 
525 // If we want to leave IntegerBits fixed, then multiplication
526 // by a power of two has to be saturating/rounding, not exact anymore.
527 template <typename tRawType, int tIntegerBits>
528 gemmlowp::FixedPoint<tRawType, tIntegerBits>
SaturatingRoundingMultiplyByPOTParam(gemmlowp::FixedPoint<tRawType,tIntegerBits> a,int exponent)529 SaturatingRoundingMultiplyByPOTParam(
530     gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
531   return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
532       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
533 }
534 
535 // Convert int32_t multiplier to int16_t with rounding.
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,int16_t * multiplier_int16_t)536 inline void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32_t,
537                                             int16_t* multiplier_int16_t) {
538   TFLITE_DCHECK_GE(multiplier_int32_t, 0);
539   static constexpr int32_t kRoundingOffset = 1 << 15;
540   if (multiplier_int32_t >=
541       std::numeric_limits<int32_t>::max() - kRoundingOffset) {
542     *multiplier_int16_t = std::numeric_limits<int16_t>::max();
543     return;
544   }
545   const int32_t result = (multiplier_int32_t + kRoundingOffset) >> 16;
546   TFLITE_DCHECK_LE(result << 16, multiplier_int32_t + kRoundingOffset);
547   TFLITE_DCHECK_GT(result << 16, multiplier_int32_t - kRoundingOffset);
548   *multiplier_int16_t = result;
549   TFLITE_DCHECK_EQ(*multiplier_int16_t, result);
550 }
551 
552 // Minimum output bits to accommodate log of maximum input range.  It actually
553 // does not matter if one considers, say, [-64,64] or [-64,64).
554 //
555 // For example, run this through Octave:
556 // [0:127; ...
557 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
558 //  ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
min_log_x_output_bits(int input_bits)559 constexpr int min_log_x_output_bits(int input_bits) {
560   return input_bits > 90   ? 7
561          : input_bits > 44 ? 6
562          : input_bits > 21 ? 5
563          : input_bits > 10 ? 4
564          : input_bits > 4  ? 3
565          : input_bits > 1  ? 2
566                            : 1;
567 }
568 
569 // Although currently the name of this function says that it cannot handle
570 // values less than 1, in practice it can handle as low as 1/x_max, where
571 // x_max is the largest representable input.  In other words, the output range
572 // is symmetric.
573 template <int OutputIntegerBits, int InputIntegerBits>
574 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)575 log_x_for_x_greater_than_or_equal_to_1_impl(
576     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
577   // assert(__builtin_clz(0u) >= std::numeric_limits<uint32_t>::digits - 1);
578   // assert(__builtin_clz(0u) <= std::numeric_limits<uint32_t>::digits);
579   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
580   // The reason for accumulating the result with an extra bit of headroom is
581   // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
582   // recip_denom will otherwise introduce an error.
583   static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
584   using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumIntegerBits>;
585 
586   const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
587       FixedPoint0, 1488522236, std::log(2.0));
588   const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
589       FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
590   const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
591       FixedPoint0, 1518500250, std::sqrt(0.5));
592   const FixedPoint0 one_quarter =
593       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
594 
595   const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
596       FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
597   const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
598       FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
599   const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
600       FixedPoint0, 1057819769,
601       2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
602   const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
603       FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
604 
605   const FixedPointAccum shifted_quarter =
606       gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
607 
608   // Reinterpret the input value as Q0.31, because we will figure out the
609   // required shift "ourselves" instead of using, say, Rescale.
610   FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
611   // z_a_pow_2 = input_integer_bits - z_a_headroom;
612   int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32_t>(z_a.raw()));
613   FixedPoint0 r_a_tmp =
614       SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
615   const int32_t r_a_raw =
616       SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
617   // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
618   // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
619   //                   InputIntegerBits - z_b_headroom - 0.25);
620   const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
621       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
622           static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
623           31 - kAccumIntegerBits)),
624       shifted_quarter);
625 
626   // z_b is treated like z_a, but premultiplying by sqrt(0.5).
627   FixedPoint0 z_b = z_a * sqrt_half;
628   int z_b_headroom = CountLeadingZeros(static_cast<uint32_t>(z_b.raw())) - 1;
629   const int32_t r_b_raw =
630       SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
631   const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
632       FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
633           static_cast<int32_t>(InputIntegerBits - z_b_headroom),
634           31 - kAccumIntegerBits)),
635       shifted_quarter);
636 
637   const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
638   const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
639       std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
640 
641   const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
642   FixedPoint0 q = r - sqrt_sqrt_half;
643   q = q + q;
644 
645   const FixedPoint0 common_sq = q * q;
646   const FixedPoint0 num = q * r + q * common_sq * alpha_n;
647   const FixedPoint0 denom_minus_one_0 =
648       p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
649   const FixedPoint0 recip_denom =
650       one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
651 
652   const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
653   return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
654                                               num_scaled * recip_denom);
655 }
656 
657 template <int OutputIntegerBits, int InputIntegerBits>
658 inline gemmlowp::FixedPoint<int32_t, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1(gemmlowp::FixedPoint<int32_t,InputIntegerBits> input_val)659 log_x_for_x_greater_than_or_equal_to_1(
660     gemmlowp::FixedPoint<int32_t, InputIntegerBits> input_val) {
661   static_assert(
662       OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
663       "Output integer bits must be sufficient to accommodate logs of inputs.");
664   return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
665                                                      InputIntegerBits>(
666       input_val);
667 }
668 
GetReciprocal(int32_t x,int x_integer_digits,int * num_bits_over_unit)669 inline int32_t GetReciprocal(int32_t x, int x_integer_digits,
670                              int* num_bits_over_unit) {
671   int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(x));
672   // This is the number of bits to the left of the binary point above 1.0.
673   // Consider x=1.25.  In that case shifted_scale=0.8 and
674   // no later adjustment will be needed.
675   *num_bits_over_unit = x_integer_digits - headroom_plus_one;
676   const int32_t shifted_sum_minus_one =
677       static_cast<int32_t>((static_cast<uint32_t>(x) << headroom_plus_one) -
678                            (static_cast<uint32_t>(1) << 31));
679 
680   gemmlowp::FixedPoint<int32_t, 0> shifted_scale =
681       gemmlowp::one_over_one_plus_x_for_x_in_0_1(
682           gemmlowp::FixedPoint<int32_t, 0>::FromRaw(shifted_sum_minus_one));
683   return shifted_scale.raw();
684 }
685 
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)686 inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
687                                              int32_t* output_inv_sqrt,
688                                              int* output_shift) {
689   TFLITE_DCHECK_GE(input, 0);
690   if (input <= 1) {
691     // Handle the input value 1 separately to avoid overflow in that case
692     // in the general computation below (b/143972021). Also handle 0 as if it
693     // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
694     // but rare/unrealistic input value. We can expect both to occur in some
695     // incompletely trained models, but probably not in fully trained models.
696     *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
697     *output_shift = 0;
698     return;
699   }
700   TFLITE_DCHECK_GT(input, 1);
701   *output_shift = 11;
702   while (input >= (1 << 29)) {
703     input /= 4;
704     ++*output_shift;
705   }
706   const unsigned max_left_shift_bits =
707       CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
708   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
709   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
710   *output_shift -= left_shift_bit_pairs;
711   input <<= 2 * left_shift_bit_pairs;
712   TFLITE_DCHECK_GE(input, (1 << 27));
713   TFLITE_DCHECK_LT(input, (1 << 29));
714   using gemmlowp::FixedPoint;
715   using gemmlowp::Rescale;
716   using gemmlowp::SaturatingRoundingMultiplyByPOT;
717   // Using 3 integer bits gives us enough room for the internal arithmetic in
718   // this Newton-Raphson iteration.
719   using F3 = FixedPoint<int32_t, 3>;
720   using F0 = FixedPoint<int32_t, 0>;
721   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
722   const F3 fixedpoint_half_input =
723       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
724   const F3 fixedpoint_half_three =
725       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
726   // Newton-Raphson iteration
727   // Naive unoptimized starting guess: x = 1
728   F3 x = F3::One();
729   // Naive unoptimized number of iterations: 5
730   for (int i = 0; i < 5; i++) {
731     const F3 x3 = Rescale<3>(x * x * x);
732     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
733   }
734   const F0 fixedpoint_half_sqrt_2 =
735       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
736   x = x * fixedpoint_half_sqrt_2;
737   *output_inv_sqrt = x.raw();
738   if (*output_shift < 0) {
739     *output_inv_sqrt <<= -*output_shift;
740     *output_shift = 0;
741   }
742   // Convert right shift (right is positive) to left shift.
743   *output_shift *= reverse_shift;
744 }
745 
746 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
747 // BROADCASTING.
748 //
749 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
750 // rectangular array of numbers.
751 //
752 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
753 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
754 // to enable simple unoptimized implementations of element-wise broadcasting
755 // operations.
756 template <int N>
757 struct NdArrayDesc {
758   // The "extent" of each dimension. Indices along dimension d must be in the
759   // half-open interval [0, extents[d]).
760   int extents[N];
761 
762   // The number of *elements* (not bytes) between consecutive indices of each
763   // dimension.
764   int strides[N];
765 };
766 
767 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
768 // BROADCASTING.
769 //
770 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
SubscriptToIndex(const NdArrayDesc<4> & desc,int i0,int i1,int i2,int i3)771 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
772                             int i3) {
773   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
774   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
775   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
776   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
777   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
778          i3 * desc.strides[3];
779 }
780 
SubscriptToIndex(const NdArrayDesc<5> & desc,int indexes[5])781 inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
782   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
783          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
784          indexes[4] * desc.strides[4];
785 }
786 
SubscriptToIndex(const NdArrayDesc<8> & desc,int indexes[8])787 inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
788   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
789          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
790          indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
791          indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
792 }
793 
794 // Given the dimensions of the operands for an element-wise binary broadcast,
795 // adjusts them so that they can be directly iterated over with simple loops.
796 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
797 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
798 //
799 // This function assumes that the two input shapes are compatible up to
800 // broadcasting and the shorter one has already been prepended with 1s to be the
801 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
802 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
803 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
804 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
805 //
806 // When two shapes are compatible up to broadcasting, for each dimension d,
807 // the input extents are either equal, or one of them is 1.
808 //
809 // This function performs the following for each dimension d:
810 // - If the extents are equal, then do nothing since the loop that walks over
811 //   both of the input arrays is correct.
812 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
813 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
814 //   array0 to be referenced *at any index* in dimension d and still access the
815 //   same slice.
816 template <int N>
NdArrayDescsForElementwiseBroadcast(const Dims<N> & input0_dims,const Dims<N> & input1_dims,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)817 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
818                                                 const Dims<N>& input1_dims,
819                                                 NdArrayDesc<N>* desc0_out,
820                                                 NdArrayDesc<N>* desc1_out) {
821   TFLITE_DCHECK(desc0_out != nullptr);
822   TFLITE_DCHECK(desc1_out != nullptr);
823 
824   // Copy dims to desc.
825   for (int i = 0; i < N; ++i) {
826     desc0_out->extents[i] = input0_dims.sizes[i];
827     desc0_out->strides[i] = input0_dims.strides[i];
828     desc1_out->extents[i] = input1_dims.sizes[i];
829     desc1_out->strides[i] = input1_dims.strides[i];
830   }
831 
832   // Walk over each dimension. If the extents are equal do nothing.
833   // Otherwise, set the desc with extent 1 to have extent equal to the other and
834   // stride 0.
835   for (int i = 0; i < N; ++i) {
836     const int extent0 = ArraySize(input0_dims, i);
837     const int extent1 = ArraySize(input1_dims, i);
838     if (extent0 != extent1) {
839       if (extent0 == 1) {
840         desc0_out->strides[i] = 0;
841         desc0_out->extents[i] = extent1;
842       } else {
843         TFLITE_DCHECK_EQ(extent1, 1);
844         desc1_out->strides[i] = 0;
845         desc1_out->extents[i] = extent0;
846       }
847     }
848   }
849 }
850 
851 // Copies dims to desc, calculating strides.
852 template <int N>
CopyDimsToDesc(const RuntimeShape & input_shape,NdArrayDesc<N> * desc_out)853 inline void CopyDimsToDesc(const RuntimeShape& input_shape,
854                            NdArrayDesc<N>* desc_out) {
855   int desc_stride = 1;
856   for (int i = N - 1; i >= 0; --i) {
857     desc_out->extents[i] = input_shape.Dims(i);
858     desc_out->strides[i] = desc_stride;
859     desc_stride *= input_shape.Dims(i);
860   }
861 }
862 
863 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out)864 inline void NdArrayDescsForElementwiseBroadcast(
865     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
866     NdArrayDesc<N>* desc0_out, NdArrayDesc<N>* desc1_out) {
867   TFLITE_DCHECK(desc0_out != nullptr);
868   TFLITE_DCHECK(desc1_out != nullptr);
869 
870   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
871   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
872 
873   // Copy dims to desc, calculating strides.
874   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
875   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
876 
877   // Walk over each dimension. If the extents are equal do nothing.
878   // Otherwise, set the desc with extent 1 to have extent equal to the other and
879   // stride 0.
880   for (int i = 0; i < N; ++i) {
881     const int extent0 = extended_input0_shape.Dims(i);
882     const int extent1 = extended_input1_shape.Dims(i);
883     if (extent0 != extent1) {
884       if (extent0 == 1) {
885         desc0_out->strides[i] = 0;
886         desc0_out->extents[i] = extent1;
887       } else {
888         TFLITE_DCHECK_EQ(extent1, 1);
889         desc1_out->strides[i] = 0;
890         desc1_out->extents[i] = extent0;
891       }
892     }
893   }
894 }
895 
896 template <int N>
NdArrayDescsForElementwiseBroadcast(const RuntimeShape & input0_shape,const RuntimeShape & input1_shape,const RuntimeShape & input2_shape,NdArrayDesc<N> * desc0_out,NdArrayDesc<N> * desc1_out,NdArrayDesc<N> * desc2_out)897 inline void NdArrayDescsForElementwiseBroadcast(
898     const RuntimeShape& input0_shape, const RuntimeShape& input1_shape,
899     const RuntimeShape& input2_shape, NdArrayDesc<N>* desc0_out,
900     NdArrayDesc<N>* desc1_out, NdArrayDesc<N>* desc2_out) {
901   TFLITE_DCHECK(desc0_out != nullptr);
902   TFLITE_DCHECK(desc1_out != nullptr);
903   TFLITE_DCHECK(desc2_out != nullptr);
904 
905   auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape);
906   auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape);
907   auto extended_input2_shape = RuntimeShape::ExtendedShape(N, input2_shape);
908 
909   // Copy dims to desc, calculating strides.
910   CopyDimsToDesc<N>(extended_input0_shape, desc0_out);
911   CopyDimsToDesc<N>(extended_input1_shape, desc1_out);
912   CopyDimsToDesc<N>(extended_input2_shape, desc2_out);
913 
914   // Walk over each dimension. If the extents are equal do nothing.
915   // Otherwise, set the desc with extent 1 to have extent equal to the other and
916   // stride 0.
917   for (int i = 0; i < N; ++i) {
918     const int extent0 = extended_input0_shape.Dims(i);
919     const int extent1 = extended_input1_shape.Dims(i);
920     const int extent2 = extended_input2_shape.Dims(i);
921 
922     int extent = extent0;
923     if (extent1 != 1) extent = extent1;
924     if (extent2 != 1) extent = extent2;
925 
926     TFLITE_DCHECK(extent0 == 1 || extent0 == extent);
927     TFLITE_DCHECK(extent1 == 1 || extent1 == extent);
928     TFLITE_DCHECK(extent2 == 1 || extent2 == extent);
929 
930     if (!(extent0 == extent1 && extent1 == extent2)) {
931       if (extent0 == 1) {
932         desc0_out->strides[i] = 0;
933         desc0_out->extents[i] = extent;
934       }
935       if (extent1 == 1) {
936         desc1_out->strides[i] = 0;
937         desc1_out->extents[i] = extent;
938       }
939       if (extent2 == 1) {
940         desc2_out->strides[i] = 0;
941         desc2_out->extents[i] = extent;
942       }
943     }
944   }
945 }
946 
947 // Detailed implementation of NDOpsHelper, the indexes must be a zero array.
948 // This implementation is equivalent to N nested loops. Ex, if N=4, it can be
949 // re-writen as:
950 // for (int b = 0; b < output.extents[0]; ++b) {
951 //   for (int y = 0; y < output.extents[1]; ++y) {
952 //     for (int x = 0; x < output.extents[2]; ++x) {
953 //       for (int c = 0; c < output.extents[3]; ++c) {
954 //           calc({b,y,x,c});
955 //       }
956 //     }
957 //   }
958 // }
959 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])960 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(
961     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
962   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
963     NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
964   }
965 }
966 
967 template <int N, int DIM, typename Calc>
NDOpsHelperImpl(const NdArrayDesc<N> & output,const Calc & calc,int indexes[N])968 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(
969     const NdArrayDesc<N>& output, const Calc& calc, int indexes[N]) {
970   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM]) {
971     calc(indexes);
972   }
973 }
974 
975 // Execute the calc function in the innermost iteration based on the shape of
976 // the output. The calc function should take a single argument of type int[N].
977 template <int N, typename Calc>
NDOpsHelper(const NdArrayDesc<N> & output,const Calc & calc)978 inline void NDOpsHelper(const NdArrayDesc<N>& output, const Calc& calc) {
979   int indexes[N] = {0};
980   NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
981 }
982 // Copied from gemmlowp::RoundDown when we dropped direct dependency on
983 // gemmlowp.
984 //
985 // Returns the runtime argument rounded down to the nearest multiple of
986 // the fixed Modulus.
987 template <unsigned Modulus, typename Integer>
RoundDown(Integer i)988 Integer RoundDown(Integer i) {
989   return i - (i % Modulus);
990 }
991 
992 // Copied from gemmlowp::RoundUp when we dropped direct dependency on
993 // gemmlowp.
994 //
995 // Returns the runtime argument rounded up to the nearest multiple of
996 // the fixed Modulus.
997 template <unsigned Modulus, typename Integer>
RoundUp(Integer i)998 Integer RoundUp(Integer i) {
999   return RoundDown<Modulus>(i + Modulus - 1);
1000 }
1001 
1002 // Copied from gemmlowp::CeilQuotient when we dropped direct dependency on
1003 // gemmlowp.
1004 //
1005 // Returns the quotient a / b rounded up ('ceil') to the nearest integer.
1006 template <typename Integer>
CeilQuotient(Integer a,Integer b)1007 Integer CeilQuotient(Integer a, Integer b) {
1008   return (a + b - 1) / b;
1009 }
1010 
1011 // This function is a copy of gemmlowp::HowManyThreads, copied when we dropped
1012 // the direct dependency of internal/optimized/ on gemmlowp.
1013 //
1014 // It computes a reasonable number of threads to use for a GEMM of shape
1015 // (rows, cols, depth).
1016 //
1017 // TODO(b/131910176): get rid of this function by switching each call site
1018 // to its own more sensible logic for its own workload.
1019 template <int KernelRows>
LegacyHowManyThreads(int max_num_threads,int rows,int cols,int depth)1020 inline int LegacyHowManyThreads(int max_num_threads, int rows, int cols,
1021                                 int depth) {
1022   // Early-exit in the default case where multi-threading is disabled.
1023   if (max_num_threads == 1) {
1024     return 1;
1025   }
1026 
1027   // Ensure that each thread has KernelRows rows to process, if at all possible.
1028   int thread_count = std::min(max_num_threads, rows / KernelRows);
1029 
1030   // Limit the number of threads according to the overall size of the problem.
1031   if (thread_count > 1) {
1032     // Empirically determined value.
1033     static constexpr std::uint64_t min_cubic_size_per_thread = 64 * 1024;
1034 
1035     // We can only multiply two out of three sizes without risking overflow
1036     const std::uint64_t cubic_size =
1037         std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
1038 
1039     thread_count = std::min(
1040         thread_count, static_cast<int>(cubic_size / min_cubic_size_per_thread));
1041   }
1042 
1043   if (thread_count < 1) {
1044     thread_count = 1;
1045   }
1046 
1047   assert(thread_count > 0 && thread_count <= max_num_threads);
1048   return thread_count;
1049 }
1050 
1051 template <typename T>
optimized_ops_preload_l1_stream(const T * ptr)1052 void optimized_ops_preload_l1_stream(const T* ptr) {
1053 #ifdef __GNUC__
1054   // builtin offered by GCC-compatible compilers including clang
1055   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 0 means no locality */ 0);
1056 #else
1057   (void)ptr;
1058 #endif
1059 }
1060 
1061 template <typename T>
optimized_ops_preload_l1_keep(const T * ptr)1062 void optimized_ops_preload_l1_keep(const T* ptr) {
1063 #ifdef __GNUC__
1064   // builtin offered by GCC-compatible compilers including clang
1065   __builtin_prefetch(ptr, /* 0 means read */ 0, /* 3 means high locality */ 3);
1066 #else
1067   (void)ptr;
1068 #endif
1069 }
1070 
1071 template <typename T>
optimized_ops_prefetch_write_l1_keep(const T * ptr)1072 void optimized_ops_prefetch_write_l1_keep(const T* ptr) {
1073 #ifdef __GNUC__
1074   // builtin offered by GCC-compatible compilers including clang
1075   __builtin_prefetch(ptr, /* 1 means write */ 1, /* 3 means high locality */ 3);
1076 #else
1077   (void)ptr;
1078 #endif
1079 }
1080 
1081 }  // namespace tflite
1082 
1083 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_COMMON_H_
1084