1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17
18 #include "common/common_test.h"
19 #include "utils/complex.h"
20
21 namespace mindspore {
22
23 class TestComplex : public UT::Common {
24 public:
TestComplex()25 TestComplex() {}
26 };
27
TEST_F(TestComplex,test_size)28 TEST_F(TestComplex, test_size) {
29 ASSERT_EQ(sizeof(Complex<float>), 2 * sizeof(float));
30 ASSERT_EQ(sizeof(Complex<double>), 2 * sizeof(double));
31 ASSERT_EQ(alignof(Complex<float>), 2 * sizeof(float));
32 ASSERT_EQ(alignof(Complex<double>), 2 * sizeof(double));
33 }
34
35 template <typename T>
test_construct()36 void test_construct() {
37 constexpr T real = T(1.11f);
38 constexpr T imag = T(2.22f);
39 ASSERT_EQ(Complex<T>().real(), T());
40 ASSERT_EQ(Complex<T>().imag(), T());
41 ASSERT_EQ(Complex<T>(real, imag).real(), real);
42 ASSERT_EQ(Complex<T>(real, imag).imag(), imag);
43 ASSERT_EQ(Complex<T>(real).real(), real);
44 ASSERT_EQ(Complex<T>(real).imag(), T());
45 }
46
47 template <typename T1, typename T2>
test_conver_construct()48 void test_conver_construct() {
49 ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).real(), T1(1.11f));
50 ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).imag(), T1(2.22f));
51 }
52
53 template <typename T>
test_conver_std_construct()54 void test_conver_std_construct() {
55 ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).real(), T(1.11f));
56 ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).imag(), T(2.22f));
57 }
58
TEST_F(TestComplex,test_construct)59 TEST_F(TestComplex, test_construct) {
60 test_construct<float>();
61 test_construct<double>();
62 test_conver_construct<float, float>();
63 test_conver_construct<double, double>();
64 test_conver_construct<float, double>();
65 test_conver_construct<double, float>();
66 test_conver_std_construct<float>();
67 test_conver_std_construct<double>();
68 }
69
70 template <typename T>
test_convert_operator(T && a)71 void test_convert_operator(T &&a) {
72 ASSERT_EQ(static_cast<T>(Complex<float>(a)), a);
73 }
74
TEST_F(TestComplex,test_convert_operator)75 TEST_F(TestComplex, test_convert_operator) {
76 test_convert_operator<bool>(true);
77 test_convert_operator<signed char>(1);
78 test_convert_operator<unsigned char>(1);
79 ASSERT_NEAR(static_cast<double>(Complex<float>(1.11)), 1.11, 0.001);
80 test_convert_operator<float>(1.11f);
81 test_convert_operator<int16_t>(1);
82 test_convert_operator<uint16_t>(1);
83 test_convert_operator<int32_t>(1);
84 test_convert_operator<uint32_t>(1);
85 test_convert_operator<int64_t>(1);
86 test_convert_operator<uint64_t>(1);
87 float16 a(1.11f);
88 ASSERT_EQ(static_cast<float16>(Complex<float>(a)), a);
89 }
90
TEST_F(TestComplex,test_assign_operator)91 TEST_F(TestComplex, test_assign_operator) {
92 Complex<float> a = 1.11f;
93 std::cout << a << std::endl;
94 ASSERT_EQ(a.real(), 1.11f);
95 ASSERT_EQ(a.imag(), float());
96 a = Complex<double>(2.22f, 1.11f);
97 ASSERT_EQ(a.real(), 2.22f);
98 ASSERT_EQ(a.imag(), 1.11f);
99 }
100
101 template <typename T1, typename T2, typename T3>
test_arithmetic_add(T1 lhs,T2 rhs,T3 r)102 void test_arithmetic_add(T1 lhs, T2 rhs, T3 r) {
103 ASSERT_EQ(lhs + rhs, r);
104 if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
105 ASSERT_EQ(lhs += rhs, r);
106 }
107 }
108 template <typename T1, typename T2, typename T3>
test_arithmetic_sub(T1 lhs,T2 rhs,T3 r)109 void test_arithmetic_sub(T1 lhs, T2 rhs, T3 r) {
110 ASSERT_EQ(lhs - rhs, r);
111 if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
112 ASSERT_EQ(lhs -= rhs, r);
113 }
114 }
115 template <typename T1, typename T2, typename T3>
test_arithmetic_mul(T1 lhs,T2 rhs,T3 r)116 void test_arithmetic_mul(T1 lhs, T2 rhs, T3 r) {
117 ASSERT_EQ(lhs * rhs, r);
118 if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
119 ASSERT_EQ(lhs *= rhs, r);
120 }
121 }
122 template <typename T1, typename T2, typename T3>
test_arithmetic_div(T1 lhs,T2 rhs,T3 r)123 void test_arithmetic_div(T1 lhs, T2 rhs, T3 r) {
124 ASSERT_EQ(lhs / rhs, r);
125 if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) {
126 ASSERT_EQ(lhs /= rhs, r);
127 }
128 }
129
TEST_F(TestComplex,test_arithmetic)130 TEST_F(TestComplex, test_arithmetic) {
131 test_arithmetic_add<Complex<float>, Complex<float>, Complex<float>>(
132 Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(2.22, 4.44));
133 test_arithmetic_add<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
134 Complex<float>(2.22, 2.22));
135 test_arithmetic_add<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
136 Complex<float>(2.22, 2.22));
137
138 test_arithmetic_sub<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
139 Complex<float>(1.11, 2.22), Complex<float>(0, 0));
140 test_arithmetic_sub<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(0, 2.22));
141 test_arithmetic_sub<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
142 Complex<float>(0, -2.22));
143
144 test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>(
145 Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284));
146 test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11,
147 Complex<float>(1.2321, 2.4642));
148 test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
149 Complex<float>(1.2321, 2.4642));
150
151 test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22),
152 Complex<float>(1.11, 2.22), Complex<float>(1, 0));
153 test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2));
154 test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22),
155 Complex<float>(0.2, -0.4));
156 }
157
158 } // namespace mindspore
159