• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
11 
12 #include "src/__support/CPP/bit.h"
13 #include "src/__support/CPP/type_traits.h"
14 #include "src/__support/FPUtil/FEnvImpl.h"
15 #include "src/__support/FPUtil/FPBits.h"
16 #include "src/__support/FPUtil/rounding_mode.h"
17 #include "src/__support/macros/attributes.h"   // LIBC_INLINE
18 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
19 #include "src/__support/uint128.h"
20 
21 namespace LIBC_NAMESPACE {
22 namespace fputil {
23 namespace generic {
24 
25 template <typename T> LIBC_INLINE T fma(T x, T y, T z);
26 
27 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
28 // The implementation below only is only correct for the default rounding mode,
29 // round-to-nearest tie-to-even.
30 template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
31   // Product is exact.
32   double prod = static_cast<double>(x) * static_cast<double>(y);
33   double z_d = static_cast<double>(z);
34   double sum = prod + z_d;
35   fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
36 
37   if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) {
38     // Since the sum is computed in double precision, rounding might happen
39     // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
40     // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
41     // the sum back to float, double rounding error might occur.
42     // A concrete example of this phenomenon is as follows:
43     //   x = y = 1 + 2^(-12), z = 2^(-53)
44     // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
45     // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
46     // On the other hand, with the default rounding mode,
47     //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
48     // and casting again to float gives us:
49     //   float(double(x*y + z)) = 1 + 2^(-11).
50     //
51     // In order to correct this possible double rounding error, first we use
52     // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
53     // assuming the (default) rounding mode is round-to-the-nearest,
54     // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
55     // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
56     // occurs when computing the sum, we just need to use t to adjust (any) last
57     // bit of sum, so that the sticky bits used when rounding sum to float are
58     // correct (when it matters).
59     fputil::FPBits<double> t(
60         (bit_prod.get_biased_exponent() >= bitz.get_biased_exponent())
61             ? ((bit_sum.get_val() - bit_prod.get_val()) - bitz.get_val())
62             : ((bit_sum.get_val() - bitz.get_val()) - bit_prod.get_val()));
63 
64     // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
65     // zero.
66     if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) {
67       if (bit_sum.sign() != t.sign()) {
68         bit_sum.set_mantissa(bit_sum.get_mantissa() + 1);
69       } else if (bit_sum.get_mantissa()) {
70         bit_sum.set_mantissa(bit_sum.get_mantissa() - 1);
71       }
72     }
73   }
74 
75   return static_cast<float>(bit_sum.get_val());
76 }
77 
78 namespace internal {
79 
80 // Extract the sticky bits and shift the `mantissa` to the right by
81 // `shift_length`.
shift_mantissa(int shift_length,UInt128 & mant)82 LIBC_INLINE bool shift_mantissa(int shift_length, UInt128 &mant) {
83   if (shift_length >= 128) {
84     mant = 0;
85     return true; // prod_mant is non-zero.
86   }
87   UInt128 mask = (UInt128(1) << shift_length) - 1;
88   bool sticky_bits = (mant & mask) != 0;
89   mant >>= shift_length;
90   return sticky_bits;
91 }
92 
93 } // namespace internal
94 
95 template <> LIBC_INLINE double fma<double>(double x, double y, double z) {
96   using FPBits = fputil::FPBits<double>;
97 
98   if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) {
99     return x * y + z;
100   }
101 
102   int x_exp = 0;
103   int y_exp = 0;
104   int z_exp = 0;
105 
106   // Normalize denormal inputs.
107   if (LIBC_UNLIKELY(FPBits(x).is_subnormal())) {
108     x_exp -= 52;
109     x *= 0x1.0p+52;
110   }
111   if (LIBC_UNLIKELY(FPBits(y).is_subnormal())) {
112     y_exp -= 52;
113     y *= 0x1.0p+52;
114   }
115   if (LIBC_UNLIKELY(FPBits(z).is_subnormal())) {
116     z_exp -= 52;
117     z *= 0x1.0p+52;
118   }
119 
120   FPBits x_bits(x), y_bits(y), z_bits(z);
121   const Sign z_sign = z_bits.sign();
122   Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
123   x_exp += x_bits.get_biased_exponent();
124   y_exp += y_bits.get_biased_exponent();
125   z_exp += z_bits.get_biased_exponent();
126 
127   if (LIBC_UNLIKELY(x_exp == FPBits::MAX_BIASED_EXPONENT ||
128                     y_exp == FPBits::MAX_BIASED_EXPONENT ||
129                     z_exp == FPBits::MAX_BIASED_EXPONENT))
130     return x * y + z;
131 
132   // Extract mantissa and append hidden leading bits.
133   UInt128 x_mant = x_bits.get_explicit_mantissa();
134   UInt128 y_mant = y_bits.get_explicit_mantissa();
135   UInt128 z_mant = z_bits.get_explicit_mantissa();
136 
137   // If the exponent of the product x*y > the exponent of z, then no extra
138   // precision beside the entire product x*y is needed.  On the other hand, when
139   // the exponent of z >= the exponent of the product x*y, the worst-case that
140   // we need extra precision is when there is cancellation and the most
141   // significant bit of the product is aligned exactly with the second most
142   // significant bit of z:
143   //      z :    10aa...a
144   // - prod :     1bb...bb....b
145   // In that case, in order to store the exact result, we need at least
146   //   (Length of prod) - (MantissaLength of z) = 2*(52 + 1) - 52 = 54.
147   // Overall, before aligning the mantissas and exponents, we can simply left-
148   // shift the mantissa of z by at least 54, and left-shift the product of x*y
149   // by (that amount - 52).  After that, it is enough to align the least
150   // significant bit, given that we keep track of the round and sticky bits
151   // after the least significant bit.
152   // We pick shifting z_mant by 64 bits so that technically we can simply use
153   // the original mantissa as high part when constructing 128-bit z_mant. So the
154   // mantissa of prod will be left-shifted by 64 - 54 = 10 initially.
155 
156   UInt128 prod_mant = x_mant * y_mant << 10;
157   int prod_lsb_exp =
158       x_exp + y_exp - (FPBits::EXP_BIAS + 2 * FPBits::FRACTION_LEN + 10);
159 
160   z_mant <<= 64;
161   int z_lsb_exp = z_exp - (FPBits::FRACTION_LEN + 64);
162   bool round_bit = false;
163   bool sticky_bits = false;
164   bool z_shifted = false;
165 
166   // Align exponents.
167   if (prod_lsb_exp < z_lsb_exp) {
168     sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant);
169     prod_lsb_exp = z_lsb_exp;
170   } else if (z_lsb_exp < prod_lsb_exp) {
171     z_shifted = true;
172     sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant);
173   }
174 
175   // Perform the addition:
176   //   (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
177   // The final result will be stored in prod_sign and prod_mant.
178   if (prod_sign == z_sign) {
179     // Effectively an addition.
180     prod_mant += z_mant;
181   } else {
182     // Subtraction cases.
183     if (prod_mant >= z_mant) {
184       if (z_shifted && sticky_bits) {
185         // Add 1 more to the subtrahend so that the sticky bits remain
186         // positive. This would simplify the rounding logic.
187         ++z_mant;
188       }
189       prod_mant -= z_mant;
190     } else {
191       if (!z_shifted && sticky_bits) {
192         // Add 1 more to the subtrahend so that the sticky bits remain
193         // positive. This would simplify the rounding logic.
194         ++prod_mant;
195       }
196       prod_mant = z_mant - prod_mant;
197       prod_sign = z_sign;
198     }
199   }
200 
201   uint64_t result = 0;
202   int r_exp = 0; // Unbiased exponent of the result
203 
204   // Normalize the result.
205   if (prod_mant != 0) {
206     uint64_t prod_hi = static_cast<uint64_t>(prod_mant >> 64);
207     int lead_zeros =
208         prod_hi ? cpp::countl_zero(prod_hi)
209                 : 64 + cpp::countl_zero(static_cast<uint64_t>(prod_mant));
210     // Move the leading 1 to the most significant bit.
211     prod_mant <<= lead_zeros;
212     // The lower 64 bits are always sticky bits after moving the leading 1 to
213     // the most significant bit.
214     sticky_bits |= (static_cast<uint64_t>(prod_mant) != 0);
215     result = static_cast<uint64_t>(prod_mant >> 64);
216     // Change prod_lsb_exp the be the exponent of the least significant bit of
217     // the result.
218     prod_lsb_exp += 64 - lead_zeros;
219     r_exp = prod_lsb_exp + 63;
220 
221     if (r_exp > 0) {
222       // The result is normal.  We will shift the mantissa to the right by
223       // 63 - 52 = 11 bits (from the locations of the most significant bit).
224       // Then the rounding bit will correspond the 11th bit, and the lowest
225       // 10 bits are merged into sticky bits.
226       round_bit = (result & 0x0400ULL) != 0;
227       sticky_bits |= (result & 0x03ffULL) != 0;
228       result >>= 11;
229     } else {
230       if (r_exp < -52) {
231         // The result is smaller than 1/2 of the smallest denormal number.
232         sticky_bits = true; // since the result is non-zero.
233         result = 0;
234       } else {
235         // The result is denormal.
236         uint64_t mask = 1ULL << (11 - r_exp);
237         round_bit = (result & mask) != 0;
238         sticky_bits |= (result & (mask - 1)) != 0;
239         if (r_exp > -52)
240           result >>= 12 - r_exp;
241         else
242           result = 0;
243       }
244 
245       r_exp = 0;
246     }
247   } else {
248     // Return +0.0 when there is exact cancellation, i.e., x*y == -z exactly.
249     prod_sign = Sign::POS;
250   }
251 
252   // Finalize the result.
253   int round_mode = fputil::quick_get_round();
254   if (LIBC_UNLIKELY(r_exp >= FPBits::MAX_BIASED_EXPONENT)) {
255     if ((round_mode == FE_TOWARDZERO) ||
256         (round_mode == FE_UPWARD && prod_sign.is_neg()) ||
257         (round_mode == FE_DOWNWARD && prod_sign.is_pos())) {
258       return FPBits::max_normal(prod_sign).get_val();
259     }
260     return FPBits::inf(prod_sign).get_val();
261   }
262 
263   // Remove hidden bit and append the exponent field and sign bit.
264   result = (result & FPBits::FRACTION_MASK) |
265            (static_cast<uint64_t>(r_exp) << FPBits::FRACTION_LEN);
266   if (prod_sign.is_neg()) {
267     result |= FPBits::SIGN_MASK;
268   }
269 
270   // Rounding.
271   if (round_mode == FE_TONEAREST) {
272     if (round_bit && (sticky_bits || ((result & 1) != 0)))
273       ++result;
274   } else if ((round_mode == FE_UPWARD && prod_sign.is_pos()) ||
275              (round_mode == FE_DOWNWARD && prod_sign.is_neg())) {
276     if (round_bit || sticky_bits)
277       ++result;
278   }
279 
280   return cpp::bit_cast<double>(result);
281 }
282 
283 } // namespace generic
284 } // namespace fputil
285 } // namespace LIBC_NAMESPACE
286 
287 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
288