1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_ 17 #define MINDSPORE_CORE_BASE_FLOAT16_H_ 18 19 #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64) 20 // Built for lite and ARM 21 #include <arm_neon.h> 22 23 using float16 = float16_t; 24 25 #else 26 #include <cmath> 27 #include <climits> 28 #include <cstdint> 29 #include <ostream> 30 #include <limits> 31 #include <functional> 32 33 // Implement Float16 for mindspore, inspired by Eigen::half. 34 namespace mindspore { 35 class Float16 { 36 public: 37 static constexpr uint16_t value_mask = 0x7fff; 38 static constexpr uint16_t nan_value = 0x7e00; 39 static constexpr uint16_t inf_value = 0x7c00; 40 static constexpr uint16_t true_value = 0x3c00; 41 42 union Union32 { 43 uint32_t u; 44 float f; 45 }; 46 47 Float16() = default; 48 ~Float16() = default; 49 50 Float16(const Float16 &other) noexcept = default; 51 Float16(Float16 &&other) noexcept = default; 52 53 Float16 &operator=(const Float16 &other) noexcept = default; 54 Float16 &operator=(Float16 &&other) noexcept = default; 55 FromRaw(uint16_t v)56 static Float16 FromRaw(uint16_t v) { 57 Float16 f; 58 f.value_ = v; 59 return f; 60 } 61 Float16(float f)62 explicit Float16(float f) : value_(FromFloat32(f)) {} Float16(bool b)63 explicit Float16(bool b) : value_(b ? true_value : 0) {} 64 template <typename T> Float16(const T & v)65 explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {} 66 int_value()67 uint16_t int_value() const { return value_; } 68 69 explicit operator bool() const { return (value_ & value_mask) != 0; } 70 explicit operator float() const { return ToFloat32(*this); } 71 explicit operator double() const { return static_cast<double>(ToFloat32(*this)); } int8_t()72 explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); } uint8_t()73 explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); } int16_t()74 explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); } uint16_t()75 explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); } int32_t()76 explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); } uint32_t()77 explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); } int64_t()78 explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); } uint64_t()79 explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); } 80 81 Float16 &operator+=(const Float16 &b) { 82 value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b)); 83 return *this; 84 } 85 86 Float16 &operator-=(const Float16 &b) { 87 value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b)); 88 return *this; 89 } 90 91 Float16 &operator*=(const Float16 &b) { 92 value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b)); 93 return *this; 94 } 95 96 Float16 &operator/=(const Float16 &b) { 97 value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b)); 98 return *this; 99 } 100 ToFloat32(Float16 f16)101 static float ToFloat32(Float16 f16) { 102 constexpr Union32 magic = {113 << 23}; 103 constexpr uint32_t exponent_adjust = ((127 - 15) << 23); 104 constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23); 105 constexpr uint32_t zero_extra_exp_adjust = (1 << 23); 106 constexpr uint32_t sign_mask = 0x8000; 107 constexpr unsigned int shifted_exp = (0x7c00 << 13); // Exponent mask after shift. 108 constexpr unsigned int exponent_bits = 13; 109 constexpr unsigned int sign_bit_shift = 16; 110 // Exponent/mantissa bits. 111 Union32 f32; 112 f32.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits); 113 // Just the exponent. 114 unsigned int exp = (shifted_exp & f32.u); 115 f32.u += exponent_adjust; 116 // Handle exponent special cases. 117 if (exp == shifted_exp) { 118 // Inf/NaN, extra exp adjust. 119 f32.u += inf_extra_exp_adjust; 120 } else if (exp == 0) { 121 // Zero/Denormal, extra exp adjust and renormalize. 122 f32.u += zero_extra_exp_adjust; 123 f32.f -= magic.f; 124 } 125 // Set sign bit. 126 f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift); 127 return f32.f; 128 } 129 130 private: FromFloat32(float f32)131 static uint16_t FromFloat32(float f32) { 132 constexpr uint32_t magic = {113 << 23}; 133 constexpr Union32 f32infty = {255 << 23}; 134 constexpr Union32 f16max = {(127 + 16) << 23}; 135 constexpr Union32 denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; 136 constexpr unsigned int exponent_bits = 13; 137 constexpr unsigned int sign_bit_shift = 16; 138 constexpr unsigned int sign_mask = 0x80000000u; 139 constexpr uint32_t rouding_bias_part1 = ((unsigned int)(15 - 127) << 23) + 0xfff; 140 141 Union32 f; 142 f.f = f32; 143 unsigned int sign = f.u & sign_mask; 144 f.u ^= sign; 145 uint16_t result = 0; 146 147 // NOTE all the integer compares in this function can be safely 148 // compiled into signed compares since all operands are below 149 // 0x80000000. Important if you want fast straight SSE2 code 150 // (since there's no unsigned PCMPGTD). 151 if (f.u >= f16max.u) { 152 // Result is Inf or NaN (all exponent bits set). 153 result = (f.u > f32infty.u) ? nan_value : inf_value; 154 } else if (f.u < magic) { 155 // (De)normalized number or zero; resulting FP16 is subnormal or zero. 156 // Use a magic value to align our 10 mantissa bits at the bottom of 157 // the float. as long as FP addition is round-to-nearest-even this 158 // just works. 159 f.f += denorm_magic.f; 160 // And one integer subtract of the bias later, we have our final float! 161 result = static_cast<uint16_t>(f.u - denorm_magic.u); 162 } else { 163 // Resulting mantissa is odd. 164 unsigned int mant_odd = (f.u >> exponent_bits) & 1; 165 // Update exponent, rounding bias part 1; 166 f.u += rouding_bias_part1; 167 // Rounding bias part 2; 168 f.u += mant_odd; 169 // Take the bits! 170 result = static_cast<uint16_t>(f.u >> exponent_bits); 171 } 172 // Set sign bit. 173 result |= static_cast<uint16_t>(sign >> sign_bit_shift); 174 return result; 175 } 176 177 uint16_t value_; 178 }; 179 180 inline Float16 operator+(const Float16 &a, const Float16 &b) { 181 return Float16(static_cast<float>(a) + static_cast<float>(b)); 182 } 183 184 inline Float16 operator*(const Float16 &a, const Float16 &b) { 185 return Float16(static_cast<float>(a) * static_cast<float>(b)); 186 } 187 188 inline Float16 operator-(const Float16 &a, const Float16 &b) { 189 return Float16(static_cast<float>(a) - static_cast<float>(b)); 190 } 191 192 inline Float16 operator/(const Float16 &a, const Float16 &b) { 193 return Float16(static_cast<float>(a) / static_cast<float>(b)); 194 } 195 196 // Division by an size_t. Do it in full float precision to avoid 197 // accuracy issues in converting the denominator to float16. 198 inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); } 199 200 inline Float16 operator-(const Float16 &a) { 201 constexpr uint16_t sign_mask = 0x8000; 202 return Float16::FromRaw(a.int_value() ^ sign_mask); 203 } 204 205 inline bool operator==(const Float16 &a, const Float16 &b) { 206 return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 207 } 208 209 inline bool operator!=(const Float16 &a, const Float16 &b) { 210 return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 211 } 212 213 inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); } 214 inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); } 215 inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); } 216 inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); } 217 218 inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); } 219 220 } // namespace mindspore 221 222 using float16 = mindspore::Float16; 223 224 namespace std { 225 template <> 226 struct hash<float16> { 227 std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); } 228 }; 229 230 template <> 231 struct numeric_limits<float16> { 232 static constexpr bool is_specialized = true; 233 static constexpr bool is_signed = true; 234 static constexpr bool is_integer = false; 235 static constexpr bool is_exact = false; 236 static constexpr bool has_infinity = true; 237 static constexpr bool has_quiet_NaN = true; 238 static constexpr bool has_signaling_NaN = true; 239 static constexpr std::float_denorm_style has_denorm = std::denorm_present; 240 static constexpr bool has_denorm_loss = false; 241 static constexpr std::float_round_style round_style = std::round_to_nearest; 242 static constexpr bool is_iec559 = false; 243 static constexpr bool is_bounded = false; 244 static constexpr bool is_modulo = false; 245 static constexpr int digits = 11; 246 static constexpr int digits10 = 3; 247 static constexpr int max_digits10 = 5; 248 static constexpr int radix = 2; 249 static constexpr int min_exponent = -13; 250 static constexpr int min_exponent10 = -4; 251 static constexpr int max_exponent = 16; 252 static constexpr int max_exponent10 = 4; 253 static constexpr bool traps = true; 254 static constexpr bool tinyness_before = false; 255 256 static constexpr uint16_t raw_min = 0x400; 257 static constexpr uint16_t raw_max = 0x7bff; 258 static constexpr uint16_t raw_lowest = 0xfbff; 259 static constexpr uint16_t raw_epsilon = 0x0800; 260 static constexpr float round_error_value = 0.5; 261 262 static float16(min)() noexcept { return float16::FromRaw(raw_min); } 263 static float16(max)() noexcept { return float16::FromRaw(raw_max); } 264 static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); } 265 static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); } 266 static float16 round_error() noexcept { return float16(round_error_value); } 267 static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); } 268 static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); } 269 static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); } 270 static float16 denorm_min() noexcept { return float16::FromRaw(1); } 271 }; 272 273 // If std::numeric_limits<T> is specialized, should also specialize 274 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and 275 // std::numeric_limits<const volatile T> 276 // https://stackoverflow.com/a/16519653/ 277 template <> 278 struct numeric_limits<const mindspore::Float16> : numeric_limits<mindspore::Float16> {}; 279 template <> 280 struct numeric_limits<volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {}; 281 template <> 282 struct numeric_limits<const volatile mindspore::Float16> : numeric_limits<mindspore::Float16> {}; 283 } // namespace std 284 285 // Implements standard math functions for float16. 286 inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; } 287 inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; } 288 inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); } 289 inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); } 290 inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); } 291 inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); } 292 inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); } 293 inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); } 294 inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); } 295 inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); } 296 inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); } 297 inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); } 298 inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); } 299 inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); } 300 inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); } 301 inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; } 302 inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; } 303 inline float16 pow(const float16 &a, const float16 &b) { 304 return float16(::powf(static_cast<float>(a), static_cast<float>(b))); 305 } 306 307 #endif // ENABLE_ARM32 || ENABLE_ARM64 308 309 inline float half_to_float(float16 h) { return static_cast<float>(h); } 310 311 #endif // MINDSPORE_CORE_BASE_FLOAT16_H_ 312