#include #include #include #include #include #include #include #include #if (defined(__CUDACC__) || defined(__HIPCC__)) #define MAYBE_GLOBAL __global__ #else #define MAYBE_GLOBAL #endif #define PI 3.141592653589793238463 namespace memory { MAYBE_GLOBAL void test_size() { static_assert(sizeof(c10::complex) == 2 * sizeof(float), ""); static_assert(sizeof(c10::complex) == 2 * sizeof(double), ""); } MAYBE_GLOBAL void test_align() { static_assert(alignof(c10::complex) == 2 * sizeof(float), ""); static_assert(alignof(c10::complex) == 2 * sizeof(double), ""); } MAYBE_GLOBAL void test_pod() { static_assert(std::is_standard_layout>::value, ""); static_assert(std::is_standard_layout>::value, ""); } TEST(TestMemory, ReinterpretCast) { { std::complex z(1, 2); c10::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), float(1)); ASSERT_EQ(zz.imag(), float(2)); } { c10::complex z(3, 4); std::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), float(3)); ASSERT_EQ(zz.imag(), float(4)); } { std::complex z(1, 2); c10::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), double(1)); ASSERT_EQ(zz.imag(), double(2)); } { c10::complex z(3, 4); std::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), double(3)); ASSERT_EQ(zz.imag(), double(4)); } } #if defined(__CUDACC__) || defined(__HIPCC__) TEST(TestMemory, ThrustReinterpretCast) { { thrust::complex z(1, 2); c10::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), float(1)); ASSERT_EQ(zz.imag(), float(2)); } { c10::complex z(3, 4); thrust::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), float(3)); ASSERT_EQ(zz.imag(), float(4)); } { thrust::complex z(1, 2); c10::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), double(1)); ASSERT_EQ(zz.imag(), double(2)); } { c10::complex z(3, 4); thrust::complex zz = *reinterpret_cast*>(&z); ASSERT_EQ(zz.real(), double(3)); ASSERT_EQ(zz.imag(), double(4)); } } #endif } // namespace memory namespace constructors { template C10_HOST_DEVICE void test_construct_from_scalar() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); constexpr scalar_t zero = scalar_t(); static_assert(c10::complex(num1, num2).real() == num1, ""); static_assert(c10::complex(num1, num2).imag() == num2, ""); static_assert(c10::complex(num1).real() == num1, ""); static_assert(c10::complex(num1).imag() == zero, ""); static_assert(c10::complex().real() == zero, ""); static_assert(c10::complex().imag() == zero, ""); } template C10_HOST_DEVICE void test_construct_from_other() { constexpr other_t num1 = other_t(1.23); constexpr other_t num2 = other_t(4.56); constexpr scalar_t num3 = scalar_t(num1); constexpr scalar_t num4 = scalar_t(num2); static_assert( c10::complex(c10::complex(num1, num2)).real() == num3, ""); static_assert( c10::complex(c10::complex(num1, num2)).imag() == num4, ""); } MAYBE_GLOBAL void test_convert_constructors() { test_construct_from_scalar(); test_construct_from_scalar(); static_assert( std::is_convertible, c10::complex>::value, ""); static_assert( !std::is_convertible, c10::complex>::value, ""); static_assert( std::is_convertible, c10::complex>::value, ""); static_assert( std::is_convertible, c10::complex>::value, ""); static_assert( std::is_constructible, c10::complex>::value, ""); static_assert( std::is_constructible, c10::complex>::value, ""); static_assert( std::is_constructible, c10::complex>::value, ""); static_assert( std::is_constructible, c10::complex>::value, ""); test_construct_from_other(); test_construct_from_other(); test_construct_from_other(); test_construct_from_other(); } template C10_HOST_DEVICE void test_construct_from_std() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); static_assert( c10::complex(std::complex(num1, num2)).real() == num1, ""); static_assert( c10::complex(std::complex(num1, num2)).imag() == num2, ""); } MAYBE_GLOBAL void test_std_conversion() { test_construct_from_std(); test_construct_from_std(); } #if defined(__CUDACC__) || defined(__HIPCC__) template void test_construct_from_thrust() { constexpr scalar_t num1 = scalar_t(1.23); constexpr scalar_t num2 = scalar_t(4.56); ASSERT_EQ( c10::complex(thrust::complex(num1, num2)).real(), num1); ASSERT_EQ( c10::complex(thrust::complex(num1, num2)).imag(), num2); } TEST(TestConstructors, FromThrust) { test_construct_from_thrust(); test_construct_from_thrust(); } #endif TEST(TestConstructors, UnorderedMap) { std::unordered_map< c10::complex, c10::complex, c10::hash>> m; auto key1 = c10::complex(2.5, 3); auto key2 = c10::complex(2, 0); auto val1 = c10::complex(2, -3.2); auto val2 = c10::complex(0, -3); m[key1] = val1; m[key2] = val2; ASSERT_EQ(m[key1], val1); ASSERT_EQ(m[key2], val2); } } // namespace constructors namespace assignment { template constexpr c10::complex one() { c10::complex result(3, 4); result = scalar_t(1); return result; } MAYBE_GLOBAL void test_assign_real() { static_assert(one().real() == float(1), ""); static_assert(one().imag() == float(), ""); static_assert(one().real() == double(1), ""); static_assert(one().imag() == double(), ""); } constexpr std::tuple, c10::complex> one_two() { constexpr c10::complex src(1, 2); c10::complex ret0; c10::complex ret1; ret0 = ret1 = src; return std::make_tuple(ret0, ret1); } MAYBE_GLOBAL void test_assign_other() { constexpr auto tup = one_two(); static_assert(std::get>(tup).real() == double(1), ""); static_assert(std::get>(tup).imag() == double(2), ""); static_assert(std::get>(tup).real() == float(1), ""); static_assert(std::get>(tup).imag() == float(2), ""); } constexpr std::tuple, c10::complex> one_two_std() { constexpr std::complex src(1, 1); c10::complex ret0; c10::complex ret1; ret0 = ret1 = src; return std::make_tuple(ret0, ret1); } MAYBE_GLOBAL void test_assign_std() { constexpr auto tup = one_two(); static_assert(std::get>(tup).real() == double(1), ""); static_assert(std::get>(tup).imag() == double(2), ""); static_assert(std::get>(tup).real() == float(1), ""); static_assert(std::get>(tup).imag() == float(2), ""); } #if defined(__CUDACC__) || defined(__HIPCC__) C10_HOST_DEVICE std::tuple, c10::complex> one_two_thrust() { thrust::complex src(1, 2); c10::complex ret0; c10::complex ret1; ret0 = ret1 = src; return std::make_tuple(ret0, ret1); } TEST(TestAssignment, FromThrust) { auto tup = one_two_thrust(); ASSERT_EQ(std::get>(tup).real(), double(1)); ASSERT_EQ(std::get>(tup).imag(), double(2)); ASSERT_EQ(std::get>(tup).real(), float(1)); ASSERT_EQ(std::get>(tup).imag(), float(2)); } #endif } // namespace assignment namespace literals { MAYBE_GLOBAL void test_complex_literals() { using namespace c10::complex_literals; static_assert(std::is_same>::value, ""); static_assert((0.5_if).real() == float(), ""); static_assert((0.5_if).imag() == float(0.5), ""); static_assert( std::is_same>::value, ""); static_assert((0.5_id).real() == float(), ""); static_assert((0.5_id).imag() == float(0.5), ""); static_assert(std::is_same>::value, ""); static_assert((1_if).real() == float(), ""); static_assert((1_if).imag() == float(1), ""); static_assert(std::is_same>::value, ""); static_assert((1_id).real() == double(), ""); static_assert((1_id).imag() == double(1), ""); } } // namespace literals namespace real_imag { template constexpr c10::complex zero_one() { c10::complex result; result.imag(scalar_t(1)); return result; } template constexpr c10::complex one_zero() { c10::complex result; result.real(scalar_t(1)); return result; } MAYBE_GLOBAL void test_real_imag_modify() { static_assert(zero_one().real() == float(0), ""); static_assert(zero_one().imag() == float(1), ""); static_assert(zero_one().real() == double(0), ""); static_assert(zero_one().imag() == double(1), ""); static_assert(one_zero().real() == float(1), ""); static_assert(one_zero().imag() == float(0), ""); static_assert(one_zero().real() == double(1), ""); static_assert(one_zero().imag() == double(0), ""); } } // namespace real_imag namespace arithmetic_assign { template constexpr c10::complex p(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result += value; return result; } template constexpr c10::complex m(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result -= value; return result; } template constexpr c10::complex t(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result *= value; return result; } template constexpr c10::complex d(scalar_t value) { c10::complex result(scalar_t(2), scalar_t(2)); result /= value; return result; } template C10_HOST_DEVICE void test_arithmetic_assign_scalar() { constexpr c10::complex x = p(scalar_t(1)); static_assert(x.real() == scalar_t(3), ""); static_assert(x.imag() == scalar_t(2), ""); constexpr c10::complex y = m(scalar_t(1)); static_assert(y.real() == scalar_t(1), ""); static_assert(y.imag() == scalar_t(2), ""); constexpr c10::complex z = t(scalar_t(2)); static_assert(z.real() == scalar_t(4), ""); static_assert(z.imag() == scalar_t(4), ""); constexpr c10::complex t = d(scalar_t(2)); static_assert(t.real() == scalar_t(1), ""); static_assert(t.imag() == scalar_t(1), ""); } template constexpr c10::complex p( scalar_t real, scalar_t imag, c10::complex rhs) { c10::complex result(real, imag); result += rhs; return result; } template constexpr c10::complex m( scalar_t real, scalar_t imag, c10::complex rhs) { c10::complex result(real, imag); result -= rhs; return result; } template constexpr c10::complex t( scalar_t real, scalar_t imag, c10::complex rhs) { c10::complex result(real, imag); result *= rhs; return result; } template constexpr c10::complex d( scalar_t real, scalar_t imag, c10::complex rhs) { c10::complex result(real, imag); result /= rhs; return result; } template C10_HOST_DEVICE void test_arithmetic_assign_complex() { using namespace c10::complex_literals; constexpr c10::complex x2 = p(scalar_t(2), scalar_t(2), 1.0_if); static_assert(x2.real() == scalar_t(2), ""); static_assert(x2.imag() == scalar_t(3), ""); constexpr c10::complex x3 = p(scalar_t(2), scalar_t(2), 1.0_id); static_assert(x3.real() == scalar_t(2), ""); // this test is skipped due to a bug in constexpr evaluation // in nvcc. This bug has already been fixed since CUDA 11.2 #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) static_assert(x3.imag() == scalar_t(3), ""); #endif constexpr c10::complex y2 = m(scalar_t(2), scalar_t(2), 1.0_if); static_assert(y2.real() == scalar_t(2), ""); static_assert(y2.imag() == scalar_t(1), ""); constexpr c10::complex y3 = m(scalar_t(2), scalar_t(2), 1.0_id); static_assert(y3.real() == scalar_t(2), ""); // this test is skipped due to a bug in constexpr evaluation // in nvcc. This bug has already been fixed since CUDA 11.2 #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) static_assert(y3.imag() == scalar_t(1), ""); #endif constexpr c10::complex z2 = t(scalar_t(1), scalar_t(-2), 1.0_if); static_assert(z2.real() == scalar_t(2), ""); static_assert(z2.imag() == scalar_t(1), ""); constexpr c10::complex z3 = t(scalar_t(1), scalar_t(-2), 1.0_id); static_assert(z3.real() == scalar_t(2), ""); static_assert(z3.imag() == scalar_t(1), ""); constexpr c10::complex t2 = d(scalar_t(-1), scalar_t(2), 1.0_if); static_assert(t2.real() == scalar_t(2), ""); static_assert(t2.imag() == scalar_t(1), ""); constexpr c10::complex t3 = d(scalar_t(-1), scalar_t(2), 1.0_id); static_assert(t3.real() == scalar_t(2), ""); static_assert(t3.imag() == scalar_t(1), ""); } MAYBE_GLOBAL void test_arithmetic_assign() { test_arithmetic_assign_scalar(); test_arithmetic_assign_scalar(); test_arithmetic_assign_complex(); test_arithmetic_assign_complex(); } } // namespace arithmetic_assign namespace arithmetic { template C10_HOST_DEVICE void test_arithmetic_() { static_assert( c10::complex(1, 2) == +c10::complex(1, 2), ""); static_assert( c10::complex(-1, -2) == -c10::complex(1, 2), ""); static_assert( c10::complex(1, 2) + c10::complex(3, 4) == c10::complex(4, 6), ""); static_assert( c10::complex(1, 2) + scalar_t(3) == c10::complex(4, 2), ""); static_assert( scalar_t(3) + c10::complex(1, 2) == c10::complex(4, 2), ""); static_assert( c10::complex(1, 2) - c10::complex(3, 4) == c10::complex(-2, -2), ""); static_assert( c10::complex(1, 2) - scalar_t(3) == c10::complex(-2, 2), ""); static_assert( scalar_t(3) - c10::complex(1, 2) == c10::complex(2, -2), ""); static_assert( c10::complex(1, 2) * c10::complex(3, 4) == c10::complex(-5, 10), ""); static_assert( c10::complex(1, 2) * scalar_t(3) == c10::complex(3, 6), ""); static_assert( scalar_t(3) * c10::complex(1, 2) == c10::complex(3, 6), ""); static_assert( c10::complex(-5, 10) / c10::complex(3, 4) == c10::complex(1, 2), ""); static_assert( c10::complex(5, 10) / scalar_t(5) == c10::complex(1, 2), ""); static_assert( scalar_t(25) / c10::complex(3, 4) == c10::complex(3, -4), ""); } MAYBE_GLOBAL void test_arithmetic() { test_arithmetic_(); test_arithmetic_(); } template void test_binary_ops_for_int_type_(T real, T img, int_t num) { c10::complex c(real, img); ASSERT_EQ(c + num, c10::complex(real + num, img)); ASSERT_EQ(num + c, c10::complex(num + real, img)); ASSERT_EQ(c - num, c10::complex(real - num, img)); ASSERT_EQ(num - c, c10::complex(num - real, -img)); ASSERT_EQ(c * num, c10::complex(real * num, img * num)); ASSERT_EQ(num * c, c10::complex(num * real, num * img)); ASSERT_EQ(c / num, c10::complex(real / num, img / num)); ASSERT_EQ( num / c, c10::complex(num * real / std::norm(c), -num * img / std::norm(c))); } template void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) { test_binary_ops_for_int_type_(real, img, i); test_binary_ops_for_int_type_(real, img, i); test_binary_ops_for_int_type_(real, img, i); test_binary_ops_for_int_type_(real, img, i); } TEST(TestArithmeticIntScalar, All) { test_binary_ops_for_all_int_types_(1.0, 0.1, 1); test_binary_ops_for_all_int_types_(-1.3, -0.2, -2); } } // namespace arithmetic namespace equality { template C10_HOST_DEVICE void test_equality_() { static_assert( c10::complex(1, 2) == c10::complex(1, 2), ""); static_assert(c10::complex(1, 0) == scalar_t(1), ""); static_assert(scalar_t(1) == c10::complex(1, 0), ""); static_assert( c10::complex(1, 2) != c10::complex(3, 4), ""); static_assert(c10::complex(1, 2) != scalar_t(1), ""); static_assert(scalar_t(1) != c10::complex(1, 2), ""); } MAYBE_GLOBAL void test_equality() { test_equality_(); test_equality_(); } } // namespace equality namespace io { template void test_io_() { std::stringstream ss; c10::complex a(1, 2); ss << a; ASSERT_EQ(ss.str(), "(1,2)"); ss.str("(3,4)"); ss >> a; ASSERT_TRUE(a == c10::complex(3, 4)); } TEST(TestIO, All) { test_io_(); test_io_(); } } // namespace io namespace test_std { template C10_HOST_DEVICE void test_callable_() { static_assert(std::real(c10::complex(1, 2)) == scalar_t(1), ""); static_assert(std::imag(c10::complex(1, 2)) == scalar_t(2), ""); std::abs(c10::complex(1, 2)); std::arg(c10::complex(1, 2)); static_assert(std::norm(c10::complex(3, 4)) == scalar_t(25), ""); static_assert( std::conj(c10::complex(3, 4)) == c10::complex(3, -4), ""); c10::polar(float(1), float(PI / 2)); c10::polar(double(1), double(PI / 2)); } MAYBE_GLOBAL void test_callable() { test_callable_(); test_callable_(); } template void test_values_() { ASSERT_EQ(std::abs(c10::complex(3, 4)), scalar_t(5)); ASSERT_LT(std::abs(std::arg(c10::complex(0, 1)) - PI / 2), 1e-6); ASSERT_LT( std::abs( c10::polar(scalar_t(1), scalar_t(PI / 2)) - c10::complex(0, 1)), 1e-6); } TEST(TestStd, BasicFunctions) { test_values_(); test_values_(); // CSQRT edge cases: checks for overflows which are likely to occur // if square root is computed using polar form ASSERT_LT( std::abs(std::sqrt(c10::complex(-1e20, -4988429.2)).real()), 3e-4); ASSERT_LT( std::abs(std::sqrt(c10::complex(-1e60, -4988429.2)).real()), 3e-4); } } // namespace test_std