• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_BASE_BFLOAT16_H_
17 #define MINDSPORE_CORE_BASE_BFLOAT16_H_
18 
19 #include <type_traits>
20 #include <cmath>
21 #include <climits>
22 #include <cstdint>
23 #include <cstring>
24 #include <ostream>
25 #include <limits>
26 #include <functional>
27 #include "third_party/securec/include/securec.h"
28 
29 // Implement BFloat16 for mindspore, inspired by Eigen::half.
30 namespace mindspore {
31 class BFloat16 {
32  public:
33   static constexpr uint16_t value_mask = 0x7fff;
34   static constexpr uint16_t inf_value = 0x7f80;
35   static constexpr uint16_t nan_value = 0x7fc0;
36   static constexpr uint16_t true_value = 0x3c00;
37   static constexpr uint32_t f32_inf_value = 0x7f800000;
38 
39   BFloat16() = default;
40   ~BFloat16() = default;
41 
42   BFloat16(const BFloat16 &other) noexcept = default;
43   BFloat16(BFloat16 &&other) noexcept = default;
44 
45   BFloat16 &operator=(const BFloat16 &other) noexcept = default;
46   BFloat16 &operator=(BFloat16 &&other) noexcept = default;
47 
FromRaw(uint16_t v)48   static BFloat16 FromRaw(uint16_t v) {
49     BFloat16 f;
50     f.value_ = v;
51     return f;
52   }
53 
BFloat16(float f)54   explicit BFloat16(float f) : value_(FromFloat32(f)) {}
BFloat16(bool b)55   explicit BFloat16(bool b) : value_(b ? true_value : 0) {}
56   template <typename T>
BFloat16(const T & v)57   explicit BFloat16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {}
58 
int_value()59   uint16_t int_value() const { return value_; }
60 
61   template <typename T>
T()62   explicit operator T() const {
63     return static_cast<T>(ToFloat32(*this));
64   }
65 
66   explicit operator bool() const { return (value_ & value_mask) != 0; }
67   explicit operator float() const { return ToFloat32(*this); }
68 
69   BFloat16 &operator+=(const BFloat16 &b) {
70     value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b));
71     return *this;
72   }
73 
74   BFloat16 &operator-=(const BFloat16 &b) {
75     value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b));
76     return *this;
77   }
78 
79   BFloat16 &operator*=(const BFloat16 &b) {
80     value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b));
81     return *this;
82   }
83 
84   BFloat16 &operator/=(const BFloat16 &b) {
85     value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b));
86     return *this;
87   }
88 
ToFloat32(const BFloat16 & bf16)89   static float ToFloat32(const BFloat16 &bf16) {
90     // We should use memcpy in order to respect the strict aliasing rule.
91     float f32 = 0;
92     uint32_t f32_tmp = bf16.int_value();
93     f32_tmp <<= 16;
94     auto ret_code = memcpy_s(&f32, sizeof(f32), &f32_tmp, sizeof(f32_tmp));
95     if (ret_code != 0) {
96       return f32_inf_value;
97     }
98     return f32;
99   }
100 
101  private:
FromFloat32(float f32)102   static uint16_t FromFloat32(float f32) {
103     if (std::isnan(f32)) {
104       return nan_value;
105     } else {
106       union {
107         uint32_t U32;
108         float F32;
109       };
110       F32 = f32;
111       uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
112       return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
113     }
114   }
115 
116   uint16_t value_;
117 };
118 
119 inline BFloat16 operator+(const BFloat16 &a, const BFloat16 &b) {
120   return BFloat16(static_cast<float>(a) + static_cast<float>(b));
121 }
122 
123 inline BFloat16 operator*(const BFloat16 &a, const BFloat16 &b) {
124   return BFloat16(static_cast<float>(a) * static_cast<float>(b));
125 }
126 
127 inline BFloat16 operator-(const BFloat16 &a, const BFloat16 &b) {
128   return BFloat16(static_cast<float>(a) - static_cast<float>(b));
129 }
130 
131 inline BFloat16 operator/(const BFloat16 &a, const BFloat16 &b) {
132   return BFloat16(static_cast<float>(a) / static_cast<float>(b));
133 }
134 
135 // Division by an size_t. Do it in full float precision to avoid
136 // accuracy issues in converting the denominator to bfloat16.
137 inline BFloat16 operator/(const BFloat16 &a, size_t b) {
138   return BFloat16(static_cast<float>(a) / static_cast<float>(b));
139 }
140 
141 inline BFloat16 operator-(const BFloat16 &a) {
142   constexpr uint16_t sign_mask = 0x8000;
143   return BFloat16::FromRaw(a.int_value() ^ sign_mask);
144 }
145 
146 inline bool operator==(const BFloat16 &a, const BFloat16 &b) {
147   return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
148 }
149 
150 inline bool operator!=(const BFloat16 &a, const BFloat16 &b) {
151   return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
152 }
153 
154 inline bool operator<(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) < static_cast<float>(b); }
155 inline bool operator<=(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) <= static_cast<float>(b); }
156 inline bool operator>(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) > static_cast<float>(b); }
157 inline bool operator>=(const BFloat16 &a, const BFloat16 &b) { return static_cast<float>(a) >= static_cast<float>(b); }
158 
159 inline std::ostream &operator<<(std::ostream &os, const BFloat16 &v) { return (os << static_cast<float>(v)); }
160 
161 }  // namespace mindspore
162 
163 using bfloat16 = mindspore::BFloat16;
164 
165 namespace std {
166 template <>
167 struct hash<bfloat16> {
168   std::size_t operator()(const bfloat16 &bf16) const noexcept { return static_cast<std::size_t>(bf16.int_value()); }
169 };
170 
171 template <>
172 struct is_floating_point<bfloat16> : public std::true_type {};
173 
174 template <>
175 struct is_signed<bfloat16> : public std::true_type {};
176 
177 template <>
178 struct numeric_limits<bfloat16> {
179   static constexpr bool is_specialized = true;
180   static constexpr bool is_signed = true;
181   static constexpr bool is_integer = false;
182   static constexpr bool is_exact = false;
183   static constexpr bool has_infinity = true;
184   static constexpr bool has_quiet_NaN = true;
185   static constexpr bool has_signaling_NaN = true;
186   static constexpr std::float_denorm_style has_denorm = numeric_limits<float>::has_denorm;
187   static constexpr bool has_denorm_loss = numeric_limits<float>::has_denorm_loss;
188   static constexpr std::float_round_style round_style = numeric_limits<float>::round_style;
189   static constexpr bool is_iec559 = false;
190   static constexpr bool is_bounded = true;
191   static constexpr bool is_modulo = false;
192   static constexpr int digits = 8;
193   static constexpr int digits10 = 2;
194   static constexpr int max_digits10 = 4;
195   static constexpr int radix = 2;
196   static constexpr int min_exponent = -125;
197   static constexpr int min_exponent10 = -37;
198   static constexpr int max_exponent = 128;
199   static constexpr int max_exponent10 = 38;
200   static constexpr bool traps = numeric_limits<float>::traps;
201   static constexpr bool tinyness_before = numeric_limits<float>::tinyness_before;
202 
203   static constexpr uint16_t raw_min = 0x0080;
204   static constexpr uint16_t raw_max = 0x7f7f;
205   static constexpr uint16_t raw_lowest = 0xff7f;
206   static constexpr uint16_t raw_epsilon = 0x3c00;
207   static constexpr uint16_t raw_round_error = 0x3f00;
208   static constexpr uint16_t raw_infinity = 0x7f80;
209   static constexpr uint16_t raw_quiet_nan = 0x7fc0;
210   static constexpr uint16_t raw_signaling_nan = 0x7f80;
211   static constexpr uint16_t raw_denorm_min = 0x0001;
212 
213   static bfloat16(min)() noexcept { return bfloat16::FromRaw(raw_min); }
214   static bfloat16(max)() noexcept { return bfloat16::FromRaw(raw_max); }
215   static bfloat16 lowest() noexcept { return bfloat16::FromRaw(raw_lowest); }
216   static bfloat16 epsilon() noexcept { return bfloat16::FromRaw(raw_epsilon); }
217   static bfloat16 round_error() noexcept { return bfloat16::FromRaw(raw_round_error); }
218   static bfloat16 infinity() noexcept { return bfloat16::FromRaw(raw_infinity); }
219   static bfloat16 quiet_NaN() noexcept { return bfloat16::FromRaw(raw_quiet_nan); }
220   static bfloat16 signaling_NaN() noexcept { return bfloat16::FromRaw(raw_signaling_nan); }
221   static bfloat16 denorm_min() noexcept { return bfloat16::FromRaw(raw_denorm_min); }
222 };
223 
224 // If std::numeric_limits<T> is specialized, should also specialize
225 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
226 // std::numeric_limits<const volatile T>
227 // https://stackoverflow.com/a/16519653/
228 template <>
229 struct numeric_limits<const mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {};
230 template <>
231 struct numeric_limits<volatile mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {};
232 template <>
233 struct numeric_limits<const volatile mindspore::BFloat16> : private numeric_limits<mindspore::BFloat16> {};
234 }  // namespace std
235 
236 // Implements standard math functions for bfloat16.
237 inline bool(isinf)(const bfloat16 &a) { return (a.int_value() & bfloat16::value_mask) == bfloat16::inf_value; }
238 inline bool(isnan)(const bfloat16 &a) { return (a.int_value() & bfloat16::value_mask) > bfloat16::inf_value; }
239 inline bool(isfinite)(const bfloat16 &a) { return !(isinf(a)) && !(isnan(a)); }
240 inline bfloat16 abs(const bfloat16 &a) { return bfloat16::FromRaw(a.int_value() & bfloat16::value_mask); }
241 inline bfloat16 exp(const bfloat16 &a) { return bfloat16(::expf(static_cast<float>(a))); }
242 inline bfloat16 log(const bfloat16 &a) { return bfloat16(::logf(static_cast<float>(a))); }
243 inline bfloat16 log1p(const bfloat16 &a) { return bfloat16(::log1pf(static_cast<float>(a))); }
244 inline bfloat16 log10(const bfloat16 &a) { return bfloat16(::log10f(static_cast<float>(a))); }
245 inline bfloat16 sqrt(const bfloat16 &a) { return bfloat16(::sqrtf(static_cast<float>(a))); }
246 inline bfloat16 sin(const bfloat16 &a) { return bfloat16(::sinf(static_cast<float>(a))); }
247 inline bfloat16 cos(const bfloat16 &a) { return bfloat16(::cosf(static_cast<float>(a))); }
248 inline bfloat16 tan(const bfloat16 &a) { return bfloat16(::tanf(static_cast<float>(a))); }
249 inline bfloat16 tanh(const bfloat16 &a) { return bfloat16(::tanhf(static_cast<float>(a))); }
250 inline bfloat16 floor(const bfloat16 &a) { return bfloat16(::floorf(static_cast<float>(a))); }
251 inline bfloat16 ceil(const bfloat16 &a) { return bfloat16(::ceilf(static_cast<float>(a))); }
252 inline bfloat16(min)(const bfloat16 &a, const bfloat16 &b) { return b < a ? b : a; }
253 inline bfloat16(max)(const bfloat16 &a, const bfloat16 &b) { return a < b ? b : a; }
254 inline bfloat16 pow(const bfloat16 &a, const bfloat16 &b) {
255   return bfloat16(::powf(static_cast<float>(a), static_cast<float>(b)));
256 }
257 
258 inline float bfloat_to_float(const bfloat16 &h) { return static_cast<float>(h); }
259 
260 #endif  // MINDSPORE_CORE_BASE_BFLOAT16_H_
261