• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
17 #define MINDSPORE_CCSRC_UTILS_COPLEX_H_
18 
19 #ifdef ENABLE_GPU
20 #include <thrust/complex.h>
21 #include <cublas_v2.h>
22 #endif
23 #include <complex>
24 #include <limits>
25 #include "base/float16.h"
26 #if defined(__CUDACC__)
27 #define HOST_DEVICE __host__ __device__
28 #else
29 #define HOST_DEVICE
30 #endif
31 
32 namespace mindspore {
33 namespace utils {
34 // Implement Complex for mindspore, inspired by std::complex.
35 constexpr int T_SIZE = 2;
36 template <typename T>
37 struct alignas(sizeof(T) * T_SIZE) Complex {
38   Complex() = default;
39   ~Complex() = default;
40 
41   Complex(const Complex<T> &other) noexcept = default;
42   Complex(Complex<T> &&other) noexcept = default;
43 
44   Complex &operator=(const Complex<T> &other) noexcept = default;
45   Complex &operator=(Complex<T> &&other) noexcept = default;
46 
real_Complex47   HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
48 
49   template <typename U>
ComplexComplex50   inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
51   template <typename U>
52   inline explicit constexpr operator std::complex<U>() const {
53     return std::complex<U>(std::complex<T>(real(), imag()));
54   }
55 
ComplexComplex56   HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
57 #if defined(__CUDACC__)
58   template <typename U>
ComplexComplex59   HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
60 
61   template <typename U>
62   HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
63     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
64   }
65 #endif
66   template <typename U = T>
ComplexComplex67   HOST_DEVICE inline explicit Complex(const Complex<U> &other)
68       : real_(static_cast<T>(other.real())), imag_(static_cast<T>(other.imag())) {}
69 
70   HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
71   HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
72   HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
73   HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
74   HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
int16_tComplex75   HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
uint16_tComplex76   HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
int32_tComplex77   HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
uint32_tComplex78   HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
int64_tComplex79   HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
uint64_tComplex80   HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
81 #if defined(__CUDACC__)
halfComplex82   HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
83 #else
float16Complex84   inline explicit operator float16() const { return static_cast<float16>(real_); }
85 #endif
86 
87   HOST_DEVICE inline Complex<T> &operator=(const T &real) {
88     real_ = real;
89     imag_ = T();
90     return *this;
91   }
92 
93   HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
94     real_ += real;
95     return *this;
96   }
97 
98   HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
99     real_ -= real;
100     return *this;
101   }
102 
103   HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
104     real_ *= real;
105     imag_ *= real;
106     return *this;
107   }
108 
109   // Note: check division by zero before use it.
110   HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
111     real_ /= real;
112     imag_ /= real;
113     return *this;
114   }
115 
116   template <typename U>
117   HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
118     real_ = z.real();
119     imag_ = z.imag();
120     return *this;
121   }
122   template <typename U>
123   HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
124     real_ += z.real();
125     imag_ += z.imag();
126     return *this;
127   }
128   template <typename U>
129   HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
130     real_ -= z.real();
131     imag_ -= z.imag();
132     return *this;
133   }
134   template <typename U>
135   HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
136 
137   // Note: check division by zero before use it.
138   template <typename U>
139   HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
140 
realComplex141   HOST_DEVICE inline constexpr T real() const { return real_; }
imagComplex142   HOST_DEVICE inline constexpr T imag() const { return imag_; }
realComplex143   HOST_DEVICE inline void real(T val) { real_ = val; }
imagComplex144   HOST_DEVICE inline void imag(T val) { imag_ = val; }
145 
146  private:
147   T real_;
148   T imag_;
149 };
150 
151 template <typename T>
152 template <typename U>
153 HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
154   const T real = real_ * z.real() - imag_ * z.imag();
155   imag_ = real_ * z.imag() + imag_ * z.real();
156   real_ = real;
157   return *this;
158 }
159 
160 // Note: check division by zero before use it.
161 template <typename T>
162 template <typename U>
163 HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
164   T a = real_;
165   T b = imag_;
166   U c = z.real();
167   U d = z.imag();
168   auto denominator = c * c + d * d;
169   real_ = (a * c + b * d) / denominator;
170   imag_ = (b * c - a * d) / denominator;
171   return *this;
172 }
173 
174 template <typename T>
175 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
176   Complex<T> result = lhs;
177   result += rhs;
178   return result;
179 }
180 
181 template <typename T>
182 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
183   Complex<T> result = lhs;
184   result += rhs;
185   return result;
186 }
187 
188 template <typename T>
189 HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
190   Complex<T> result = rhs;
191   result += lhs;
192   return result;
193 }
194 
195 template <typename T>
196 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
197   Complex<T> result = lhs;
198   result -= rhs;
199   return result;
200 }
201 
202 template <typename T>
203 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
204   Complex<T> result = lhs;
205   result -= rhs;
206   return result;
207 }
208 
209 template <typename T>
210 HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
211   Complex<T> result(lhs, -rhs.imag());
212   result -= rhs.real();
213   return result;
214 }
215 
216 template <typename T>
217 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
218   Complex<T> result = lhs;
219   result *= rhs;
220   return result;
221 }
222 
223 template <typename T>
224 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
225   Complex<T> result = lhs;
226   result *= rhs;
227   return result;
228 }
229 
230 template <typename T>
231 HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
232   Complex<T> result = rhs;
233   result *= lhs;
234   return result;
235 }
236 
237 // Note: check division by zero before use it.
238 template <typename T>
239 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
240   Complex<T> result = lhs;
241   result /= rhs;
242   return result;
243 }
244 
245 // Note: check division by zero before use it.
246 template <typename T>
247 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
248   Complex<T> result = lhs;
249   result /= rhs;
250   return result;
251 }
252 
253 // Note: check division by zero before use it.
254 template <typename T>
255 HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
256   Complex<T> result = lhs;
257   result /= rhs;
258   return result;
259 }
260 
261 template <typename T>
262 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
263   return z;
264 }
265 
266 template <typename T>
267 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
268   return Complex<T>(-z.real(), -z.imag());
269 }
270 
271 template <typename T>
272 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
273   return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
274 }
275 
276 template <typename T>
277 HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
278   return lhs == rhs.real() && rhs.imag() == 0;
279 }
280 
281 template <typename T>
282 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
283   return lhs.real() == rhs && lhs.imag() == 0;
284 }
285 
286 template <typename T>
287 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
288   return !(lhs == rhs);
289 }
290 
291 template <typename T>
292 HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
293   return !(lhs == rhs);
294 }
295 
296 template <typename T>
297 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
298   return !(lhs == rhs);
299 }
300 
301 template <typename T>
302 inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
303   return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
304 }
305 
306 template <typename T>
tan(const Complex<T> & z)307 HOST_DEVICE inline Complex<T> tan(const Complex<T> &z) {
308   Complex<T> result;
309 #if defined(__CUDACC__)
310   auto thrust_result = thrust::tan(thrust::complex<T>(z));
311   result.real(thrust_result.real());
312   result.imag(thrust_result.imag());
313 #else
314   result(std::tan(std::complex<T>(z)));
315 #endif
316   return result;
317 }
318 
319 template <typename T>
sin(const Complex<T> & z)320 HOST_DEVICE inline Complex<T> sin(const Complex<T> &z) {
321   Complex<T> result;
322 #if defined(__CUDACC__)
323   auto thrust_result = thrust::sin(thrust::complex<T>(z));
324   result.real(thrust_result.real());
325   result.imag(thrust_result.imag());
326 #else
327   result(std::sin(std::complex<T>(z)));
328 #endif
329   return result;
330 }
331 
332 template <typename T>
cos(const Complex<T> & z)333 HOST_DEVICE inline Complex<T> cos(const Complex<T> &z) {
334   Complex<T> result;
335 #if defined(__CUDACC__)
336   auto thrust_result = thrust::cos(thrust::complex<T>(z));
337   result.real(thrust_result.real());
338   result.imag(thrust_result.imag());
339 #else
340   result(std::cos(std::complex<T>(z)));
341 #endif
342   return result;
343 }
344 
345 template <typename T>
acos(const Complex<T> & z)346 HOST_DEVICE inline Complex<T> acos(const Complex<T> &z) {
347   Complex<T> result;
348 #if defined(__CUDACC__)
349   auto thrust_result = thrust::acos(thrust::complex<T>(z));
350   result.real(thrust_result.real());
351   result.imag(thrust_result.imag());
352 #else
353   result(std::acos(std::complex<T>(z)));
354 #endif
355   return result;
356 }
357 
358 template <typename T>
acosh(const Complex<T> & z)359 HOST_DEVICE inline Complex<T> acosh(const Complex<T> &z) {
360   Complex<T> result;
361 #if defined(__CUDACC__)
362   auto thrust_result = thrust::acosh(thrust::complex<T>(z));
363   result.real(thrust_result.real());
364   result.imag(thrust_result.imag());
365 #else
366   result(std::acosh(std::complex<T>(z)));
367 #endif
368   return result;
369 }
370 
371 template <typename T>
asin(const Complex<T> & z)372 HOST_DEVICE inline Complex<T> asin(const Complex<T> &z) {
373   Complex<T> result;
374 #if defined(__CUDACC__)
375   auto thrust_result = thrust::asin(thrust::complex<T>(z));
376   result.real(thrust_result.real());
377   result.imag(thrust_result.imag());
378 #else
379   result(std::asin(std::complex<T>(z)));
380 #endif
381   return result;
382 }
383 
384 template <typename T>
asinh(const Complex<T> & z)385 HOST_DEVICE inline Complex<T> asinh(const Complex<T> &z) {
386   Complex<T> result;
387 #if defined(__CUDACC__)
388   auto thrust_result = thrust::asinh(thrust::complex<T>(z));
389   result.real(thrust_result.real());
390   result.imag(thrust_result.imag());
391 #else
392   result(std::asinh(std::complex<T>(z)));
393 #endif
394   return result;
395 }
396 
397 template <typename T>
atan(const Complex<T> & z)398 HOST_DEVICE inline Complex<T> atan(const Complex<T> &z) {
399   Complex<T> result;
400 #if defined(__CUDACC__)
401   auto thrust_result = thrust::atan(thrust::complex<T>(z));
402   result.real(thrust_result.real());
403   result.imag(thrust_result.imag());
404 #else
405   result(std::tan(std::complex<T>(z)));
406 #endif
407   return result;
408 }
409 
410 template <typename T>
atanh(const Complex<T> & z)411 HOST_DEVICE inline Complex<T> atanh(const Complex<T> &z) {
412   Complex<T> result;
413 #if defined(__CUDACC__)
414   auto thrust_result = thrust::atanh(thrust::complex<T>(z));
415   result.real(thrust_result.real());
416   result.imag(thrust_result.imag());
417 #else
418   result(std::tan(std::complex<T>(z)));
419 #endif
420   return result;
421 }
422 
423 template <typename T>
conj(const Complex<T> & z)424 HOST_DEVICE inline Complex<T> conj(const Complex<T> &z) {
425   Complex<T> result;
426 #if defined(__CUDACC__)
427   auto thrust_result = thrust::conj(thrust::complex<T>(z));
428   result.real(thrust_result.real());
429   result.imag(thrust_result.imag());
430 #else
431   result(std::conj(std::complex<T>(z)));
432 #endif
433   return result;
434 }
435 
436 template <typename T>
sqrt(const Complex<T> & z)437 HOST_DEVICE inline Complex<T> sqrt(const Complex<T> &z) {
438   Complex<T> result;
439 #if defined(__CUDACC__)
440   auto thrust_result = thrust::sqrt(thrust::complex<T>(z));
441   result.real(thrust_result.real());
442   result.imag(thrust_result.imag());
443 #else
444   result(std::sqrt(std::complex<T>(z)));
445 #endif
446   return result;
447 }
448 
449 template <typename T>
tanh(const Complex<T> & z)450 HOST_DEVICE inline Complex<T> tanh(const Complex<T> &z) {
451   Complex<T> result;
452 #if defined(__CUDACC__)
453   auto thrust_result = thrust::tanh(thrust::complex<T>(z));
454   result.real(thrust_result.real());
455   result.imag(thrust_result.imag());
456 #else
457   result(std::tanh(std::complex<T>(z)));
458 #endif
459   return result;
460 }
461 
462 template <typename T>
abs(const Complex<T> & z)463 HOST_DEVICE inline T abs(const Complex<T> &z) {
464 #if defined(__CUDACC__)
465   return thrust::abs(thrust::complex<T>(z));
466 #else
467   return std::abs(std::complex<T>(z));
468 #endif
469 }
470 
471 template <typename T>
log(const Complex<T> & z)472 HOST_DEVICE inline Complex<T> log(const Complex<T> &z) {
473   Complex<T> result;
474 #if defined(__CUDACC__)
475   auto thrust_result = thrust::log(thrust::complex<T>(z));
476   result.real(thrust_result.real());
477   result.imag(thrust_result.imag());
478 #else
479   result(std::log(std::complex<T>(z)));
480 #endif
481   return result;
482 }
483 
484 template <typename T>
exp(const Complex<T> & z)485 HOST_DEVICE inline Complex<T> exp(const Complex<T> &z) {
486   Complex<T> result;
487 #if defined(__CUDACC__)
488   auto thrust_result = thrust::exp(thrust::complex<T>(z));
489   result.real(thrust_result.real());
490   result.imag(thrust_result.imag());
491 #else
492   result(std::exp(std::complex<T>(z)));
493 #endif
494   return result;
495 }
496 
497 template <typename T>
cosh(const Complex<T> & z)498 HOST_DEVICE inline Complex<T> cosh(const Complex<T> &z) {
499   Complex<T> result;
500 #if defined(__CUDACC__)
501   auto thrust_result = thrust::cosh(thrust::complex<T>(z));
502   result.real(thrust_result.real());
503   result.imag(thrust_result.imag());
504 #else
505   result(std::cosh(std::complex<T>(z)));
506 #endif
507   return result;
508 }
509 
510 template <typename T>
sinh(const Complex<T> & z)511 HOST_DEVICE inline Complex<T> sinh(const Complex<T> &z) {
512   Complex<T> result;
513 #if defined(__CUDACC__)
514   auto thrust_result = thrust::sinh(thrust::complex<T>(z));
515   result.real(thrust_result.real());
516   result.imag(thrust_result.imag());
517 #else
518   result(std::sinh(std::complex<T>(z)));
519 #endif
520   return result;
521 }
522 
523 template <typename T>
isfinite(const Complex<T> & z)524 HOST_DEVICE inline bool isfinite(const Complex<T> &z) {
525   return std::isfinite(z.real()) || std::isfinite(z.imag());
526 }
527 }  // namespace utils
528 }  // namespace mindspore
529 
530 template <typename T>
531 using Complex = mindspore::utils::Complex<T>;
532 namespace std {
533 template <typename T>
534 class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
535 }  // namespace std
536 #endif  // MINDSPORE_CCSRC_UTILS_COPLEX_H_
537