• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/int8/fixed_point.h"
18 
19 // returns the high-32 bits of a * b with rounding
20 // assume that a and b is divided by 2^31, who fall into [-1, 1]
21 // so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31
22 // actually we compute 2 * a * b / 2^32
23 // and take 32 bits of mantissa for rounding
SaturatingRoundingDoublingHighMul(int a,int b)24 int SaturatingRoundingDoublingHighMul(int a, int b) {
25   if (a == INT_MIN && b == INT_MIN) {
26     return INT_MAX;
27   }
28   int64_t ab = ((int64_t)a) * ((int64_t)b);
29   int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30));
30   // do not apply right shift to potential negetive values
31   int ab_mantissa = (int)((ab + rounding) / (1ll << 31));
32   return ab_mantissa;
33 }
34 
SaturatingRoundingDoublingHighMulInt16(int16_t a,int16_t b)35 int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) {
36   if (a == SHRT_MIN && b == SHRT_MIN) {
37     return SHRT_MAX;
38   }
39   int32_t ab = ((int32_t)a) * ((int32_t)b);
40   int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14));
41   return (int16_t)((ab + rounding) / (1ll << 15));
42 }
43 
44 // division by a 2^exponent with rounding
45 // or arithmetic right shift with rounding
RoundingDivideByPOT(int x,int exponent)46 int RoundingDivideByPOT(int x, int exponent) {
47   const int mask = (1ll << exponent) - 1;
48   const int remainder = x & mask;
49   const int threshold = (mask >> 1) + (x < 0 ? 1 : 0);
50   return (x >> exponent) + (remainder > threshold ? 1 : 0);
51 }
52 
UpwardRounding(int x,int exponent)53 int UpwardRounding(int x, int exponent) {
54   const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0;
55   if (x > INT32_MAX - rounding_offset) {
56     return 1 << (31 - exponent);
57   }
58   return (x + rounding_offset) >> exponent;
59 }
60 
MultiplyByQuantizedMultiplier(int32_t value,int32_t multiplier,int32_t left_shift,int32_t right_shift)61 int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
62   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
63 }
64 
MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value,int32_t multiplier,int32_t left_shift,int32_t right_shift)65 int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift,
66                                                     int32_t right_shift) {
67   return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
68 }
69 
MultiplyByMultiplierAndRightShift(int32_t value,int32_t multiplier,int32_t right_shift)70 int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) {
71   return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift);
72 }
73 
FractionsBits(int integer_bits)74 int FractionsBits(int integer_bits) { return 8 * (int)(sizeof(int32_t)) - 1 - integer_bits; }
75 
FixedPoint_One(int integer_bits,int fractions_bits)76 int FixedPoint_One(int integer_bits, int fractions_bits) {
77   return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits)));
78 }
79 
RoundingHalfSum(int32_t a,int32_t b)80 int RoundingHalfSum(int32_t a, int32_t b) {
81   int64_t sum = (int64_t)a + (int64_t)b;
82   return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2);
83 }
84 
BitAnd(int32_t a,int32_t b)85 int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; }
86 
BitOr(int32_t a,int32_t b)87 int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; }
88 
BitXor(int32_t a,int32_t b)89 int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; }
90 
BitNot(int32_t a)91 int32_t BitNot(int32_t a) { return ~(uint32_t)a; }
92 
BitsSelect(int mask,int bound,int val)93 int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); }
94 
ConstantPOT(int fractional_bits,int exponent)95 int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); }
96 
MaskIfNonZero(int32_t a)97 int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
98 
MaskIfZero(int32_t a)99 int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
100 
MaskIfLessThan(int32_t a,int32_t b)101 int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
102 
CountLeadingZeroBits(uint32_t x)103 int CountLeadingZeroBits(uint32_t x) {
104 #if defined(__GUNC__)
105   return x ? __builtin_clz(x) : 8 * sizeof(uint32_t);
106 #else
107   if (x == 0) {
108     return 8 * sizeof(uint32_t);
109   }
110   const int32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1);
111   int leading_zeros = 0;
112   while (x < leading_positive) {
113     x <<= 1;
114     leading_zeros++;
115   }
116   return leading_zeros;
117 #endif
118 }
119 
CountLeadingSignBits(int32_t x)120 int CountLeadingSignBits(int32_t x) {
121 #if defined(__GUNC__) && !defined(__clang__)
122   return x ? __builtin_clrsb(x) : 8 * sizeof(int32_t);
123 #else
124   return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0;
125 #endif
126 }
127 
SaturatingRoundingMultiplyByPOT(int32_t x,int exponent)128 int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) {
129   if (exponent > 0) {
130     const int min = INT32_MIN;
131     const int max = INT32_MAX;
132     const int scalar_int_bits = 8 * (int)(sizeof(int32_t));
133     const int threshold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1);
134     const int positive_mask = x > threshold ? BitNot(0) : 0;
135     const int negative_mask = x < -threshold ? BitNot(0) : 0;
136     int result = x * ((int32_t)(1) << (uint32_t)exponent);
137     result = BitsSelect(positive_mask, max, result);
138     result = BitsSelect(negative_mask, min, result);
139     return result;
140   } else if (exponent < 0) {
141     return RoundingDivideByPOT(x, -exponent);
142   } else {
143     return x;
144   }
145 }
146 
Rescale(int x,int integer_bits_src,int integer_bits_dst)147 int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) {
148   int exponent = integer_bits_src - integer_bits_dst;
149   return SaturatingRoundingMultiplyByPOT(x, exponent);
150 }
151 
reciprocal_on_interval_between_0_1(int32_t a)152 int32_t reciprocal_on_interval_between_0_1(int32_t a) {
153   int one = FixedPoint_One(0, FractionsBits(0));
154   int half_sum = RoundingHalfSum(a, one);
155   const int constant_48_over_17 = 1515870810;
156   const int constant_neg_32_over_17 = -1010580540;
157   int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17);
158   for (int i = 0; i < 3; i++) {
159     int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x);
160     int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x;
161     x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2);
162   }
163   return Rescale(x, 2 - 1, 0);
164 }
165 
ComputerReciprocal(int32_t x,int x_digits,int * recip_shift)166 int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift) {
167   int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x);
168   *recip_shift = x_digits - leading_zreos_plus_one;
169   const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31));
170   const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one);
171   return shifted_scaled;
172 }
173 
exp_on_interval_values(int a)174 int exp_on_interval_values(int a) {
175   const int constant_neg_1_over_8 = 1895147668;
176   const int constant_1_over_3 = 715827883;
177   int fractional_bits = FractionsBits(0);
178   int x = a + ConstantPOT(fractional_bits, -3);
179   int x2 = SaturatingRoundingDoublingHighMul(x, x);
180   int x3 = SaturatingRoundingDoublingHighMul(x2, x);
181   int x4 = SaturatingRoundingDoublingHighMul(x2, x2);
182   int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2);
183   int x4_over_24_plus_x3_over_6_plus_x2_over_2 =
184     SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1);
185   return constant_neg_1_over_8 +
186          SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
187 }
188 
exp_barrel_shifter(int exponent,int muliplier,int integer_bits,int fractional_bits,int remainder,int * result)189 void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder,
190                         int *result) {
191   if (integer_bits > exponent) {
192     int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0;
193     *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))),
194                          SaturatingRoundingDoublingHighMul(*result, muliplier), *result);
195   }
196 }
197 
exp_on_negative_values(int a,const int integer_bits)198 int exp_on_negative_values(int a, const int integer_bits) {
199   int fractional_bits = FractionsBits(integer_bits);
200   const int one_quarter = ConstantPOT(fractional_bits, -2);
201   int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter;
202   int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0));
203   int remainder = a_mod_quarter_minus_one_quarter - a;
204 
205   exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result);
206   exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result);
207   exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result);
208   exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result);
209   exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result);
210   exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result);
211   exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result);
212 
213   int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0;
214   if (integer_bits > 5) {
215     const int clamp = -(1 << (uint32_t)clamp_bits);
216     result = BitsSelect(MaskIfLessThan(a, clamp), 0, result);
217   }
218   result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result);
219   return result;
220 }
221 
GetSqrtQuantMultiplierExp(int32_t input,int reverse_shift,int32_t * multiplier,int32_t * shift)222 void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) {
223   if (input <= 1) {
224     *multiplier = INT_MAX;
225     *shift = 0;
226   }
227   *shift = 11;
228   while (input >= (1 << 29)) {
229     input /= 4;
230     ++*shift;
231   }
232   int max_left_shift_bits = CountLeadingSignBits(input);
233   int left_shift_bit_pairs = max_left_shift_bits / 2 - 1;
234   *shift -= left_shift_bit_pairs;
235   input <<= 2 * left_shift_bit_pairs;
236   int32_t fixedpoint_f3_input = input >> 1;  // sign: 1 bit, integer: 3 bit, fractional: 28 bit
237   int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1);
238   int32_t fp_f3_half_three = (1 << 28) + (1 << 27);
239   int32_t tmp = (1 << 28);  // one
240   for (int i = 0; i < 5; i++) {
241     int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3);
242     tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) -
243                     SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3),
244                   6, 3);
245   }
246   const int32_t fp_f0_half_sqrt_2 = 1518500250;  // sqrt(2) / 2
247   tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2);
248   *multiplier = tmp;
249   if (*shift < 0) {
250     *multiplier <<= -*shift;
251     *shift = 0;
252   }
253   *shift *= reverse_shift;
254 }
255 
256 #ifdef ENABLE_NEON
RoundingDivideByPOTInt32x4(int32x4_t x,int exponent)257 int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
258   const int32x4_t shift_vec = vdupq_n_s32(-exponent);
259   const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
260   const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
261   return vrshlq_s32(fixed_up_x, shift_vec);
262 }
263 
SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a,int32x4_t b)264 int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); }
265 #endif
266