1 //===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 10 #define LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 11 12 #include "utils/CPP/TypeTraits.h" 13 #include "utils/UnitTest/Test.h" 14 15 #include <stdint.h> 16 17 namespace __llvm_libc { 18 namespace testing { 19 namespace mpfr { 20 21 enum class Operation : int { 22 // Operations with take a single floating point number as input 23 // and produce a single floating point number as output. The input 24 // and output floating point numbers are of the same kind. 25 BeginUnaryOperationsSingleOutput, 26 Abs, 27 Ceil, 28 Cos, 29 Exp, 30 Exp2, 31 Floor, 32 Round, 33 Sin, 34 Sqrt, 35 Trunc, 36 EndUnaryOperationsSingleOutput, 37 38 // Operations which take a single floating point nubmer as input 39 // but produce two outputs. The first ouput is a floating point 40 // number of the same type as the input. The second output is of type 41 // 'int'. 42 BeginUnaryOperationsTwoOutputs, 43 Frexp, // Floating point output, the first output, is the fractional part. 44 EndUnaryOperationsTwoOutputs, 45 46 // Operations wich take two floating point nubmers of the same type as 47 // input and produce a single floating point number of the same type as 48 // output. 49 BeginBinaryOperationsSingleOutput, 50 Hypot, 51 EndBinaryOperationsSingleOutput, 52 53 // Operations which take two floating point numbers of the same type as 54 // input and produce two outputs. The first output is a floating nubmer of 55 // the same type as the inputs. The second output is af type 'int'. 56 BeginBinaryOperationsTwoOutputs, 57 RemQuo, // The first output, the floating point output, is the remainder. 58 EndBinaryOperationsTwoOutputs, 59 60 BeginTernaryOperationsSingleOuput, 61 // TODO: Add operations like fma. 62 EndTernaryOperationsSingleOutput, 63 }; 64 65 template <typename T> struct BinaryInput { 66 static_assert( 67 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 68 "Template parameter of BinaryInput must be a floating point type."); 69 70 using Type = T; 71 T x, y; 72 }; 73 74 template <typename T> struct TernaryInput { 75 static_assert( 76 __llvm_libc::cpp::IsFloatingPointType<T>::Value, 77 "Template parameter of TernaryInput must be a floating point type."); 78 79 using Type = T; 80 T x, y, z; 81 }; 82 83 template <typename T> struct BinaryOutput { 84 T f; 85 int i; 86 }; 87 88 namespace internal { 89 90 template <typename T1, typename T2> 91 struct AreMatchingBinaryInputAndBinaryOutput { 92 static constexpr bool value = false; 93 }; 94 95 template <typename T> 96 struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> { 97 static constexpr bool value = cpp::IsFloatingPointType<T>::Value; 98 }; 99 100 template <typename T> 101 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcOutput, 102 double t); 103 template <typename T> 104 bool compareUnaryOperationTwoOutputs(Operation op, T input, 105 const BinaryOutput<T> &libcOutput, 106 double t); 107 template <typename T> 108 bool compareBinaryOperationTwoOutputs(Operation op, const BinaryInput<T> &input, 109 const BinaryOutput<T> &libcOutput, 110 double t); 111 112 template <typename T> 113 bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input, 114 T libcOutput, double t); 115 116 template <typename T> 117 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue, 118 testutils::StreamWrapper &OS); 119 template <typename T> 120 void explainUnaryOperationTwoOutputsError(Operation op, T input, 121 const BinaryOutput<T> &matchValue, 122 testutils::StreamWrapper &OS); 123 template <typename T> 124 void explainBinaryOperationTwoOutputsError(Operation op, 125 const BinaryInput<T> &input, 126 const BinaryOutput<T> &matchValue, 127 testutils::StreamWrapper &OS); 128 129 template <typename T> 130 void explainBinaryOperationOneOutputError(Operation op, 131 const BinaryInput<T> &input, 132 T matchValue, 133 testutils::StreamWrapper &OS); 134 135 template <Operation op, typename InputType, typename OutputType> 136 class MPFRMatcher : public testing::Matcher<OutputType> { 137 InputType input; 138 OutputType matchValue; 139 double ulpTolerance; 140 141 public: 142 MPFRMatcher(InputType testInput, double ulpTolerance) 143 : input(testInput), ulpTolerance(ulpTolerance) {} 144 145 bool match(OutputType libcResult) { 146 matchValue = libcResult; 147 return match(input, matchValue, ulpTolerance); 148 } 149 150 void explainError(testutils::StreamWrapper &OS) override { 151 explainError(input, matchValue, OS); 152 } 153 154 private: 155 template <typename T> static bool match(T in, T out, double tolerance) { 156 return compareUnaryOperationSingleOutput(op, in, out, tolerance); 157 } 158 159 template <typename T> 160 static bool match(T in, const BinaryOutput<T> &out, double tolerance) { 161 return compareUnaryOperationTwoOutputs(op, in, out, tolerance); 162 } 163 164 template <typename T> 165 static bool match(const BinaryInput<T> &in, T out, double tolerance) { 166 return compareBinaryOperationOneOutput(op, in, out, tolerance); 167 } 168 169 template <typename T> 170 static bool match(BinaryInput<T> in, const BinaryOutput<T> &out, 171 double tolerance) { 172 return compareBinaryOperationTwoOutputs(op, in, out, tolerance); 173 } 174 175 template <typename T> 176 static bool match(const TernaryInput<T> &in, T out, double tolerance) { 177 // TODO: Implement the comparision function and error reporter. 178 } 179 180 template <typename T> 181 static void explainError(T in, T out, testutils::StreamWrapper &OS) { 182 explainUnaryOperationSingleOutputError(op, in, out, OS); 183 } 184 185 template <typename T> 186 static void explainError(T in, const BinaryOutput<T> &out, 187 testutils::StreamWrapper &OS) { 188 explainUnaryOperationTwoOutputsError(op, in, out, OS); 189 } 190 191 template <typename T> 192 static void explainError(const BinaryInput<T> &in, const BinaryOutput<T> &out, 193 testutils::StreamWrapper &OS) { 194 explainBinaryOperationTwoOutputsError(op, in, out, OS); 195 } 196 197 template <typename T> 198 static void explainError(const BinaryInput<T> &in, T out, 199 testutils::StreamWrapper &OS) { 200 explainBinaryOperationOneOutputError(op, in, out, OS); 201 } 202 }; 203 204 } // namespace internal 205 206 // Return true if the input and ouput types for the operation op are valid 207 // types. 208 template <Operation op, typename InputType, typename OutputType> 209 constexpr bool isValidOperation() { 210 return (Operation::BeginUnaryOperationsSingleOutput < op && 211 op < Operation::EndUnaryOperationsSingleOutput && 212 cpp::IsSame<InputType, OutputType>::Value && 213 cpp::IsFloatingPointType<InputType>::Value) || 214 (Operation::BeginUnaryOperationsTwoOutputs < op && 215 op < Operation::EndUnaryOperationsTwoOutputs && 216 cpp::IsFloatingPointType<InputType>::Value && 217 cpp::IsSame<OutputType, BinaryOutput<InputType>>::Value) || 218 (Operation::BeginBinaryOperationsSingleOutput < op && 219 op < Operation::EndBinaryOperationsSingleOutput && 220 cpp::IsFloatingPointType<OutputType>::Value && 221 cpp::IsSame<InputType, BinaryInput<OutputType>>::Value) || 222 (Operation::BeginBinaryOperationsTwoOutputs < op && 223 op < Operation::EndBinaryOperationsTwoOutputs && 224 internal::AreMatchingBinaryInputAndBinaryOutput<InputType, 225 OutputType>::value) || 226 (Operation::BeginTernaryOperationsSingleOuput < op && 227 op < Operation::EndTernaryOperationsSingleOutput && 228 cpp::IsFloatingPointType<OutputType>::Value && 229 cpp::IsSame<InputType, TernaryInput<OutputType>>::Value); 230 } 231 232 template <Operation op, typename InputType, typename OutputType> 233 __attribute__((no_sanitize("address"))) 234 cpp::EnableIfType<isValidOperation<op, InputType, OutputType>(), 235 internal::MPFRMatcher<op, InputType, OutputType>> 236 getMPFRMatcher(InputType input, OutputType outputUnused, double t) { 237 return internal::MPFRMatcher<op, InputType, OutputType>(input, t); 238 } 239 240 } // namespace mpfr 241 } // namespace testing 242 } // namespace __llvm_libc 243 244 #define EXPECT_MPFR_MATCH(op, input, matchValue, tolerance) \ 245 EXPECT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \ 246 input, matchValue, tolerance)) 247 248 #define ASSERT_MPFR_MATCH(op, input, matchValue, tolerance) \ 249 ASSERT_THAT(matchValue, __llvm_libc::testing::mpfr::getMPFRMatcher<op>( \ 250 input, matchValue, tolerance)) 251 252 #endif // LLVM_LIBC_UTILS_TESTUTILS_MPFRUTILS_H 253