• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
11 
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h" // countl_zero
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/rounding_mode.h"
18 #include "src/__support/common.h"
19 #include "src/__support/uint128.h"
20 
21 namespace LIBC_NAMESPACE {
22 namespace fputil {
23 
24 namespace internal {
25 
26 template <typename T> struct SpecialLongDouble {
27   static constexpr bool VALUE = false;
28 };
29 
30 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
31 template <> struct SpecialLongDouble<long double> {
32   static constexpr bool VALUE = true;
33 };
34 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
35 
36 template <typename T>
37 LIBC_INLINE void normalize(int &exponent,
38                            typename FPBits<T>::StorageType &mantissa) {
39   const int shift =
40       cpp::countl_zero(mantissa) -
41       (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
42   exponent -= shift;
43   mantissa <<= shift;
44 }
45 
46 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
47 template <>
48 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
49   normalize<double>(exponent, mantissa);
50 }
51 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
52 template <>
53 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
54   const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
55   const int shift =
56       hi_bits ? (cpp::countl_zero(hi_bits) - 15)
57               : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
58   exponent -= shift;
59   mantissa <<= shift;
60 }
61 #endif
62 
63 } // namespace internal
64 
65 // Correctly rounded IEEE 754 SQRT for all rounding modes.
66 // Shift-and-add algorithm.
67 template <typename T>
68 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
69 
70   if constexpr (internal::SpecialLongDouble<T>::VALUE) {
71     // Special 80-bit long double.
72     return x86::sqrt(x);
73   } else {
74     // IEEE floating points formats.
75     using FPBits_t = typename fputil::FPBits<T>;
76     using StorageType = typename FPBits_t::StorageType;
77     constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
78     constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
79 
80     FPBits_t bits(x);
81 
82     if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
83       // sqrt(+Inf) = +Inf
84       // sqrt(+0) = +0
85       // sqrt(-0) = -0
86       // sqrt(NaN) = NaN
87       // sqrt(-NaN) = -NaN
88       return x;
89     } else if (bits.is_neg()) {
90       // sqrt(-Inf) = NaN
91       // sqrt(-x) = NaN
92       return FLT_NAN;
93     } else {
94       int x_exp = bits.get_exponent();
95       StorageType x_mant = bits.get_mantissa();
96 
97       // Step 1a: Normalize denormal input and append hidden bit to the mantissa
98       if (bits.is_subnormal()) {
99         ++x_exp; // let x_exp be the correct exponent of ONE bit.
100         internal::normalize<T>(x_exp, x_mant);
101       } else {
102         x_mant |= ONE;
103       }
104 
105       // Step 1b: Make sure the exponent is even.
106       if (x_exp & 1) {
107         --x_exp;
108         x_mant <<= 1;
109       }
110 
111       // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
112       // 1 <= x_mant < 4.  So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
113       // Notice that the output of sqrt is always in the normal range.
114       // To perform shift-and-add algorithm to find y, let denote:
115       //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
116       //   r(n) = 2^n ( x_mant - y(n)^2 ).
117       // That leads to the following recurrence formula:
118       //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
119       // with the initial conditions: y(0) = 1, and r(0) = x - 1.
120       // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
121       //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
122       //         0 otherwise.
123       StorageType y = ONE;
124       StorageType r = x_mant - ONE;
125 
126       for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
127         r <<= 1;
128         StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
129         if (r >= tmp) {
130           r -= tmp;
131           y += current_bit;
132         }
133       }
134 
135       // We compute one more iteration in order to round correctly.
136       bool lsb = static_cast<bool>(y & 1); // Least significant bit
137       bool rb = false;                     // Round bit
138       r <<= 2;
139       StorageType tmp = (y << 2) + 1;
140       if (r >= tmp) {
141         r -= tmp;
142         rb = true;
143       }
144 
145       // Remove hidden bit and append the exponent field.
146       x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
147 
148       y = (y - ONE) |
149           (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
150 
151       switch (quick_get_round()) {
152       case FE_TONEAREST:
153         // Round to nearest, ties to even
154         if (rb && (lsb || (r != 0)))
155           ++y;
156         break;
157       case FE_UPWARD:
158         if (rb || (r != 0))
159           ++y;
160         break;
161       }
162 
163       return cpp::bit_cast<T>(y);
164     }
165   }
166 }
167 
168 } // namespace fputil
169 } // namespace LIBC_NAMESPACE
170 
171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
172