• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- Square root of x86 long double numbers ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
10 #define LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
11 
12 #include "FPBits.h"
13 #include "Sqrt.h"
14 
15 #include "utils/CPP/TypeTraits.h"
16 
17 namespace __llvm_libc {
18 namespace fputil {
19 
20 #if (defined(__x86_64__) || defined(__i386__))
21 namespace internal {
22 
23 template <>
24 inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
25   // Use binary search to shift the leading 1 bit similar to float.
26   // With MantissaWidth<long double> = 63, it will take
27   // ceil(log2(63)) = 6 steps checking the mantissa bits.
28   constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
29   constexpr __uint128_t bounds[nsteps] = {
30       __uint128_t(1) << 32, __uint128_t(1) << 48, __uint128_t(1) << 56,
31       __uint128_t(1) << 60, __uint128_t(1) << 62, __uint128_t(1) << 63};
32   constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1};
33 
34   for (int i = 0; i < nsteps; ++i) {
35     if (mantissa < bounds[i]) {
36       exponent -= shifts[i];
37       mantissa <<= shifts[i];
38     }
39   }
40 }
41 
42 } // namespace internal
43 
44 // Correctly rounded SQRT with round to nearest, ties to even.
45 // Shift-and-add algorithm.
46 template <> inline long double sqrt<long double, 0>(long double x) {
47   using UIntType = typename FPBits<long double>::UIntType;
48   constexpr UIntType One = UIntType(1)
49                            << int(MantissaWidth<long double>::value);
50 
51   FPBits<long double> bits(x);
52 
53   if (bits.isInfOrNaN()) {
54     if (bits.sign && (bits.mantissa == 0)) {
55       // sqrt(-Inf) = NaN
56       return FPBits<long double>::buildNaN(One >> 1);
57     } else {
58       // sqrt(NaN) = NaN
59       // sqrt(+Inf) = +Inf
60       return x;
61     }
62   } else if (bits.isZero()) {
63     // sqrt(+0) = +0
64     // sqrt(-0) = -0
65     return x;
66   } else if (bits.sign) {
67     // sqrt( negative numbers ) = NaN
68     return FPBits<long double>::buildNaN(One >> 1);
69   } else {
70     int xExp = bits.getExponent();
71     UIntType xMant = bits.mantissa;
72 
73     // Step 1a: Normalize denormal input
74     if (bits.implicitBit) {
75       xMant |= One;
76     } else if (bits.exponent == 0) {
77       internal::normalize<long double>(xExp, xMant);
78     }
79 
80     // Step 1b: Make sure the exponent is even.
81     if (xExp & 1) {
82       --xExp;
83       xMant <<= 1;
84     }
85 
86     // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
87     // 1 <= xMant < 4.  So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
88     // Notice that the output of sqrt is always in the normal range.
89     // To perform shift-and-add algorithm to find y, let denote:
90     //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
91     //   r(n) = 2^n ( xMant - y(n)^2 ).
92     // That leads to the following recurrence formula:
93     //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
94     // with the initial conditions: y(0) = 1, and r(0) = x - 1.
95     // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
96     //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
97     //         0 otherwise.
98     UIntType y = One;
99     UIntType r = xMant - One;
100 
101     for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
102       r <<= 1;
103       UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
104       if (r >= tmp) {
105         r -= tmp;
106         y += current_bit;
107       }
108     }
109 
110     // We compute one more iteration in order to round correctly.
111     bool lsb = y & 1; // Least significant bit
112     bool rb = false;  // Round bit
113     r <<= 2;
114     UIntType tmp = (y << 2) + 1;
115     if (r >= tmp) {
116       r -= tmp;
117       rb = true;
118     }
119 
120     // Append the exponent field.
121     xExp = ((xExp >> 1) + FPBits<long double>::exponentBias);
122     y |= (static_cast<UIntType>(xExp)
123           << (MantissaWidth<long double>::value + 1));
124 
125     // Round to nearest, ties to even
126     if (rb && (lsb || (r != 0))) {
127       ++y;
128     }
129 
130     // Extract output
131     FPBits<long double> out(0.0L);
132     out.exponent = xExp;
133     out.implicitBit = 1;
134     out.mantissa = (y & (One - 1));
135 
136     return out;
137   }
138 }
139 #endif // defined(__x86_64__) || defined(__i386__)
140 
141 } // namespace fputil
142 } // namespace __llvm_libc
143 
144 #endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_LONG_DOUBLE_X86_H
145