1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ 17 #define MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ 18 19 #include "base/float16.h" 20 #include "base/bfloat16.h" 21 #include "utils/ms_utils.h" 22 23 namespace mindspore { 24 constexpr auto kComplexValueUnit = 2; 25 26 template <typename T> 27 struct alignas(sizeof(T) * kComplexValueUnit) ComplexStorage { 28 T real_; 29 T imag_; 30 31 ComplexStorage() = default; 32 ~ComplexStorage() = default; 33 34 ComplexStorage(const ComplexStorage<T> &other) noexcept = default; 35 ComplexStorage(ComplexStorage<T> &&other) noexcept = default; 36 37 ComplexStorage &operator=(const ComplexStorage<T> &other) noexcept = default; 38 ComplexStorage &operator=(ComplexStorage<T> &&other) noexcept = default; 39 real_ComplexStorage40 inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {} 41 #ifndef ENABLE_ARM ComplexStorageComplexStorage42 inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {} ComplexStorageComplexStorage43 inline explicit constexpr ComplexStorage(const bfloat16 &real) : real_(static_cast<T>(real)), imag_(T()) {} 44 #endif 45 template <typename U = T> ComplexStorageComplexStorage46 explicit ComplexStorage(const std::enable_if_t<std::is_same<U, float>::value, ComplexStorage<double>> &other) 47 : real_(other.real_), imag_(other.imag_) {} 48 49 template <typename U = T> ComplexStorageComplexStorage50 explicit ComplexStorage(const std::enable_if_t<std::is_same<U, double>::value, ComplexStorage<float>> &other) 51 : real_(other.real_), imag_(other.imag_) {} 52 53 inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); } 54 inline explicit operator signed char() const { return static_cast<signed char>(real_); } 55 inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); } 56 inline explicit operator double() const { return static_cast<double>(real_); } 57 inline explicit operator float() const { return static_cast<float>(real_); } int16_tComplexStorage58 inline explicit operator int16_t() const { return static_cast<int16_t>(real_); } uint16_tComplexStorage59 inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); } int32_tComplexStorage60 inline explicit operator int32_t() const { return static_cast<int32_t>(real_); } uint32_tComplexStorage61 inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); } int64_tComplexStorage62 inline explicit operator int64_t() const { return static_cast<int64_t>(real_); } uint64_tComplexStorage63 inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); } float16ComplexStorage64 inline explicit operator float16() const { return static_cast<float16>(real_); } bfloat16ComplexStorage65 inline explicit operator bfloat16() const { return static_cast<bfloat16>(real_); } 66 }; 67 68 template <typename T> 69 inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) { 70 if constexpr (std::is_same_v<T, double>) { 71 return common::IsDoubleEqual(lhs.real_, rhs.real_) && common::IsDoubleEqual(lhs.imag_, rhs.imag_); 72 } else if constexpr (std::is_same_v<T, float>) { 73 return common::IsFloatEqual(lhs.real_, rhs.real_) && common::IsFloatEqual(lhs.imag_, rhs.imag_); 74 } 75 return (lhs.real_ == rhs.real_) && (lhs.imag_ == rhs.imag_); 76 } 77 78 template <typename T> 79 inline std::ostream &operator<<(std::ostream &os, const ComplexStorage<T> &v) { 80 return (os << std::noshowpos << v.real_ << std::showpos << v.imag_ << 'j'); 81 } 82 83 } // namespace mindspore 84 85 #endif // MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ 86