1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
17 #define MINDSPORE_CCSRC_UTILS_COPLEX_H_
18
19 #include <complex>
20 #include <limits>
21 #ifdef ENABLE_GPU
22 #include <thrust/complex.h>
23 #include <cublas_v2.h>
24 #endif
25 #include "base/float16.h"
26 #if defined(__CUDACC__)
27 #define HOST_DEVICE __host__ __device__
28 #else
29 #define HOST_DEVICE
30 #endif
31
32 namespace mindspore {
33 namespace utils {
34 // Implement Complex for mindspore, inspired by std::complex.
35 template <typename T>
36 struct alignas(sizeof(T) * 2) Complex {
37 Complex() = default;
38 ~Complex() = default;
39
40 Complex(const Complex<T> &other) noexcept = default;
41 Complex(Complex<T> &&other) noexcept = default;
42
43 Complex &operator=(const Complex<T> &other) noexcept = default;
44 Complex &operator=(Complex<T> &&other) noexcept = default;
45
real_Complex46 HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
47
48 template <typename U>
ComplexComplex49 inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
50 template <typename U>
51 inline explicit constexpr operator std::complex<U>() const {
52 return std::complex<U>(std::complex<T>(real(), imag()));
53 }
54
ComplexComplex55 HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
56 #if defined(__CUDACC__)
57 template <typename U>
ComplexComplex58 HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
59
60 template <typename U>
61 HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
62 return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
63 }
64 #endif
65 template <typename U = T>
ComplexComplex66 HOST_DEVICE inline explicit Complex(const Complex<U> &other)
67 : real_(static_cast<T>(other.real())), imag_(static_cast<T>(other.imag())) {}
68
69 HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
70 HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
71 HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
72 HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
73 HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
int16_tComplex74 HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
uint16_tComplex75 HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
int32_tComplex76 HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
uint32_tComplex77 HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
int64_tComplex78 HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
uint64_tComplex79 HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
80 #if defined(__CUDACC__)
halfComplex81 HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
82 #else
float16Complex83 inline explicit operator float16() const { return static_cast<float16>(real_); }
84 #endif
85
86 HOST_DEVICE inline Complex<T> &operator=(const T &real) {
87 real_ = real;
88 imag_ = T();
89 return *this;
90 }
91
92 HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
93 real_ += real;
94 return *this;
95 }
96
97 HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
98 real_ -= real;
99 return *this;
100 }
101
102 HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
103 real_ *= real;
104 imag_ *= real;
105 return *this;
106 }
107
108 // Note: check division by zero before use it.
109 HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
110 real_ /= real;
111 imag_ /= real;
112 return *this;
113 }
114
115 template <typename U>
116 HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
117 real_ = z.real();
118 imag_ = z.imag();
119 return *this;
120 }
121 template <typename U>
122 HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
123 real_ += z.real();
124 imag_ += z.imag();
125 return *this;
126 }
127 template <typename U>
128 HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
129 real_ -= z.real();
130 imag_ -= z.imag();
131 return *this;
132 }
133 template <typename U>
134 HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
135
136 // Note: check division by zero before use it.
137 template <typename U>
138 HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
139
realComplex140 HOST_DEVICE inline constexpr T real() const { return real_; }
imagComplex141 HOST_DEVICE inline constexpr T imag() const { return imag_; }
realComplex142 HOST_DEVICE inline void real(T val) { real_ = val; }
imagComplex143 HOST_DEVICE inline void imag(T val) { imag_ = val; }
144
145 private:
146 T real_;
147 T imag_;
148 };
149
150 template <typename T>
151 template <typename U>
152 HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
153 const T real = real_ * z.real() - imag_ * z.imag();
154 imag_ = real_ * z.imag() + imag_ * z.real();
155 real_ = real;
156 return *this;
157 }
158
159 // Note: check division by zero before use it.
160 template <typename T>
161 template <typename U>
162 HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
163 T a = real_;
164 T b = imag_;
165 U c = z.real();
166 U d = z.imag();
167 auto denominator = c * c + d * d;
168 real_ = (a * c + b * d) / denominator;
169 imag_ = (b * c - a * d) / denominator;
170 return *this;
171 }
172
173 template <typename T>
174 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
175 Complex<T> result = lhs;
176 result += rhs;
177 return result;
178 }
179
180 template <typename T>
181 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
182 Complex<T> result = lhs;
183 result += rhs;
184 return result;
185 }
186
187 template <typename T>
188 HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
189 Complex<T> result = rhs;
190 result += lhs;
191 return result;
192 }
193
194 template <typename T>
195 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
196 Complex<T> result = lhs;
197 result -= rhs;
198 return result;
199 }
200
201 template <typename T>
202 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
203 Complex<T> result = lhs;
204 result -= rhs;
205 return result;
206 }
207
208 template <typename T>
209 HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
210 Complex<T> result(lhs, -rhs.imag());
211 result -= rhs.real();
212 return result;
213 }
214
215 template <typename T>
216 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
217 Complex<T> result = lhs;
218 result *= rhs;
219 return result;
220 }
221
222 template <typename T>
223 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
224 Complex<T> result = lhs;
225 result *= rhs;
226 return result;
227 }
228
229 template <typename T>
230 HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
231 Complex<T> result = rhs;
232 result *= lhs;
233 return result;
234 }
235
236 // Note: check division by zero before use it.
237 template <typename T>
238 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
239 Complex<T> result = lhs;
240 result /= rhs;
241 return result;
242 }
243
244 // Note: check division by zero before use it.
245 template <typename T>
246 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
247 Complex<T> result = lhs;
248 result /= rhs;
249 return result;
250 }
251
252 // Note: check division by zero before use it.
253 template <typename T>
254 HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
255 Complex<T> result = lhs;
256 result /= rhs;
257 return result;
258 }
259
260 template <typename T>
261 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
262 return z;
263 }
264
265 template <typename T>
266 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
267 return Complex<T>(-z.real(), -z.imag());
268 }
269
270 template <typename T>
271 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
272 return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
273 }
274
275 template <typename T>
276 HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
277 return lhs == rhs.real() && rhs.imag() == 0;
278 }
279
280 template <typename T>
281 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
282 return lhs.real() == rhs && lhs.imag() == 0;
283 }
284
285 template <typename T>
286 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
287 return !(lhs == rhs);
288 }
289
290 template <typename T>
291 HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
292 return !(lhs == rhs);
293 }
294
295 template <typename T>
296 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
297 return !(lhs == rhs);
298 }
299
300 template <typename T>
301 inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
302 return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
303 }
304
305 template <typename T>
abs(const Complex<T> & z)306 HOST_DEVICE inline T abs(const Complex<T> &z) {
307 #if defined(__CUDACC__)
308 return thrust::abs(thrust::complex<T>(z));
309 #else
310 return std::abs(std::complex<T>(z));
311 #endif
312 }
313 } // namespace utils
314 } // namespace mindspore
315
316 template <typename T>
317 using Complex = mindspore::utils::Complex<T>;
318
319 namespace std {
320
321 template <typename T>
322 class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
323
324 } // namespace std
325
326 #endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_
327