/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "common/common_test.h" #include "utils/complex.h" namespace mindspore { class TestComplex : public UT::Common { public: TestComplex() {} }; TEST_F(TestComplex, test_size) { ASSERT_EQ(sizeof(Complex), 2 * sizeof(float)); ASSERT_EQ(sizeof(Complex), 2 * sizeof(double)); ASSERT_EQ(alignof(Complex), 2 * sizeof(float)); ASSERT_EQ(alignof(Complex), 2 * sizeof(double)); } template void test_construct() { constexpr T real = T(1.11f); constexpr T imag = T(2.22f); ASSERT_EQ(Complex().real(), T()); ASSERT_EQ(Complex().imag(), T()); ASSERT_EQ(Complex(real, imag).real(), real); ASSERT_EQ(Complex(real, imag).imag(), imag); ASSERT_EQ(Complex(real).real(), real); ASSERT_EQ(Complex(real).imag(), T()); } template void test_conver_construct() { ASSERT_EQ(Complex(Complex(T2(1.11f), T2(2.22f))).real(), T1(1.11f)); ASSERT_EQ(Complex(Complex(T2(1.11f), T2(2.22f))).imag(), T1(2.22f)); } template void test_conver_std_construct() { ASSERT_EQ(Complex(std::complex(T(1.11f), T(2.22f))).real(), T(1.11f)); ASSERT_EQ(Complex(std::complex(T(1.11f), T(2.22f))).imag(), T(2.22f)); } TEST_F(TestComplex, test_construct) { test_construct(); test_construct(); test_conver_construct(); test_conver_construct(); test_conver_construct(); test_conver_construct(); test_conver_std_construct(); test_conver_std_construct(); } template void test_convert_operator(T &&a) { ASSERT_EQ(static_cast(Complex(a)), a); } TEST_F(TestComplex, test_convert_operator) { test_convert_operator(true); test_convert_operator(1); test_convert_operator(1); ASSERT_NEAR(static_cast(Complex(1.11)), 1.11, 0.001); test_convert_operator(1.11f); test_convert_operator(1); test_convert_operator(1); test_convert_operator(1); test_convert_operator(1); test_convert_operator(1); test_convert_operator(1); float16 a(1.11f); ASSERT_EQ(static_cast(Complex(a)), a); } TEST_F(TestComplex, test_assign_operator) { Complex a = 1.11f; std::cout << a << std::endl; ASSERT_EQ(a.real(), 1.11f); ASSERT_EQ(a.imag(), float()); a = Complex(2.22f, 1.11f); ASSERT_EQ(a.real(), 2.22f); ASSERT_EQ(a.imag(), 1.11f); } template void test_arithmetic_add(T1 lhs, T2 rhs, T3 r) { ASSERT_EQ(lhs + rhs, r); if constexpr (!(std::is_same::value || std::is_same::value)) { ASSERT_EQ(lhs += rhs, r); } } template void test_arithmetic_sub(T1 lhs, T2 rhs, T3 r) { ASSERT_EQ(lhs - rhs, r); if constexpr (!(std::is_same::value || std::is_same::value)) { ASSERT_EQ(lhs -= rhs, r); } } template void test_arithmetic_mul(T1 lhs, T2 rhs, T3 r) { ASSERT_EQ(lhs * rhs, r); if constexpr (!(std::is_same::value || std::is_same::value)) { ASSERT_EQ(lhs *= rhs, r); } } template void test_arithmetic_div(T1 lhs, T2 rhs, T3 r) { ASSERT_EQ(lhs / rhs, r); if constexpr (!(std::is_same::value || std::is_same::value)) { ASSERT_EQ(lhs /= rhs, r); } } TEST_F(TestComplex, test_arithmetic) { test_arithmetic_add, Complex, Complex>( Complex(1.11, 2.22), Complex(1.11, 2.22), Complex(2.22, 4.44)); test_arithmetic_add, float, Complex>(Complex(1.11, 2.22), 1.11, Complex(2.22, 2.22)); test_arithmetic_add, Complex>(1.11, Complex(1.11, 2.22), Complex(2.22, 2.22)); test_arithmetic_sub, Complex, Complex>(Complex(1.11, 2.22), Complex(1.11, 2.22), Complex(0, 0)); test_arithmetic_sub, float, Complex>(Complex(1.11, 2.22), 1.11, Complex(0, 2.22)); test_arithmetic_sub, Complex>(1.11, Complex(1.11, 2.22), Complex(0, -2.22)); test_arithmetic_mul, Complex, Complex>( Complex(1.11, 2.22), Complex(1.11, 2.22), Complex(-3.6963, 4.9284)); test_arithmetic_mul, float, Complex>(Complex(1.11, 2.22), 1.11, Complex(1.2321, 2.4642)); test_arithmetic_mul, Complex>(1.11, Complex(1.11, 2.22), Complex(1.2321, 2.4642)); test_arithmetic_div, Complex, Complex>(Complex(1.11, 2.22), Complex(1.11, 2.22), Complex(1, 0)); test_arithmetic_div, float, Complex>(Complex(1.11, 2.22), 1.11, Complex(1, 2)); test_arithmetic_div, Complex>(1.11, Complex(1.11, 2.22), Complex(0.2, -0.4)); } } // namespace mindspore