1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_BASE_BFLOAT16_H_ 17 #define MINDSPORE_CORE_BASE_BFLOAT16_H_ 18 19 #include <type_traits> 20 #include <cmath> 21 #include <climits> 22 #include <cstdint> 23 #include <cstring> 24 #include <ostream> 25 #include <limits> 26 #include <functional> 27 #include "third_party/securec/include/securec.h" 28 29 // Implement BFloat16 for mindspore, inspired by Eigen::half. 30 namespace mindspore { 31 class BFloat16 { 32 public: 33 static constexpr uint16_t value_mask = 0x7fff; 34 static constexpr uint16_t inf_value = 0x7f80; 35 static constexpr uint16_t nan_value = 0x7fc0; 36 static constexpr uint16_t true_value = 0x3c00; 37 static constexpr uint32_t f32_inf_value = 0x7f800000; 38 39 BFloat16() = default; 40 ~BFloat16() = default; 41 42 BFloat16(const BFloat16 &other) noexcept = default; 43 BFloat16(BFloat16 &&other) noexcept = default; 44 45 BFloat16 &operator=(const BFloat16 &other) noexcept = default; 46 BFloat16 &operator=(BFloat16 &&other) noexcept = default; 47 FromRaw(uint16_t v)48 static BFloat16 FromRaw(uint16_t v) { 49 BFloat16 f; 50 f.value_ = v; 51 return f; 52 } 53 BFloat16(float f)54 explicit BFloat16(float f) : value_(FromFloat32(f)) {} BFloat16(bool b)55 explicit BFloat16(bool b) : value_(b ? true_value : 0) {} 56 template <typename T> BFloat16(const T & v)57 explicit BFloat16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {} 58 int_value()59 uint16_t int_value() const { return value_; } 60 61 template <typename T> T()62 explicit operator T() const { 63 return static_cast<T>(ToFloat32(*this)); 64 } 65 66 explicit operator bool() const { return (value_ & value_mask) != 0; } 67 explicit operator float() const { return ToFloat32(*this); } 68 69 BFloat16 &operator+=(const BFloat16 &b) { 70 value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b)); 71 return *this; 72 } 73 74 BFloat16 &operator-=(const BFloat16 &b) { 75 value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b)); 76 return *this; 77 } 78 79 BFloat16 &operator*=(const BFloat16 &b) { 80 value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b)); 81 return *this; 82 } 83 84 BFloat16 &operator/=(const BFloat16 &b) { 85 value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b)); 86 return *this; 87 } 88 ToFloat32(const BFloat16 & bf16)89 static float ToFloat32(const BFloat16 &bf16) { 90 // We should use memcpy in order to respect the strict aliasing rule. 91 float f32 = 0; 92 uint32_t f32_tmp = bf16.int_value(); 93 f32_tmp <<= 16; 94 auto ret_code = memcpy_s(&f32, sizeof(f32), &f32_tmp, sizeof(f32_tmp)); 95 if (ret_code != 0) { 96 return f32_inf_value; 97 } 98 return f32; 99 } 100 101 private: FromFloat32(float f32)102 static uint16_t FromFloat32(float f32) { 103 if (std::isnan(f32)) { 104 return nan_value; 105 } else { 106 union { 107 uint32_t U32; 108 float F32; 109 }; 110 F32 = f32; 111 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); 112 return static_cast<uint16_t>((U32 + rounding_bias) >> 16); 113 } 114 } 115 116 uint16_t value_; 117 }; 118 119 inline BFloat16 operator+(const BFloat16 &a, const BFloat16 &b) { 120 return BFloat16(static_cast<float>(a) + static_cast<float>(b)); 121 } 122 123 inline BFloat16 operator*(const BFloat16 &a, const BFloat16 &b) { 124 return BFloat16(static_cast<float>(a) * static_cast<float>(b)); 125 } 126 127 inline BFloat16 operator-(const BFloat16 &a, const BFloat16 &b) { 128 return BFloat16(static_cast<float>(a) - static_cast<float>(b)); 129 } 130 131 inline BFloat16 operator/(const BFloat16 &a, const BFloat16 &b) { 132 return BFloat16(static_cast<float>(a) / static_cast<float>(b)); 133 } 134 135 // Division by an size_t. Do it in full float precision to avoid 136 // accuracy issues in converting the denominator to bfloat16. 137 inline BFloat16 operator/(const BFloat16 &a, size_t b) { 138 return BFloat16(static_cast<float>(a) / static_cast<float>(b)); 139 } 140 141 inline BFloat16 operator-(const BFloat16 &a) { 142 constexpr uint16_t sign_mask = 0x8000; 143 return BFloat16::FromRaw(a.int_value() ^ sign_mask); 144 } 145 146 inline bool operator==(const BFloat16 &a, const BFloat16 &b) { 147 return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 148 } 149 150 inline bool operator!=(const BFloat16 &a, const BFloat16 &b) { 151 return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 152 } 153 154 inline bool operator<(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) < static_cast<float>(b); } 155 inline bool operator<=(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) <= static_cast<float>(b); } 156 inline bool operator>(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) > static_cast<float>(b); } 157 inline bool operator>=(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) >= static_cast<float>(b); } 158 159 inline std::ostream &operator<<(std::ostream &os, const BFloat16 &v) { return (os << static_cast<float>(v)); } 160 161 } // namespace mindspore 162 163 using bfloat16 = mindspore::BFloat16; 164 165 namespace std { 166 template <> 167 struct hash<bfloat16> { 168 std::size_t operator()(const bfloat16 &bf16) const noexcept { return static_cast<std::size_t>(bf16.int_value()); } 169 }; 170 171 template <> 172 struct is_floating_point<bfloat16> : public std::true_type {}; 173 174 template <> 175 struct is_signed<bfloat16> : public std::true_type {}; 176 177 template <> 178 struct numeric_limits<bfloat16> { 179 static constexpr bool is_specialized = true; 180 static constexpr bool is_signed = true; 181 static constexpr bool is_integer = false; 182 static constexpr bool is_exact = false; 183 static constexpr bool has_infinity = true; 184 static constexpr bool has_quiet_NaN = true; 185 static constexpr bool has_signaling_NaN = true; 186 static constexpr std::float_denorm_style has_denorm = numeric_limits<float>::has_denorm; 187 static constexpr bool has_denorm_loss = numeric_limits<float>::has_denorm_loss; 188 static constexpr std::float_round_style round_style = numeric_limits<float>::round_style; 189 static constexpr bool is_iec559 = false; 190 static constexpr bool is_bounded = true; 191 static constexpr bool is_modulo = false; 192 static constexpr int digits = 8; 193 static constexpr int digits10 = 2; 194 static constexpr int max_digits10 = 4; 195 static constexpr int radix = 2; 196 static constexpr int min_exponent = -125; 197 static constexpr int min_exponent10 = -37; 198 static constexpr int max_exponent = 128; 199 static constexpr int max_exponent10 = 38; 200 static constexpr bool traps = numeric_limits<float>::traps; 201 static constexpr bool tinyness_before = numeric_limits<float>::tinyness_before; 202 203 static constexpr uint16_t raw_min = 0x0080; 204 static constexpr uint16_t raw_max = 0x7f7f; 205 static constexpr uint16_t raw_lowest = 0xff7f; 206 static constexpr uint16_t raw_epsilon = 0x3c00; 207 static constexpr uint16_t raw_round_error = 0x3f00; 208 static constexpr uint16_t raw_infinity = 0x7f80; 209 static constexpr uint16_t raw_quiet_nan = 0x7fc0; 210 static constexpr uint16_t raw_signaling_nan = 0x7f80; 211 static constexpr uint16_t raw_denorm_min = 0x0001; 212 213 static bfloat16(min)() noexcept { return bfloat16::FromRaw(raw_min); } 214 static bfloat16(max)() noexcept { return bfloat16::FromRaw(raw_max); } 215 static bfloat16 lowest() noexcept { return bfloat16::FromRaw(raw_lowest); } 216 static bfloat16 epsilon() noexcept { return bfloat16::FromRaw(raw_epsilon); } 217 static bfloat16 round_error() noexcept { return bfloat16::FromRaw(raw_round_error); } 218 static bfloat16 infinity() noexcept { return bfloat16::FromRaw(raw_infinity); } 219 static bfloat16 quiet_NaN() noexcept { return bfloat16::FromRaw(raw_quiet_nan); } 220 static bfloat16 signaling_NaN() noexcept { return bfloat16::FromRaw(raw_signaling_nan); } 221 static bfloat16 denorm_min() noexcept { return bfloat16::FromRaw(raw_denorm_min); } 222 }; 223 224 // If std::numeric_limits<T> is specialized, should also specialize 225 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and 226 // std::numeric_limits<const volatile T> 227 // https://stackoverflow.com/a/16519653/ 228 template <> 229 struct numeric_limits<const mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {}; 230 template <> 231 struct numeric_limits<volatile mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {}; 232 template <> 233 struct numeric_limits<const volatile mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {}; 234 } // namespace std 235 236 // Implements standard math functions for bfloat16. 237 inline bool(isinf)(const bfloat16 &a) { return (a.int_value() & bfloat16::value_mask) == bfloat16::inf_value; } 238 inline bool(isnan)(const bfloat16 &a) { return (a.int_value() & bfloat16::value_mask) > bfloat16::inf_value; } 239 inline bool(isfinite)(const bfloat16 &a) { return !(isinf(a)) && !(isnan(a)); } 240 inline bfloat16 abs(const bfloat16 &a) { return bfloat16::FromRaw(a.int_value() & bfloat16::value_mask); } 241 inline bfloat16 exp(const bfloat16 &a) { return bfloat16(::expf(static_cast<float>(a))); } 242 inline bfloat16 log(const bfloat16 &a) { return bfloat16(::logf(static_cast<float>(a))); } 243 inline bfloat16 log1p(const bfloat16 &a) { return bfloat16(::log1pf(static_cast<float>(a))); } 244 inline bfloat16 log10(const bfloat16 &a) { return bfloat16(::log10f(static_cast<float>(a))); } 245 inline bfloat16 sqrt(const bfloat16 &a) { return bfloat16(::sqrtf(static_cast<float>(a))); } 246 inline bfloat16 sin(const bfloat16 &a) { return bfloat16(::sinf(static_cast<float>(a))); } 247 inline bfloat16 cos(const bfloat16 &a) { return bfloat16(::cosf(static_cast<float>(a))); } 248 inline bfloat16 tan(const bfloat16 &a) { return bfloat16(::tanf(static_cast<float>(a))); } 249 inline bfloat16 tanh(const bfloat16 &a) { return bfloat16(::tanhf(static_cast<float>(a))); } 250 inline bfloat16 floor(const bfloat16 &a) { return bfloat16(::floorf(static_cast<float>(a))); } 251 inline bfloat16 ceil(const bfloat16 &a) { return bfloat16(::ceilf(static_cast<float>(a))); } 252 inline bfloat16(min)(const bfloat16 &a, const bfloat16 &b) { return b < a ? b : a; } 253 inline bfloat16(max)(const bfloat16 &a, const bfloat16 &b) { return a < b ? b : a; } 254 inline bfloat16 pow(const bfloat16 &a, const bfloat16 &b) { 255 return bfloat16(::powf(static_cast<float>(a), static_cast<float>(b))); 256 } 257 258 inline float bfloat_to_float(const bfloat16 &h) { return static_cast<float>(h); } 259 260 #endif // MINDSPORE_CORE_BASE_BFLOAT16_H_ 261