• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_
17 #define MINDSPORE_CCSRC_UTILS_COPLEX_H_
18 
19 #include <complex>
20 #include <limits>
21 #ifdef ENABLE_GPU
22 #include <thrust/complex.h>
23 #include <cublas_v2.h>
24 #endif
25 #include "base/float16.h"
26 #if defined(__CUDACC__)
27 #define HOST_DEVICE __host__ __device__
28 #else
29 #define HOST_DEVICE
30 #endif
31 
32 namespace mindspore {
33 namespace utils {
34 // Implement Complex for mindspore, inspired by std::complex.
35 template <typename T>
36 struct alignas(sizeof(T) * 2) Complex {
37   Complex() = default;
38   ~Complex() = default;
39 
40   Complex(const Complex<T> &other) noexcept = default;
41   Complex(Complex<T> &&other) noexcept = default;
42 
43   Complex &operator=(const Complex<T> &other) noexcept = default;
44   Complex &operator=(Complex<T> &&other) noexcept = default;
45 
real_Complex46   HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
47 
48   template <typename U>
ComplexComplex49   inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {}
50   template <typename U>
51   inline explicit constexpr operator std::complex<U>() const {
52     return std::complex<U>(std::complex<T>(real(), imag()));
53   }
54 
ComplexComplex55   HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
56 #if defined(__CUDACC__)
57   template <typename U>
ComplexComplex58   HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {}
59 
60   template <typename U>
61   HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const {
62     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
63   }
64 #endif
65   template <typename U = T>
ComplexComplex66   HOST_DEVICE inline explicit Complex(const Complex<U> &other)
67       : real_(static_cast<T>(other.real())), imag_(static_cast<T>(other.imag())) {}
68 
69   HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); }
70   HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); }
71   HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); }
72   HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); }
73   HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); }
int16_tComplex74   HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); }
uint16_tComplex75   HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); }
int32_tComplex76   HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); }
uint32_tComplex77   HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); }
int64_tComplex78   HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); }
uint64_tComplex79   HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); }
80 #if defined(__CUDACC__)
halfComplex81   HOST_DEVICE inline explicit operator half() const { return static_cast<half>(real_); }
82 #else
float16Complex83   inline explicit operator float16() const { return static_cast<float16>(real_); }
84 #endif
85 
86   HOST_DEVICE inline Complex<T> &operator=(const T &real) {
87     real_ = real;
88     imag_ = T();
89     return *this;
90   }
91 
92   HOST_DEVICE inline Complex<T> &operator+=(const T &real) {
93     real_ += real;
94     return *this;
95   }
96 
97   HOST_DEVICE inline Complex<T> &operator-=(const T &real) {
98     real_ -= real;
99     return *this;
100   }
101 
102   HOST_DEVICE inline Complex<T> &operator*=(const T &real) {
103     real_ *= real;
104     imag_ *= real;
105     return *this;
106   }
107 
108   // Note: check division by zero before use it.
109   HOST_DEVICE inline Complex<T> &operator/=(const T &real) {
110     real_ /= real;
111     imag_ /= real;
112     return *this;
113   }
114 
115   template <typename U>
116   HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) {
117     real_ = z.real();
118     imag_ = z.imag();
119     return *this;
120   }
121   template <typename U>
122   HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) {
123     real_ += z.real();
124     imag_ += z.imag();
125     return *this;
126   }
127   template <typename U>
128   HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) {
129     real_ -= z.real();
130     imag_ -= z.imag();
131     return *this;
132   }
133   template <typename U>
134   HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z);
135 
136   // Note: check division by zero before use it.
137   template <typename U>
138   HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z);
139 
realComplex140   HOST_DEVICE inline constexpr T real() const { return real_; }
imagComplex141   HOST_DEVICE inline constexpr T imag() const { return imag_; }
realComplex142   HOST_DEVICE inline void real(T val) { real_ = val; }
imagComplex143   HOST_DEVICE inline void imag(T val) { imag_ = val; }
144 
145  private:
146   T real_;
147   T imag_;
148 };
149 
150 template <typename T>
151 template <typename U>
152 HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) {
153   const T real = real_ * z.real() - imag_ * z.imag();
154   imag_ = real_ * z.imag() + imag_ * z.real();
155   real_ = real;
156   return *this;
157 }
158 
159 // Note: check division by zero before use it.
160 template <typename T>
161 template <typename U>
162 HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) {
163   T a = real_;
164   T b = imag_;
165   U c = z.real();
166   U d = z.imag();
167   auto denominator = c * c + d * d;
168   real_ = (a * c + b * d) / denominator;
169   imag_ = (b * c - a * d) / denominator;
170   return *this;
171 }
172 
173 template <typename T>
174 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) {
175   Complex<T> result = lhs;
176   result += rhs;
177   return result;
178 }
179 
180 template <typename T>
181 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) {
182   Complex<T> result = lhs;
183   result += rhs;
184   return result;
185 }
186 
187 template <typename T>
188 HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) {
189   Complex<T> result = rhs;
190   result += lhs;
191   return result;
192 }
193 
194 template <typename T>
195 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) {
196   Complex<T> result = lhs;
197   result -= rhs;
198   return result;
199 }
200 
201 template <typename T>
202 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) {
203   Complex<T> result = lhs;
204   result -= rhs;
205   return result;
206 }
207 
208 template <typename T>
209 HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) {
210   Complex<T> result(lhs, -rhs.imag());
211   result -= rhs.real();
212   return result;
213 }
214 
215 template <typename T>
216 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) {
217   Complex<T> result = lhs;
218   result *= rhs;
219   return result;
220 }
221 
222 template <typename T>
223 HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) {
224   Complex<T> result = lhs;
225   result *= rhs;
226   return result;
227 }
228 
229 template <typename T>
230 HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) {
231   Complex<T> result = rhs;
232   result *= lhs;
233   return result;
234 }
235 
236 // Note: check division by zero before use it.
237 template <typename T>
238 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) {
239   Complex<T> result = lhs;
240   result /= rhs;
241   return result;
242 }
243 
244 // Note: check division by zero before use it.
245 template <typename T>
246 HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) {
247   Complex<T> result = lhs;
248   result /= rhs;
249   return result;
250 }
251 
252 // Note: check division by zero before use it.
253 template <typename T>
254 HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) {
255   Complex<T> result = lhs;
256   result /= rhs;
257   return result;
258 }
259 
260 template <typename T>
261 HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) {
262   return z;
263 }
264 
265 template <typename T>
266 HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) {
267   return Complex<T>(-z.real(), -z.imag());
268 }
269 
270 template <typename T>
271 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) {
272   return lhs.real() == rhs.real() && lhs.imag() == rhs.imag();
273 }
274 
275 template <typename T>
276 HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) {
277   return lhs == rhs.real() && rhs.imag() == 0;
278 }
279 
280 template <typename T>
281 HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) {
282   return lhs.real() == rhs && lhs.imag() == 0;
283 }
284 
285 template <typename T>
286 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) {
287   return !(lhs == rhs);
288 }
289 
290 template <typename T>
291 HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) {
292   return !(lhs == rhs);
293 }
294 
295 template <typename T>
296 HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) {
297   return !(lhs == rhs);
298 }
299 
300 template <typename T>
301 inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) {
302   return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j');
303 }
304 
305 template <typename T>
abs(const Complex<T> & z)306 HOST_DEVICE inline T abs(const Complex<T> &z) {
307 #if defined(__CUDACC__)
308   return thrust::abs(thrust::complex<T>(z));
309 #else
310   return std::abs(std::complex<T>(z));
311 #endif
312 }
313 }  // namespace utils
314 }  // namespace mindspore
315 
316 template <typename T>
317 using Complex = mindspore::utils::Complex<T>;
318 
319 namespace std {
320 
321 template <typename T>
322 class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {};
323 
324 }  // namespace std
325 
326 #endif  // MINDSPORE_CCSRC_UTILS_COPLEX_H_
327