• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_
17 #define MINDSPORE_CORE_BASE_FLOAT16_H_
18 
19 #include <type_traits>
20 #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64)
21 // Built for lite and ARM
22 #include <arm_neon.h>
23 
24 using float16 = float16_t;
25 
26 #else
27 #include <cmath>
28 #include <climits>
29 #include <cstdint>
30 #include <ostream>
31 #include <limits>
32 #include <functional>
33 
34 // Implement Float16 for mindspore, inspired by Eigen::half.
35 namespace mindspore {
36 class Float16 {
37  public:
38   static constexpr uint16_t value_mask = 0x7fff;
39   static constexpr uint16_t nan_value = 0x7e00;
40   static constexpr uint16_t inf_value = 0x7c00;
41   static constexpr uint16_t true_value = 0x3c00;
42 
43   union Union32 {
44     uint32_t u;
45     float f;
46   };
47 
48   Float16() = default;
49   ~Float16() = default;
50 
51   Float16(const Float16 &other) noexcept = default;
52   Float16(Float16 &&other) noexcept = default;
53 
54   Float16 &operator=(const Float16 &other) noexcept = default;
55   Float16 &operator=(Float16 &&other) noexcept = default;
56 
FromRaw(uint16_t v)57   static Float16 FromRaw(uint16_t v) {
58     Float16 f;
59     f.value_ = v;
60     return f;
61   }
62 
Float16(float f)63   explicit Float16(float f) : value_(FromFloat32(f)) {}
Float16(bool b)64   explicit Float16(bool b) : value_(b ? true_value : 0) {}
65   template <typename T>
Float16(const T & v)66   explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {}
67 
int_value()68   uint16_t int_value() const { return value_; }
69 
70   explicit operator bool() const { return (value_ & value_mask) != 0; }
71   explicit operator float() const { return ToFloat32(*this); }
72   explicit operator double() const { return static_cast<double>(ToFloat32(*this)); }
int8_t()73   explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); }
uint8_t()74   explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); }
int16_t()75   explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); }
uint16_t()76   explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); }
int32_t()77   explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); }
uint32_t()78   explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); }
int64_t()79   explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); }
uint64_t()80   explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); }
81 
82   Float16 &operator+=(const Float16 &b) {
83     value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b));
84     return *this;
85   }
86 
87   Float16 &operator-=(const Float16 &b) {
88     value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b));
89     return *this;
90   }
91 
92   Float16 &operator*=(const Float16 &b) {
93     value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b));
94     return *this;
95   }
96 
97   Float16 &operator/=(const Float16 &b) {
98     value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b));
99     return *this;
100   }
101 
ToFloat32(const Float16 & f16)102   static float ToFloat32(const Float16 &f16) {
103     constexpr uint32_t mu_value = 113 << 23;
104     Union32 magic;
105     magic.u = mu_value;
106     constexpr uint32_t exponent_adjust = ((127 - 15) << 23);
107     constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23);
108     constexpr uint32_t zero_extra_exp_adjust = (1 << 23);
109     constexpr uint32_t sign_mask = 0x8000;
110     constexpr unsigned int shifted_exp = (0x7c00 << 13);  // Exponent mask after shift.
111     constexpr unsigned int exponent_bits = 13;
112     constexpr unsigned int sign_bit_shift = 16;
113     // Exponent/mantissa bits.
114     Union32 f32;
115     f32.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits);
116     // Just the exponent.
117     unsigned int exp = (shifted_exp & f32.u);
118     f32.u += exponent_adjust;
119     // Handle exponent special cases.
120     if (exp == shifted_exp) {
121       // Inf/NaN, extra exp adjust.
122       f32.u += inf_extra_exp_adjust;
123     } else if (exp == 0) {
124       // Zero/Denormal, extra exp adjust and renormalize.
125       f32.u += zero_extra_exp_adjust;
126       f32.f -= magic.f;
127     }
128     // Set sign bit.
129     f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift);
130     return f32.f;
131   }
132 
133  private:
FromFloat32(float f32)134   static uint16_t FromFloat32(float f32) {
135     constexpr uint32_t magic = {113 << 23};
136     constexpr uint32_t f32infty_value = 255 << 23;
137     Union32 f32infty;
138     f32infty.u = f32infty_value;
139     constexpr uint32_t f16max_value = (127 + 16) << 23;
140     Union32 f16max;
141     f16max.u = f16max_value;
142     constexpr uint32_t denorm_magic_value = ((127 - 15) + (23 - 10) + 1) << 23;
143     Union32 denorm_magic;
144     denorm_magic.u = denorm_magic_value;
145     constexpr unsigned int exponent_bits = 13;
146     constexpr unsigned int sign_bit_shift = 16;
147     constexpr unsigned int sign_mask = 0x80000000u;
148     constexpr uint32_t rouding_bias_part1 = (static_cast<unsigned int>(15 - 127) << 23) + 0xfff;
149 
150     Union32 f;
151     f.f = f32;
152     unsigned int sign = f.u & sign_mask;
153     f.u ^= sign;
154     uint16_t result = 0;
155 
156     // NOTE all the integer compares in this function can be safely
157     // compiled into signed compares since all operands are below
158     // 0x80000000. Important if you want fast straight SSE2 code
159     // (since there's no unsigned PCMPGTD).
160     if (f.u >= f16max.u) {
161       // Result is Inf or NaN (all exponent bits set).
162       result = (f.u > f32infty.u) ? nan_value : inf_value;
163     } else if (f.u < magic) {
164       // (De)normalized number or zero; resulting FP16 is subnormal or zero.
165       // Use a magic value to align our 10 mantissa bits at the bottom of
166       // the float. as long as FP addition is round-to-nearest-even this
167       // just works.
168       f.f += denorm_magic.f;
169       // And one integer subtract of the bias later, we have our final float!
170       result = static_cast<uint16_t>(f.u - denorm_magic.u);
171     } else {
172       // Resulting mantissa is odd.
173       unsigned int mant_odd = (f.u >> exponent_bits) & 1;
174       // Update exponent, rounding bias part 1;
175       f.u += rouding_bias_part1;
176       // Rounding bias part 2;
177       f.u += mant_odd;
178       // Take the bits!
179       result = static_cast<uint16_t>(f.u >> exponent_bits);
180     }
181     // Set sign bit.
182     result |= static_cast<uint16_t>(sign >> sign_bit_shift);
183     return result;
184   }
185 
186   uint16_t value_;
187 };
188 
189 inline Float16 operator+(const Float16 &a, const Float16 &b) {
190   return Float16(static_cast<float>(a) + static_cast<float>(b));
191 }
192 
193 inline Float16 operator*(const Float16 &a, const Float16 &b) {
194   return Float16(static_cast<float>(a) * static_cast<float>(b));
195 }
196 
197 inline Float16 operator-(const Float16 &a, const Float16 &b) {
198   return Float16(static_cast<float>(a) - static_cast<float>(b));
199 }
200 
201 inline Float16 operator/(const Float16 &a, const Float16 &b) {
202   return Float16(static_cast<float>(a) / static_cast<float>(b));
203 }
204 
205 // Division by an size_t. Do it in full float precision to avoid
206 // accuracy issues in converting the denominator to float16.
207 inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); }
208 
209 inline Float16 operator-(const Float16 &a) {
210   constexpr uint16_t sign_mask = 0x8000;
211   return Float16::FromRaw(a.int_value() ^ sign_mask);
212 }
213 
214 inline bool operator==(const Float16 &a, const Float16 &b) {
215   return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
216 }
217 
218 inline bool operator!=(const Float16 &a, const Float16 &b) {
219   return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
220 }
221 
222 inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); }
223 inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); }
224 inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); }
225 inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); }
226 
227 inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); }
228 
229 }  // namespace mindspore
230 
231 using float16 = mindspore::Float16;
232 
233 namespace std {
234 template <>
235 struct hash<float16> {
236   std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); }
237 };
238 
239 template <>
240 struct is_floating_point<float16> : public std::true_type {};
241 
242 template <>
243 struct is_signed<float16> : public std::true_type {};
244 
245 template <>
246 struct numeric_limits<float16> {
247   static constexpr bool is_specialized = true;
248   static constexpr bool is_signed = true;
249   static constexpr bool is_integer = false;
250   static constexpr bool is_exact = false;
251   static constexpr bool has_infinity = true;
252   static constexpr bool has_quiet_NaN = true;
253   static constexpr bool has_signaling_NaN = true;
254   static constexpr std::float_denorm_style has_denorm = std::denorm_present;
255   static constexpr bool has_denorm_loss = false;
256   static constexpr std::float_round_style round_style = std::round_to_nearest;
257   static constexpr bool is_iec559 = false;
258   static constexpr bool is_bounded = false;
259   static constexpr bool is_modulo = false;
260   static constexpr int digits = 11;
261   static constexpr int digits10 = 3;
262   static constexpr int max_digits10 = 5;
263   static constexpr int radix = 2;
264   static constexpr int min_exponent = -13;
265   static constexpr int min_exponent10 = -4;
266   static constexpr int max_exponent = 16;
267   static constexpr int max_exponent10 = 4;
268   static constexpr bool traps = true;
269   static constexpr bool tinyness_before = false;
270 
271   static constexpr uint16_t raw_min = 0x400;
272   static constexpr uint16_t raw_max = 0x7bff;
273   static constexpr uint16_t raw_lowest = 0xfbff;
274   static constexpr uint16_t raw_epsilon = 0x0800;
275   static constexpr float round_error_value = 0.5;
276 
277   static float16(min)() noexcept { return float16::FromRaw(raw_min); }
278   static float16(max)() noexcept { return float16::FromRaw(raw_max); }
279   static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); }
280   static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); }
281   static float16 round_error() noexcept { return float16(round_error_value); }
282   static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); }
283   static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
284   static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
285   static float16 denorm_min() noexcept { return float16::FromRaw(1); }
286 };
287 
288 // If std::numeric_limits<T> is specialized, should also specialize
289 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
290 // std::numeric_limits<const volatile T>
291 // https://stackoverflow.com/a/16519653/
292 template <>
293 struct numeric_limits<const mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
294 template <>
295 struct numeric_limits<volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
296 template <>
297 struct numeric_limits<const volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {};
298 }  // namespace std
299 
300 // Implements standard math functions for float16.
301 inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; }
302 inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; }
303 inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); }
304 inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); }
305 inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); }
306 inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); }
307 inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); }
308 inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); }
309 inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); }
310 inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); }
311 inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); }
312 inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); }
313 inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); }
314 inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); }
315 inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); }
316 inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; }
317 inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; }
318 inline float16 pow(const float16 &a, const float16 &b) {
319   return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
320 }
321 
322 #endif  // ENABLE_ARM32 || ENABLE_ARM64
323 
324 inline float half_to_float(const float16 &h) { return static_cast<float>(h); }
325 
326 #endif  // MINDSPORE_CORE_BASE_FLOAT16_H_
327