1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
17 #define MINDSPORE_CCSRC_UTILS_COPLEX_H_
18
19 #ifdef ENABLE_GPU
20 #include <thrust/complex.h>
21 #include <cublas_v2.h>
22 #endif
23 #include <complex>
24 #include <limits>
25 #include "base/float16.h"
26 #if defined(__CUDACC__)
27 #define HOST_DEVICE __host__ __device__
28 #else
29 #define HOST_DEVICE
30 #endif
31
32 namespace mindspore {
33 namespace utils {
34 // Implement Complex for mindspore, inspired by std::complex.
35 constexpr int T_SIZE = 2;
36 template <typename T>
37 struct alignas(sizeof(T) * T_SIZE) Complex {
38 Complex() = default;
39 ~Complex() = default;
40
41 Complex(const Complex<T> &other) noexcept = default;
42 Complex(Complex<T> &&other) noexcept = default;
43
44 Complex &operator=(const Complex<T> &other) noexcept = default;
45 Complex &operator=(Complex<T> &&other) noexcept = default;
46
real_Complex47 HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
48
49 template <typename U>
ComplexComplex50 inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
51 template <typename U>
52 inline explicit constexpr operator std::complex<U>() const {
53 return std::complex<U>(std::complex<T>(real(), imag()));
54 }
55
ComplexComplex56 HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
57 #if defined(__CUDACC__)
58 template <typename U>
ComplexComplex59 HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
60
61 template <typename U>
62 HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
63 return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
64 }
65 #endif
66 template <typename U = T>
ComplexComplex67 HOST_DEVICE inline explicit Complex(const Complex<U> &other)
68 : real_(static_cast<T>(other.real())), imag_(static_cast<T>(other.imag())) {}
69
70 HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
71 HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
72 HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
73 HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
74 HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
int16_tComplex75 HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
uint16_tComplex76 HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
int32_tComplex77 HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
uint32_tComplex78 HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
int64_tComplex79 HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
uint64_tComplex80 HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
81 #if defined(__CUDACC__)
halfComplex82 HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
83 #else
float16Complex84 inline explicit operator float16() const { return static_cast<float16>(real_); }
85 #endif
86
87 HOST_DEVICE inline Complex<T> &operator=(const T &real) {
88 real_ = real;
89 imag_ = T();
90 return *this;
91 }
92
93 HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
94 real_ += real;
95 return *this;
96 }
97
98 HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
99 real_ -= real;
100 return *this;
101 }
102
103 HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
104 real_ *= real;
105 imag_ *= real;
106 return *this;
107 }
108
109 // Note: check division by zero before use it.
110 HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
111 real_ /= real;
112 imag_ /= real;
113 return *this;
114 }
115
116 template <typename U>
117 HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
118 real_ = z.real();
119 imag_ = z.imag();
120 return *this;
121 }
122 template <typename U>
123 HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
124 real_ += z.real();
125 imag_ += z.imag();
126 return *this;
127 }
128 template <typename U>
129 HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
130 real_ -= z.real();
131 imag_ -= z.imag();
132 return *this;
133 }
134 template <typename U>
135 HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
136
137 // Note: check division by zero before use it.
138 template <typename U>
139 HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
140
realComplex141 HOST_DEVICE inline constexpr T real() const { return real_; }
imagComplex142 HOST_DEVICE inline constexpr T imag() const { return imag_; }
realComplex143 HOST_DEVICE inline void real(T val) { real_ = val; }
imagComplex144 HOST_DEVICE inline void imag(T val) { imag_ = val; }
145
146 private:
147 T real_;
148 T imag_;
149 };
150
151 template <typename T>
152 template <typename U>
153 HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
154 const T real = real_ * z.real() - imag_ * z.imag();
155 imag_ = real_ * z.imag() + imag_ * z.real();
156 real_ = real;
157 return *this;
158 }
159
160 // Note: check division by zero before use it.
161 template <typename T>
162 template <typename U>
163 HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
164 T a = real_;
165 T b = imag_;
166 U c = z.real();
167 U d = z.imag();
168 auto denominator = c * c + d * d;
169 real_ = (a * c + b * d) / denominator;
170 imag_ = (b * c - a * d) / denominator;
171 return *this;
172 }
173
174 template <typename T>
175 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
176 Complex<T> result = lhs;
177 result += rhs;
178 return result;
179 }
180
181 template <typename T>
182 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
183 Complex<T> result = lhs;
184 result += rhs;
185 return result;
186 }
187
188 template <typename T>
189 HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
190 Complex<T> result = rhs;
191 result += lhs;
192 return result;
193 }
194
195 template <typename T>
196 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
197 Complex<T> result = lhs;
198 result -= rhs;
199 return result;
200 }
201
202 template <typename T>
203 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
204 Complex<T> result = lhs;
205 result -= rhs;
206 return result;
207 }
208
209 template <typename T>
210 HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
211 Complex<T> result(lhs, -rhs.imag());
212 result -= rhs.real();
213 return result;
214 }
215
216 template <typename T>
217 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
218 Complex<T> result = lhs;
219 result *= rhs;
220 return result;
221 }
222
223 template <typename T>
224 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
225 Complex<T> result = lhs;
226 result *= rhs;
227 return result;
228 }
229
230 template <typename T>
231 HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
232 Complex<T> result = rhs;
233 result *= lhs;
234 return result;
235 }
236
237 // Note: check division by zero before use it.
238 template <typename T>
239 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
240 Complex<T> result = lhs;
241 result /= rhs;
242 return result;
243 }
244
245 // Note: check division by zero before use it.
246 template <typename T>
247 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
248 Complex<T> result = lhs;
249 result /= rhs;
250 return result;
251 }
252
253 // Note: check division by zero before use it.
254 template <typename T>
255 HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
256 Complex<T> result = lhs;
257 result /= rhs;
258 return result;
259 }
260
261 template <typename T>
262 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
263 return z;
264 }
265
266 template <typename T>
267 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
268 return Complex<T>(-z.real(), -z.imag());
269 }
270
271 template <typename T>
272 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
273 return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
274 }
275
276 template <typename T>
277 HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
278 return lhs == rhs.real() && rhs.imag() == 0;
279 }
280
281 template <typename T>
282 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
283 return lhs.real() == rhs && lhs.imag() == 0;
284 }
285
286 template <typename T>
287 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
288 return !(lhs == rhs);
289 }
290
291 template <typename T>
292 HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
293 return !(lhs == rhs);
294 }
295
296 template <typename T>
297 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
298 return !(lhs == rhs);
299 }
300
301 template <typename T>
302 inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
303 return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
304 }
305
306 template <typename T>
tan(const Complex<T> & z)307 HOST_DEVICE inline Complex<T> tan(const Complex<T> &z) {
308 Complex<T> result;
309 #if defined(__CUDACC__)
310 auto thrust_result = thrust::tan(thrust::complex<T>(z));
311 result.real(thrust_result.real());
312 result.imag(thrust_result.imag());
313 #else
314 result(std::tan(std::complex<T>(z)));
315 #endif
316 return result;
317 }
318
319 template <typename T>
sin(const Complex<T> & z)320 HOST_DEVICE inline Complex<T> sin(const Complex<T> &z) {
321 Complex<T> result;
322 #if defined(__CUDACC__)
323 auto thrust_result = thrust::sin(thrust::complex<T>(z));
324 result.real(thrust_result.real());
325 result.imag(thrust_result.imag());
326 #else
327 result(std::sin(std::complex<T>(z)));
328 #endif
329 return result;
330 }
331
332 template <typename T>
cos(const Complex<T> & z)333 HOST_DEVICE inline Complex<T> cos(const Complex<T> &z) {
334 Complex<T> result;
335 #if defined(__CUDACC__)
336 auto thrust_result = thrust::cos(thrust::complex<T>(z));
337 result.real(thrust_result.real());
338 result.imag(thrust_result.imag());
339 #else
340 result(std::cos(std::complex<T>(z)));
341 #endif
342 return result;
343 }
344
345 template <typename T>
acos(const Complex<T> & z)346 HOST_DEVICE inline Complex<T> acos(const Complex<T> &z) {
347 Complex<T> result;
348 #if defined(__CUDACC__)
349 auto thrust_result = thrust::acos(thrust::complex<T>(z));
350 result.real(thrust_result.real());
351 result.imag(thrust_result.imag());
352 #else
353 result(std::acos(std::complex<T>(z)));
354 #endif
355 return result;
356 }
357
358 template <typename T>
acosh(const Complex<T> & z)359 HOST_DEVICE inline Complex<T> acosh(const Complex<T> &z) {
360 Complex<T> result;
361 #if defined(__CUDACC__)
362 auto thrust_result = thrust::acosh(thrust::complex<T>(z));
363 result.real(thrust_result.real());
364 result.imag(thrust_result.imag());
365 #else
366 result(std::acosh(std::complex<T>(z)));
367 #endif
368 return result;
369 }
370
371 template <typename T>
asin(const Complex<T> & z)372 HOST_DEVICE inline Complex<T> asin(const Complex<T> &z) {
373 Complex<T> result;
374 #if defined(__CUDACC__)
375 auto thrust_result = thrust::asin(thrust::complex<T>(z));
376 result.real(thrust_result.real());
377 result.imag(thrust_result.imag());
378 #else
379 result(std::asin(std::complex<T>(z)));
380 #endif
381 return result;
382 }
383
384 template <typename T>
asinh(const Complex<T> & z)385 HOST_DEVICE inline Complex<T> asinh(const Complex<T> &z) {
386 Complex<T> result;
387 #if defined(__CUDACC__)
388 auto thrust_result = thrust::asinh(thrust::complex<T>(z));
389 result.real(thrust_result.real());
390 result.imag(thrust_result.imag());
391 #else
392 result(std::asinh(std::complex<T>(z)));
393 #endif
394 return result;
395 }
396
397 template <typename T>
atan(const Complex<T> & z)398 HOST_DEVICE inline Complex<T> atan(const Complex<T> &z) {
399 Complex<T> result;
400 #if defined(__CUDACC__)
401 auto thrust_result = thrust::atan(thrust::complex<T>(z));
402 result.real(thrust_result.real());
403 result.imag(thrust_result.imag());
404 #else
405 result(std::tan(std::complex<T>(z)));
406 #endif
407 return result;
408 }
409
410 template <typename T>
atanh(const Complex<T> & z)411 HOST_DEVICE inline Complex<T> atanh(const Complex<T> &z) {
412 Complex<T> result;
413 #if defined(__CUDACC__)
414 auto thrust_result = thrust::atanh(thrust::complex<T>(z));
415 result.real(thrust_result.real());
416 result.imag(thrust_result.imag());
417 #else
418 result(std::tan(std::complex<T>(z)));
419 #endif
420 return result;
421 }
422
423 template <typename T>
conj(const Complex<T> & z)424 HOST_DEVICE inline Complex<T> conj(const Complex<T> &z) {
425 Complex<T> result;
426 #if defined(__CUDACC__)
427 auto thrust_result = thrust::conj(thrust::complex<T>(z));
428 result.real(thrust_result.real());
429 result.imag(thrust_result.imag());
430 #else
431 result(std::conj(std::complex<T>(z)));
432 #endif
433 return result;
434 }
435
436 template <typename T>
sqrt(const Complex<T> & z)437 HOST_DEVICE inline Complex<T> sqrt(const Complex<T> &z) {
438 Complex<T> result;
439 #if defined(__CUDACC__)
440 auto thrust_result = thrust::sqrt(thrust::complex<T>(z));
441 result.real(thrust_result.real());
442 result.imag(thrust_result.imag());
443 #else
444 result(std::sqrt(std::complex<T>(z)));
445 #endif
446 return result;
447 }
448
449 template <typename T>
tanh(const Complex<T> & z)450 HOST_DEVICE inline Complex<T> tanh(const Complex<T> &z) {
451 Complex<T> result;
452 #if defined(__CUDACC__)
453 auto thrust_result = thrust::tanh(thrust::complex<T>(z));
454 result.real(thrust_result.real());
455 result.imag(thrust_result.imag());
456 #else
457 result(std::tanh(std::complex<T>(z)));
458 #endif
459 return result;
460 }
461
462 template <typename T>
abs(const Complex<T> & z)463 HOST_DEVICE inline T abs(const Complex<T> &z) {
464 #if defined(__CUDACC__)
465 return thrust::abs(thrust::complex<T>(z));
466 #else
467 return std::abs(std::complex<T>(z));
468 #endif
469 }
470
471 template <typename T>
log(const Complex<T> & z)472 HOST_DEVICE inline Complex<T> log(const Complex<T> &z) {
473 Complex<T> result;
474 #if defined(__CUDACC__)
475 auto thrust_result = thrust::log(thrust::complex<T>(z));
476 result.real(thrust_result.real());
477 result.imag(thrust_result.imag());
478 #else
479 result(std::log(std::complex<T>(z)));
480 #endif
481 return result;
482 }
483
484 template <typename T>
exp(const Complex<T> & z)485 HOST_DEVICE inline Complex<T> exp(const Complex<T> &z) {
486 Complex<T> result;
487 #if defined(__CUDACC__)
488 auto thrust_result = thrust::exp(thrust::complex<T>(z));
489 result.real(thrust_result.real());
490 result.imag(thrust_result.imag());
491 #else
492 result(std::exp(std::complex<T>(z)));
493 #endif
494 return result;
495 }
496
497 template <typename T>
cosh(const Complex<T> & z)498 HOST_DEVICE inline Complex<T> cosh(const Complex<T> &z) {
499 Complex<T> result;
500 #if defined(__CUDACC__)
501 auto thrust_result = thrust::cosh(thrust::complex<T>(z));
502 result.real(thrust_result.real());
503 result.imag(thrust_result.imag());
504 #else
505 result(std::cosh(std::complex<T>(z)));
506 #endif
507 return result;
508 }
509
510 template <typename T>
sinh(const Complex<T> & z)511 HOST_DEVICE inline Complex<T> sinh(const Complex<T> &z) {
512 Complex<T> result;
513 #if defined(__CUDACC__)
514 auto thrust_result = thrust::sinh(thrust::complex<T>(z));
515 result.real(thrust_result.real());
516 result.imag(thrust_result.imag());
517 #else
518 result(std::sinh(std::complex<T>(z)));
519 #endif
520 return result;
521 }
522
523 template <typename T>
isfinite(const Complex<T> & z)524 HOST_DEVICE inline bool isfinite(const Complex<T> &z) {
525 return std::isfinite(z.real()) || std::isfinite(z.imag());
526 }
527 } // namespace utils
528 } // namespace mindspore
529
530 template <typename T>
531 using Complex = mindspore::utils::Complex<T>;
532 namespace std {
533 template <typename T>
534 class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
535 } // namespace std
536 #endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_
537