1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
10 #define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
11
12 #include "FPBits.h"
13
14 #include "utils/CPP/TypeTraits.h"
15
16 namespace __llvm_libc {
17 namespace fputil {
18
19 namespace internal {
20
21 template <typename T>
22 static inline void normalize(int &exponent,
23 typename FPBits<T>::UIntType &mantissa);
24
25 template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
26 // Use binary search to shift the leading 1 bit.
27 // With MantissaWidth<float> = 23, it will take
28 // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
29 // Step 1: 0000 0000 0000 XXXX XXXX XXXX
30 // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
31 // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
32 // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
33 // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
34 constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
35 constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
36 1 << 23};
37 constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
38
39 for (int i = 0; i < nsteps; ++i) {
40 if (mantissa < bounds[i]) {
41 exponent -= shifts[i];
42 mantissa <<= shifts[i];
43 }
44 }
45 }
46
47 template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
48 // Use binary search to shift the leading 1 bit similar to float.
49 // With MantissaWidth<double> = 52, it will take
50 // ceil(log2(52)) = 6 steps checking the mantissa bits.
51 constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
52 constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
53 1ULL << 49, 1ULL << 51, 1ULL << 52};
54 constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
55
56 for (int i = 0; i < nsteps; ++i) {
57 if (mantissa < bounds[i]) {
58 exponent -= shifts[i];
59 mantissa <<= shifts[i];
60 }
61 }
62 }
63
64 #if !(defined(__x86_64__) || defined(__i386__))
65 template <>
66 inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
67 // Use binary search to shift the leading 1 bit similar to float.
68 // With MantissaWidth<long double> = 112, it will take
69 // ceil(log2(112)) = 7 steps checking the mantissa bits.
70 constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
71 constexpr __uint128_t bounds[nsteps] = {
72 __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
73 __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
74 __uint128_t(1) << 112};
75 constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
76
77 for (int i = 0; i < nsteps; ++i) {
78 if (mantissa < bounds[i]) {
79 exponent -= shifts[i];
80 mantissa <<= shifts[i];
81 }
82 }
83 }
84 #endif
85
86 } // namespace internal
87
88 // Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
89 // Shift-and-add algorithm.
90 template <typename T,
91 cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
sqrt(T x)92 static inline T sqrt(T x) {
93 using UIntType = typename FPBits<T>::UIntType;
94 constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
95
96 FPBits<T> bits(x);
97
98 if (bits.isInfOrNaN()) {
99 if (bits.sign && (bits.mantissa == 0)) {
100 // sqrt(-Inf) = NaN
101 return FPBits<T>::buildNaN(One >> 1);
102 } else {
103 // sqrt(NaN) = NaN
104 // sqrt(+Inf) = +Inf
105 return x;
106 }
107 } else if (bits.isZero()) {
108 // sqrt(+0) = +0
109 // sqrt(-0) = -0
110 return x;
111 } else if (bits.sign) {
112 // sqrt( negative numbers ) = NaN
113 return FPBits<T>::buildNaN(One >> 1);
114 } else {
115 int xExp = bits.getExponent();
116 UIntType xMant = bits.mantissa;
117
118 // Step 1a: Normalize denormal input and append hiddent bit to the mantissa
119 if (bits.exponent == 0) {
120 ++xExp; // let xExp be the correct exponent of One bit.
121 internal::normalize<T>(xExp, xMant);
122 } else {
123 xMant |= One;
124 }
125
126 // Step 1b: Make sure the exponent is even.
127 if (xExp & 1) {
128 --xExp;
129 xMant <<= 1;
130 }
131
132 // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
133 // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
134 // Notice that the output of sqrt is always in the normal range.
135 // To perform shift-and-add algorithm to find y, let denote:
136 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
137 // r(n) = 2^n ( xMant - y(n)^2 ).
138 // That leads to the following recurrence formula:
139 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
140 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
141 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
142 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
143 // 0 otherwise.
144 UIntType y = One;
145 UIntType r = xMant - One;
146
147 for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
148 r <<= 1;
149 UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
150 if (r >= tmp) {
151 r -= tmp;
152 y += current_bit;
153 }
154 }
155
156 // We compute one more iteration in order to round correctly.
157 bool lsb = y & 1; // Least significant bit
158 bool rb = false; // Round bit
159 r <<= 2;
160 UIntType tmp = (y << 2) + 1;
161 if (r >= tmp) {
162 r -= tmp;
163 rb = true;
164 }
165
166 // Remove hidden bit and append the exponent field.
167 xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
168
169 y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
170 // Round to nearest, ties to even
171 if (rb && (lsb || (r != 0))) {
172 ++y;
173 }
174
175 return *reinterpret_cast<T *>(&y);
176 }
177 }
178
179 } // namespace fputil
180 } // namespace __llvm_libc
181
182 #if (defined(__x86_64__) || defined(__i386__))
183 #include "SqrtLongDoubleX86.h"
184 #endif // defined(__x86_64__) || defined(__i386__)
185
186 #endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H
187