• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- A class to store high precision floating point numbers --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
11 
12 #include "FPBits.h"
13 #include "multiply_add.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/big_int.h"
16 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
17 
18 #include <stddef.h>
19 
20 namespace LIBC_NAMESPACE::fputil {
21 
22 // A generic class to perform computations of high precision floating points.
23 // We store the value in dyadic format, including 3 fields:
24 //   sign    : boolean value - false means positive, true means negative
25 //   exponent: the exponent value of the least significant bit of the mantissa.
26 //   mantissa: unsigned integer of length `Bits`.
27 // So the real value that is stored is:
28 //   real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
29 // The stored data is normal if for non-zero mantissa, the leading bit is 1.
30 // The outputs of the constructors and most functions will be normalized.
31 // To simplify and improve the efficiency, many functions will assume that the
32 // inputs are normal.
33 template <size_t Bits> struct DyadicFloat {
34   using MantissaType = LIBC_NAMESPACE::UInt<Bits>;
35 
36   Sign sign = Sign::POS;
37   int exponent = 0;
38   MantissaType mantissa = MantissaType(0);
39 
40   LIBC_INLINE constexpr DyadicFloat() = default;
41 
42   template <typename T, cpp::enable_if_t<cpp::is_floating_point_v<T>, int> = 0>
DyadicFloatDyadicFloat43   LIBC_INLINE constexpr DyadicFloat(T x) {
44     static_assert(FPBits<T>::FRACTION_LEN < Bits);
45     FPBits<T> x_bits(x);
46     sign = x_bits.sign();
47     exponent = x_bits.get_explicit_exponent() - FPBits<T>::FRACTION_LEN;
48     mantissa = MantissaType(x_bits.get_explicit_mantissa());
49     normalize();
50   }
51 
DyadicFloatDyadicFloat52   LIBC_INLINE constexpr DyadicFloat(Sign s, int e, MantissaType m)
53       : sign(s), exponent(e), mantissa(m) {
54     normalize();
55   }
56 
57   // Normalizing the mantissa, bringing the leading 1 bit to the most
58   // significant bit.
normalizeDyadicFloat59   LIBC_INLINE constexpr DyadicFloat &normalize() {
60     if (!mantissa.is_zero()) {
61       int shift_length = cpp::countl_zero(mantissa);
62       exponent -= shift_length;
63       mantissa <<= static_cast<size_t>(shift_length);
64     }
65     return *this;
66   }
67 
68   // Used for aligning exponents.  Output might not be normalized.
shift_leftDyadicFloat69   LIBC_INLINE constexpr DyadicFloat &shift_left(int shift_length) {
70     exponent -= shift_length;
71     mantissa <<= static_cast<size_t>(shift_length);
72     return *this;
73   }
74 
75   // Used for aligning exponents.  Output might not be normalized.
shift_rightDyadicFloat76   LIBC_INLINE constexpr DyadicFloat &shift_right(int shift_length) {
77     exponent += shift_length;
78     mantissa >>= static_cast<size_t>(shift_length);
79     return *this;
80   }
81 
82   // Assume that it is already normalized.  Output the unbiased exponent.
get_unbiased_exponentDyadicFloat83   LIBC_INLINE constexpr int get_unbiased_exponent() const {
84     return exponent + (Bits - 1);
85   }
86 
87   // Assume that it is already normalized.
88   // Output is rounded correctly with respect to the current rounding mode.
89   template <typename T,
90             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
91                                             (FPBits<T>::FRACTION_LEN < Bits),
92                                         void>>
TDyadicFloat93   LIBC_INLINE explicit constexpr operator T() const {
94     if (LIBC_UNLIKELY(mantissa.is_zero()))
95       return FPBits<T>::zero(sign).get_val();
96 
97     // Assume that it is normalized, and output is also normal.
98     constexpr uint32_t PRECISION = FPBits<T>::FRACTION_LEN + 1;
99     using output_bits_t = typename FPBits<T>::StorageType;
100     constexpr output_bits_t IMPLICIT_MASK =
101         FPBits<T>::SIG_MASK - FPBits<T>::FRACTION_MASK;
102 
103     int exp_hi = exponent + static_cast<int>((Bits - 1) + FPBits<T>::EXP_BIAS);
104 
105     if (LIBC_UNLIKELY(exp_hi > 2 * FPBits<T>::EXP_BIAS)) {
106       // Results overflow.
107       T d_hi =
108           FPBits<T>::create_value(sign, 2 * FPBits<T>::EXP_BIAS, IMPLICIT_MASK)
109               .get_val();
110       return T(2) * d_hi;
111     }
112 
113     bool denorm = false;
114     uint32_t shift = Bits - PRECISION;
115     if (LIBC_UNLIKELY(exp_hi <= 0)) {
116       // Output is denormal.
117       denorm = true;
118       shift = (Bits - PRECISION) + static_cast<uint32_t>(1 - exp_hi);
119 
120       exp_hi = FPBits<T>::EXP_BIAS;
121     }
122 
123     int exp_lo = exp_hi - static_cast<int>(PRECISION) - 1;
124 
125     MantissaType m_hi =
126         shift >= MantissaType::BITS ? MantissaType(0) : mantissa >> shift;
127 
128     T d_hi = FPBits<T>::create_value(
129                  sign, exp_hi,
130                  (static_cast<output_bits_t>(m_hi) & FPBits<T>::SIG_MASK) |
131                      IMPLICIT_MASK)
132                  .get_val();
133 
134     MantissaType round_mask =
135         shift > MantissaType::BITS ? 0 : MantissaType(1) << (shift - 1);
136     MantissaType sticky_mask = round_mask - MantissaType(1);
137 
138     bool round_bit = !(mantissa & round_mask).is_zero();
139     bool sticky_bit = !(mantissa & sticky_mask).is_zero();
140     int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
141 
142     T d_lo;
143 
144     if (LIBC_UNLIKELY(exp_lo <= 0)) {
145       // d_lo is denormal, but the output is normal.
146       int scale_up_exponent = 2 * PRECISION;
147       T scale_up_factor =
148           FPBits<T>::create_value(sign, FPBits<T>::EXP_BIAS + scale_up_exponent,
149                                   IMPLICIT_MASK)
150               .get_val();
151       T scale_down_factor =
152           FPBits<T>::create_value(sign, FPBits<T>::EXP_BIAS - scale_up_exponent,
153                                   IMPLICIT_MASK)
154               .get_val();
155 
156       d_lo = FPBits<T>::create_value(sign, exp_lo + scale_up_exponent,
157                                      IMPLICIT_MASK)
158                  .get_val();
159 
160       return multiply_add(d_lo, T(round_and_sticky), d_hi * scale_up_factor) *
161              scale_down_factor;
162     }
163 
164     d_lo = FPBits<T>::create_value(sign, exp_lo, IMPLICIT_MASK).get_val();
165 
166     // Still correct without FMA instructions if `d_lo` is not underflow.
167     T r = multiply_add(d_lo, T(round_and_sticky), d_hi);
168 
169     if (LIBC_UNLIKELY(denorm)) {
170       // Exponent before rounding is in denormal range, simply clear the
171       // exponent field.
172       output_bits_t clear_exp = (output_bits_t(exp_hi) << FPBits<T>::SIG_LEN);
173       output_bits_t r_bits = FPBits<T>(r).uintval() - clear_exp;
174       if (!(r_bits & FPBits<T>::EXP_MASK)) {
175         // Output is denormal after rounding, clear the implicit bit for 80-bit
176         // long double.
177         r_bits -= IMPLICIT_MASK;
178       }
179 
180       return FPBits<T>(r_bits).get_val();
181     }
182 
183     return r;
184   }
185 
MantissaTypeDyadicFloat186   LIBC_INLINE explicit constexpr operator MantissaType() const {
187     if (mantissa.is_zero())
188       return 0;
189 
190     MantissaType new_mant = mantissa;
191     if (exponent > 0) {
192       new_mant <<= exponent;
193     } else {
194       new_mant >>= (-exponent);
195     }
196 
197     if (sign.is_neg()) {
198       new_mant = (~new_mant) + 1;
199     }
200 
201     return new_mant;
202   }
203 };
204 
205 // Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the
206 // output:
207 //   - Align the exponents so that:
208 //     new a.exponent = new b.exponent = max(a.exponent, b.exponent)
209 //   - Add or subtract the mantissas depending on the signs.
210 //   - Normalize the result.
211 // The absolute errors compared to the mathematical sum is bounded by:
212 //   | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
213 // i.e., errors are up to 2 ULPs.
214 // Assume inputs are normalized (by constructors or other functions) so that we
215 // don't need to normalize the inputs again in this function.  If the inputs are
216 // not normalized, the results might lose precision significantly.
217 template <size_t Bits>
quick_add(DyadicFloat<Bits> a,DyadicFloat<Bits> b)218 LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
219                                                   DyadicFloat<Bits> b) {
220   if (LIBC_UNLIKELY(a.mantissa.is_zero()))
221     return b;
222   if (LIBC_UNLIKELY(b.mantissa.is_zero()))
223     return a;
224 
225   // Align exponents
226   if (a.exponent > b.exponent)
227     b.shift_right(a.exponent - b.exponent);
228   else if (b.exponent > a.exponent)
229     a.shift_right(b.exponent - a.exponent);
230 
231   DyadicFloat<Bits> result;
232 
233   if (a.sign == b.sign) {
234     // Addition
235     result.sign = a.sign;
236     result.exponent = a.exponent;
237     result.mantissa = a.mantissa;
238     if (result.mantissa.add_overflow(b.mantissa)) {
239       // Mantissa addition overflow.
240       result.shift_right(1);
241       result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
242           (uint64_t(1) << 63);
243     }
244     // Result is already normalized.
245     return result;
246   }
247 
248   // Subtraction
249   if (a.mantissa >= b.mantissa) {
250     result.sign = a.sign;
251     result.exponent = a.exponent;
252     result.mantissa = a.mantissa - b.mantissa;
253   } else {
254     result.sign = b.sign;
255     result.exponent = b.exponent;
256     result.mantissa = b.mantissa - a.mantissa;
257   }
258 
259   return result.normalize();
260 }
261 
262 // Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
263 // floats with rounding toward 0 and then normalize the output:
264 //   result.exponent = a.exponent + b.exponent + Bits,
265 //   result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
266 //                   ~ (full product a.mantissa * b.mantissa) >> Bits.
267 // The errors compared to the mathematical product is bounded by:
268 //   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
269 // Assume inputs are normalized (by constructors or other functions) so that we
270 // don't need to normalize the inputs again in this function.  If the inputs are
271 // not normalized, the results might lose precision significantly.
272 template <size_t Bits>
quick_mul(DyadicFloat<Bits> a,DyadicFloat<Bits> b)273 LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(DyadicFloat<Bits> a,
274                                                   DyadicFloat<Bits> b) {
275   DyadicFloat<Bits> result;
276   result.sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
277   result.exponent = a.exponent + b.exponent + int(Bits);
278 
279   if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
280     result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
281     // Check the leading bit directly, should be faster than using clz in
282     // normalize().
283     if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
284             63 ==
285         0)
286       result.shift_left(1);
287   } else {
288     result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
289   }
290   return result;
291 }
292 
293 // Simple polynomial approximation.
294 template <size_t Bits>
295 LIBC_INLINE constexpr DyadicFloat<Bits>
multiply_add(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b,const DyadicFloat<Bits> & c)296 multiply_add(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b,
297              const DyadicFloat<Bits> &c) {
298   return quick_add(c, quick_mul(a, b));
299 }
300 
301 // Simple exponentiation implementation for printf. Only handles positive
302 // exponents, since division isn't implemented.
303 template <size_t Bits>
pow_n(DyadicFloat<Bits> a,uint32_t power)304 LIBC_INLINE constexpr DyadicFloat<Bits> pow_n(DyadicFloat<Bits> a,
305                                               uint32_t power) {
306   DyadicFloat<Bits> result = 1.0;
307   DyadicFloat<Bits> cur_power = a;
308 
309   while (power > 0) {
310     if ((power % 2) > 0) {
311       result = quick_mul(result, cur_power);
312     }
313     power = power >> 1;
314     cur_power = quick_mul(cur_power, cur_power);
315   }
316   return result;
317 }
318 
319 template <size_t Bits>
mul_pow_2(DyadicFloat<Bits> a,int32_t pow_2)320 LIBC_INLINE constexpr DyadicFloat<Bits> mul_pow_2(DyadicFloat<Bits> a,
321                                                   int32_t pow_2) {
322   DyadicFloat<Bits> result = a;
323   result.exponent += pow_2;
324   return result;
325 }
326 
327 } // namespace LIBC_NAMESPACE::fputil
328 
329 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
330