1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 11 12 #include "sqrt_80_bit_long_double.h" 13 #include "src/__support/CPP/bit.h" // countl_zero 14 #include "src/__support/CPP/type_traits.h" 15 #include "src/__support/FPUtil/FEnvImpl.h" 16 #include "src/__support/FPUtil/FPBits.h" 17 #include "src/__support/FPUtil/rounding_mode.h" 18 #include "src/__support/common.h" 19 #include "src/__support/uint128.h" 20 21 namespace LIBC_NAMESPACE { 22 namespace fputil { 23 24 namespace internal { 25 26 template <typename T> struct SpecialLongDouble { 27 static constexpr bool VALUE = false; 28 }; 29 30 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) 31 template <> struct SpecialLongDouble<long double> { 32 static constexpr bool VALUE = true; 33 }; 34 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80 35 36 template <typename T> 37 LIBC_INLINE void normalize(int &exponent, 38 typename FPBits<T>::StorageType &mantissa) { 39 const int shift = 40 cpp::countl_zero(mantissa) - 41 (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN); 42 exponent -= shift; 43 mantissa <<= shift; 44 } 45 46 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64 47 template <> 48 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) { 49 normalize<double>(exponent, mantissa); 50 } 51 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) 52 template <> 53 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) { 54 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64); 55 const int shift = 56 hi_bits ? (cpp::countl_zero(hi_bits) - 15) 57 : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49); 58 exponent -= shift; 59 mantissa <<= shift; 60 } 61 #endif 62 63 } // namespace internal 64 65 // Correctly rounded IEEE 754 SQRT for all rounding modes. 66 // Shift-and-add algorithm. 67 template <typename T> 68 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) { 69 70 if constexpr (internal::SpecialLongDouble<T>::VALUE) { 71 // Special 80-bit long double. 72 return x86::sqrt(x); 73 } else { 74 // IEEE floating points formats. 75 using FPBits_t = typename fputil::FPBits<T>; 76 using StorageType = typename FPBits_t::StorageType; 77 constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN; 78 constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val(); 79 80 FPBits_t bits(x); 81 82 if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) { 83 // sqrt(+Inf) = +Inf 84 // sqrt(+0) = +0 85 // sqrt(-0) = -0 86 // sqrt(NaN) = NaN 87 // sqrt(-NaN) = -NaN 88 return x; 89 } else if (bits.is_neg()) { 90 // sqrt(-Inf) = NaN 91 // sqrt(-x) = NaN 92 return FLT_NAN; 93 } else { 94 int x_exp = bits.get_exponent(); 95 StorageType x_mant = bits.get_mantissa(); 96 97 // Step 1a: Normalize denormal input and append hidden bit to the mantissa 98 if (bits.is_subnormal()) { 99 ++x_exp; // let x_exp be the correct exponent of ONE bit. 100 internal::normalize<T>(x_exp, x_mant); 101 } else { 102 x_mant |= ONE; 103 } 104 105 // Step 1b: Make sure the exponent is even. 106 if (x_exp & 1) { 107 --x_exp; 108 x_mant <<= 1; 109 } 110 111 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and 112 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. 113 // Notice that the output of sqrt is always in the normal range. 114 // To perform shift-and-add algorithm to find y, let denote: 115 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: 116 // r(n) = 2^n ( x_mant - y(n)^2 ). 117 // That leads to the following recurrence formula: 118 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] 119 // with the initial conditions: y(0) = 1, and r(0) = x - 1. 120 // So the nth digit y_n of the mantissa of sqrt(x) can be found by: 121 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) 122 // 0 otherwise. 123 StorageType y = ONE; 124 StorageType r = x_mant - ONE; 125 126 for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { 127 r <<= 1; 128 StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) 129 if (r >= tmp) { 130 r -= tmp; 131 y += current_bit; 132 } 133 } 134 135 // We compute one more iteration in order to round correctly. 136 bool lsb = static_cast<bool>(y & 1); // Least significant bit 137 bool rb = false; // Round bit 138 r <<= 2; 139 StorageType tmp = (y << 2) + 1; 140 if (r >= tmp) { 141 r -= tmp; 142 rb = true; 143 } 144 145 // Remove hidden bit and append the exponent field. 146 x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS); 147 148 y = (y - ONE) | 149 (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN); 150 151 switch (quick_get_round()) { 152 case FE_TONEAREST: 153 // Round to nearest, ties to even 154 if (rb && (lsb || (r != 0))) 155 ++y; 156 break; 157 case FE_UPWARD: 158 if (rb || (r != 0)) 159 ++y; 160 break; 161 } 162 163 return cpp::bit_cast<T>(y); 164 } 165 } 166 } 167 168 } // namespace fputil 169 } // namespace LIBC_NAMESPACE 170 171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 172