• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_
17 #define MINDSPORE_CORE_BASE_FLOAT16_H_
18 
19 #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64)
20 // Built for lite and ARM
21 #include <arm_neon.h>
22 
23 using float16 = float16_t;
24 
25 #else
26 #include <cmath>
27 #include <climits>
28 #include <cstdint>
29 #include <ostream>
30 #include <limits>
31 #include <functional>
32 
33 // Implement Float16 for mindspore, inspired by Eigen::half.
34 namespace mindspore {
35 class Float16 {
36  public:
37   static constexpr uint16_t value_mask = 0x7fff;
38   static constexpr uint16_t nan_value = 0x7e00;
39   static constexpr uint16_t inf_value = 0x7c00;
40   static constexpr uint16_t true_value = 0x3c00;
41 
42   union Union32 {
43     uint32_t u;
44     float f;
45   };
46 
47   Float16() = default;
48   ~Float16() = default;
49 
50   Float16(const Float16 &other) noexcept = default;
51   Float16(Float16 &&other) noexcept = default;
52 
53   Float16 &operator=(const Float16 &other) noexcept = default;
54   Float16 &operator=(Float16 &&other) noexcept = default;
55 
FromRaw(uint16_t v)56   static Float16 FromRaw(uint16_t v) {
57     Float16 f;
58     f.value_ = v;
59     return f;
60   }
61 
Float16(float f)62   explicit Float16(float f) : value_(FromFloat32(f)) {}
Float16(bool b)63   explicit Float16(bool b) : value_(b ? true_value : 0) {}
64   template <typename T>
Float16(const T & v)65   explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {}
66 
int_value()67   uint16_t int_value() const { return value_; }
68 
69   explicit operator bool() const { return (value_ & value_mask) != 0; }
70   explicit operator float() const { return ToFloat32(*this); }
71   explicit operator double() const { return static_cast<double>(ToFloat32(*this)); }
int8_t()72   explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); }
uint8_t()73   explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); }
int16_t()74   explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); }
uint16_t()75   explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); }
int32_t()76   explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); }
uint32_t()77   explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); }
int64_t()78   explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); }
uint64_t()79   explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); }
80 
81   Float16 &operator+=(const Float16 &b) {
82     value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b));
83     return *this;
84   }
85 
86   Float16 &operator-=(const Float16 &b) {
87     value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b));
88     return *this;
89   }
90 
91   Float16 &operator*=(const Float16 &b) {
92     value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b));
93     return *this;
94   }
95 
96   Float16 &operator/=(const Float16 &b) {
97     value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b));
98     return *this;
99   }
100 
ToFloat32(Float16 f16)101   static float ToFloat32(Float16 f16) {
102     constexpr Union32 magic = {113 << 23};
103     constexpr uint32_t exponent_adjust = ((127 - 15) << 23);
104     constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23);
105     constexpr uint32_t zero_extra_exp_adjust = (1 << 23);
106     constexpr uint32_t sign_mask = 0x8000;
107     constexpr unsigned int shifted_exp = (0x7c00 << 13);  // Exponent mask after shift.
108     constexpr unsigned int exponent_bits = 13;
109     constexpr unsigned int sign_bit_shift = 16;
110     // Exponent/mantissa bits.
111     Union32 f32;
112     f32.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits);
113     // Just the exponent.
114     unsigned int exp = (shifted_exp & f32.u);
115     f32.u += exponent_adjust;
116     // Handle exponent special cases.
117     if (exp == shifted_exp) {
118       // Inf/NaN, extra exp adjust.
119       f32.u += inf_extra_exp_adjust;
120     } else if (exp == 0) {
121       // Zero/Denormal, extra exp adjust and renormalize.
122       f32.u += zero_extra_exp_adjust;
123       f32.f -= magic.f;
124     }
125     // Set sign bit.
126     f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift);
127     return f32.f;
128   }
129 
130  private:
FromFloat32(float f32)131   static uint16_t FromFloat32(float f32) {
132     constexpr uint32_t magic = {113 << 23};
133     constexpr Union32 f32infty = {255 << 23};
134     constexpr Union32 f16max = {(127 + 16) << 23};
135     constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
136     constexpr unsigned int exponent_bits = 13;
137     constexpr unsigned int sign_bit_shift = 16;
138     constexpr unsigned int sign_mask = 0x80000000u;
139     constexpr uint32_t rouding_bias_part1 = ((unsigned int)(15 - 127) << 23) + 0xfff;
140 
141     Union32 f;
142     f.f = f32;
143     unsigned int sign = f.u & sign_mask;
144     f.u ^= sign;
145     uint16_t result = 0;
146 
147     // NOTE all the integer compares in this function can be safely
148     // compiled into signed compares since all operands are below
149     // 0x80000000. Important if you want fast straight SSE2 code
150     // (since there's no unsigned PCMPGTD).
151     if (f.u >= f16max.u) {
152       // Result is Inf or NaN (all exponent bits set).
153       result = (f.u > f32infty.u) ? nan_value : inf_value;
154     } else if (f.u < magic) {
155       // (De)normalized number or zero; resulting FP16 is subnormal or zero.
156       // Use a magic value to align our 10 mantissa bits at the bottom of
157       // the float. as long as FP addition is round-to-nearest-even this
158       // just works.
159       f.f += denorm_magic.f;
160       // And one integer subtract of the bias later, we have our final float!
161       result = static_cast<uint16_t>(f.u - denorm_magic.u);
162     } else {
163       // Resulting mantissa is odd.
164       unsigned int mant_odd = (f.u >> exponent_bits) & 1;
165       // Update exponent, rounding bias part 1;
166       f.u += rouding_bias_part1;
167       // Rounding bias part 2;
168       f.u += mant_odd;
169       // Take the bits!
170       result = static_cast<uint16_t>(f.u >> exponent_bits);
171     }
172     // Set sign bit.
173     result |= static_cast<uint16_t>(sign >> sign_bit_shift);
174     return result;
175   }
176 
177   uint16_t value_;
178 };
179 
180 inline Float16 operator+(const Float16 &a, const Float16 &b) {
181   return Float16(static_cast<float>(a) + static_cast<float>(b));
182 }
183 
184 inline Float16 operator*(const Float16 &a, const Float16 &b) {
185   return Float16(static_cast<float>(a) * static_cast<float>(b));
186 }
187 
188 inline Float16 operator-(const Float16 &a, const Float16 &b) {
189   return Float16(static_cast<float>(a) - static_cast<float>(b));
190 }
191 
192 inline Float16 operator/(const Float16 &a, const Float16 &b) {
193   return Float16(static_cast<float>(a) / static_cast<float>(b));
194 }
195 
196 // Division by an size_t. Do it in full float precision to avoid
197 // accuracy issues in converting the denominator to float16.
198 inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); }
199 
200 inline Float16 operator-(const Float16 &a) {
201   constexpr uint16_t sign_mask = 0x8000;
202   return Float16::FromRaw(a.int_value() ^ sign_mask);
203 }
204 
205 inline bool operator==(const Float16 &a, const Float16 &b) {
206   return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
207 }
208 
209 inline bool operator!=(const Float16 &a, const Float16 &b) {
210   return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b));
211 }
212 
213 inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); }
214 inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); }
215 inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); }
216 inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); }
217 
218 inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); }
219 
220 }  // namespace mindspore
221 
222 using float16 = mindspore::Float16;
223 
224 namespace std {
225 template <>
226 struct hash<float16> {
227   std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); }
228 };
229 
230 template <>
231 struct numeric_limits<float16> {
232   static constexpr bool is_specialized = true;
233   static constexpr bool is_signed = true;
234   static constexpr bool is_integer = false;
235   static constexpr bool is_exact = false;
236   static constexpr bool has_infinity = true;
237   static constexpr bool has_quiet_NaN = true;
238   static constexpr bool has_signaling_NaN = true;
239   static constexpr std::float_denorm_style has_denorm = std::denorm_present;
240   static constexpr bool has_denorm_loss = false;
241   static constexpr std::float_round_style round_style = std::round_to_nearest;
242   static constexpr bool is_iec559 = false;
243   static constexpr bool is_bounded = false;
244   static constexpr bool is_modulo = false;
245   static constexpr int digits = 11;
246   static constexpr int digits10 = 3;
247   static constexpr int max_digits10 = 5;
248   static constexpr int radix = 2;
249   static constexpr int min_exponent = -13;
250   static constexpr int min_exponent10 = -4;
251   static constexpr int max_exponent = 16;
252   static constexpr int max_exponent10 = 4;
253   static constexpr bool traps = true;
254   static constexpr bool tinyness_before = false;
255 
256   static constexpr uint16_t raw_min = 0x400;
257   static constexpr uint16_t raw_max = 0x7bff;
258   static constexpr uint16_t raw_lowest = 0xfbff;
259   static constexpr uint16_t raw_epsilon = 0x0800;
260   static constexpr float round_error_value = 0.5;
261 
262   static float16(min)() noexcept { return float16::FromRaw(raw_min); }
263   static float16(max)() noexcept { return float16::FromRaw(raw_max); }
264   static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); }
265   static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); }
266   static float16 round_error() noexcept { return float16(round_error_value); }
267   static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); }
268   static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
269   static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); }
270   static float16 denorm_min() noexcept { return float16::FromRaw(1); }
271 };
272 
273 // If std::numeric_limits<T> is specialized, should also specialize
274 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
275 // std::numeric_limits<const volatile T>
276 // https://stackoverflow.com/a/16519653/
277 template <>
278 struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {};
279 template <>
280 struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
281 template <>
282 struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {};
283 }  // namespace std
284 
285 // Implements standard math functions for float16.
286 inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; }
287 inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; }
288 inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); }
289 inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); }
290 inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); }
291 inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); }
292 inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); }
293 inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); }
294 inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); }
295 inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); }
296 inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); }
297 inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); }
298 inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); }
299 inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); }
300 inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); }
301 inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; }
302 inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; }
303 inline float16 pow(const float16 &a, const float16 &b) {
304   return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
305 }
306 
307 #endif  // ENABLE_ARM32 || ENABLE_ARM64
308 
309 inline float half_to_float(float16 h) { return static_cast<float>(h); }
310 
311 #endif  // MINDSPORE_CORE_BASE_FLOAT16_H_
312