• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
10 #define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
11 
12 #include "FPBits.h"
13 
14 #include "utils/CPP/TypeTraits.h"
15 
16 namespace __llvm_libc {
17 namespace fputil {
18 
19 namespace internal {
20 
21 template <typename T>
22 static inline void normalize(int &exponent,
23                              typename FPBits<T>::UIntType &mantissa);
24 
25 template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
26   // Use binary search to shift the leading 1 bit.
27   // With MantissaWidth<float> = 23, it will take
28   // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
29   // Step 1: 0000 0000 0000 XXXX XXXX XXXX
30   // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
31   // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
32   // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
33   // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
34   constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
35   constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
36                                        1 << 23};
37   constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
38 
39   for (int i = 0; i < nsteps; ++i) {
40     if (mantissa < bounds[i]) {
41       exponent -= shifts[i];
42       mantissa <<= shifts[i];
43     }
44   }
45 }
46 
47 template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
48   // Use binary search to shift the leading 1 bit similar to float.
49   // With MantissaWidth<double> = 52, it will take
50   // ceil(log2(52)) = 6 steps checking the mantissa bits.
51   constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
52   constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
53                                        1ULL << 49, 1ULL << 51, 1ULL << 52};
54   constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
55 
56   for (int i = 0; i < nsteps; ++i) {
57     if (mantissa < bounds[i]) {
58       exponent -= shifts[i];
59       mantissa <<= shifts[i];
60     }
61   }
62 }
63 
64 #if !(defined(__x86_64__) || defined(__i386__))
65 template <>
66 inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
67   // Use binary search to shift the leading 1 bit similar to float.
68   // With MantissaWidth<long double> = 112, it will take
69   // ceil(log2(112)) = 7 steps checking the mantissa bits.
70   constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
71   constexpr __uint128_t bounds[nsteps] = {
72       __uint128_t(1) << 56,  __uint128_t(1) << 84,  __uint128_t(1) << 98,
73       __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
74       __uint128_t(1) << 112};
75   constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
76 
77   for (int i = 0; i < nsteps; ++i) {
78     if (mantissa < bounds[i]) {
79       exponent -= shifts[i];
80       mantissa <<= shifts[i];
81     }
82   }
83 }
84 #endif
85 
86 } // namespace internal
87 
88 // Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
89 // Shift-and-add algorithm.
90 template <typename T,
91           cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
sqrt(T x)92 static inline T sqrt(T x) {
93   using UIntType = typename FPBits<T>::UIntType;
94   constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
95 
96   FPBits<T> bits(x);
97 
98   if (bits.isInfOrNaN()) {
99     if (bits.sign && (bits.mantissa == 0)) {
100       // sqrt(-Inf) = NaN
101       return FPBits<T>::buildNaN(One >> 1);
102     } else {
103       // sqrt(NaN) = NaN
104       // sqrt(+Inf) = +Inf
105       return x;
106     }
107   } else if (bits.isZero()) {
108     // sqrt(+0) = +0
109     // sqrt(-0) = -0
110     return x;
111   } else if (bits.sign) {
112     // sqrt( negative numbers ) = NaN
113     return FPBits<T>::buildNaN(One >> 1);
114   } else {
115     int xExp = bits.getExponent();
116     UIntType xMant = bits.mantissa;
117 
118     // Step 1a: Normalize denormal input and append hiddent bit to the mantissa
119     if (bits.exponent == 0) {
120       ++xExp; // let xExp be the correct exponent of One bit.
121       internal::normalize<T>(xExp, xMant);
122     } else {
123       xMant |= One;
124     }
125 
126     // Step 1b: Make sure the exponent is even.
127     if (xExp & 1) {
128       --xExp;
129       xMant <<= 1;
130     }
131 
132     // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
133     // 1 <= xMant < 4.  So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
134     // Notice that the output of sqrt is always in the normal range.
135     // To perform shift-and-add algorithm to find y, let denote:
136     //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
137     //   r(n) = 2^n ( xMant - y(n)^2 ).
138     // That leads to the following recurrence formula:
139     //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
140     // with the initial conditions: y(0) = 1, and r(0) = x - 1.
141     // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
142     //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
143     //         0 otherwise.
144     UIntType y = One;
145     UIntType r = xMant - One;
146 
147     for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
148       r <<= 1;
149       UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
150       if (r >= tmp) {
151         r -= tmp;
152         y += current_bit;
153       }
154     }
155 
156     // We compute one more iteration in order to round correctly.
157     bool lsb = y & 1; // Least significant bit
158     bool rb = false;  // Round bit
159     r <<= 2;
160     UIntType tmp = (y << 2) + 1;
161     if (r >= tmp) {
162       r -= tmp;
163       rb = true;
164     }
165 
166     // Remove hidden bit and append the exponent field.
167     xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
168 
169     y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
170     // Round to nearest, ties to even
171     if (rb && (lsb || (r != 0))) {
172       ++y;
173     }
174 
175     return *reinterpret_cast<T *>(&y);
176   }
177 }
178 
179 } // namespace fputil
180 } // namespace __llvm_libc
181 
182 #if (defined(__x86_64__) || defined(__i386__))
183 #include "SqrtLongDoubleX86.h"
184 #endif // defined(__x86_64__) || defined(__i386__)
185 
186 #endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H
187