1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_BASE_FLOAT16_H_ 17 #define MINDSPORE_CORE_BASE_FLOAT16_H_ 18 19 #include <type_traits> 20 #if defined(ENABLE_ARM32) || defined(ENABLE_ARM64) 21 // Built for lite and ARM 22 #include <arm_neon.h> 23 24 using float16 = float16_t; 25 26 #else 27 #include <cmath> 28 #include <climits> 29 #include <cstdint> 30 #include <ostream> 31 #include <limits> 32 #include <functional> 33 34 // Implement Float16 for mindspore, inspired by Eigen::half. 35 namespace mindspore { 36 class Float16 { 37 public: 38 static constexpr uint16_t value_mask = 0x7fff; 39 static constexpr uint16_t nan_value = 0x7e00; 40 static constexpr uint16_t inf_value = 0x7c00; 41 static constexpr uint16_t true_value = 0x3c00; 42 43 union Union32 { 44 uint32_t u; 45 float f; 46 }; 47 48 Float16() = default; 49 ~Float16() = default; 50 51 Float16(const Float16 &other) noexcept = default; 52 Float16(Float16 &&other) noexcept = default; 53 54 Float16 &operator=(const Float16 &other) noexcept = default; 55 Float16 &operator=(Float16 &&other) noexcept = default; 56 FromRaw(uint16_t v)57 static Float16 FromRaw(uint16_t v) { 58 Float16 f; 59 f.value_ = v; 60 return f; 61 } 62 Float16(float f)63 explicit Float16(float f) : value_(FromFloat32(f)) {} Float16(bool b)64 explicit Float16(bool b) : value_(b ? true_value : 0) {} 65 template <typename T> Float16(const T & v)66 explicit Float16(const T &v) : value_(FromFloat32(static_cast<float>(v))) {} 67 int_value()68 uint16_t int_value() const { return value_; } 69 70 explicit operator bool() const { return (value_ & value_mask) != 0; } 71 explicit operator float() const { return ToFloat32(*this); } 72 explicit operator double() const { return static_cast<double>(ToFloat32(*this)); } int8_t()73 explicit operator int8_t() const { return static_cast<int8_t>(ToFloat32(*this)); } uint8_t()74 explicit operator uint8_t() const { return static_cast<uint8_t>(ToFloat32(*this)); } int16_t()75 explicit operator int16_t() const { return static_cast<int16_t>(ToFloat32(*this)); } uint16_t()76 explicit operator uint16_t() const { return static_cast<uint16_t>(ToFloat32(*this)); } int32_t()77 explicit operator int32_t() const { return static_cast<int32_t>(ToFloat32(*this)); } uint32_t()78 explicit operator uint32_t() const { return static_cast<uint32_t>(ToFloat32(*this)); } int64_t()79 explicit operator int64_t() const { return static_cast<int64_t>(ToFloat32(*this)); } uint64_t()80 explicit operator uint64_t() const { return static_cast<uint64_t>(ToFloat32(*this)); } 81 82 Float16 &operator+=(const Float16 &b) { 83 value_ = FromFloat32(ToFloat32(*this) + ToFloat32(b)); 84 return *this; 85 } 86 87 Float16 &operator-=(const Float16 &b) { 88 value_ = FromFloat32(ToFloat32(*this) - ToFloat32(b)); 89 return *this; 90 } 91 92 Float16 &operator*=(const Float16 &b) { 93 value_ = FromFloat32(ToFloat32(*this) * ToFloat32(b)); 94 return *this; 95 } 96 97 Float16 &operator/=(const Float16 &b) { 98 value_ = FromFloat32(ToFloat32(*this) / ToFloat32(b)); 99 return *this; 100 } 101 ToFloat32(const Float16 & f16)102 static float ToFloat32(const Float16 &f16) { 103 constexpr uint32_t mu_value = 113 << 23; 104 Union32 magic; 105 magic.u = mu_value; 106 constexpr uint32_t exponent_adjust = ((127 - 15) << 23); 107 constexpr uint32_t inf_extra_exp_adjust = ((128 - 16) << 23); 108 constexpr uint32_t zero_extra_exp_adjust = (1 << 23); 109 constexpr uint32_t sign_mask = 0x8000; 110 constexpr unsigned int shifted_exp = (0x7c00 << 13); // Exponent mask after shift. 111 constexpr unsigned int exponent_bits = 13; 112 constexpr unsigned int sign_bit_shift = 16; 113 // Exponent/mantissa bits. 114 Union32 f32; 115 f32.u = (static_cast<uint32_t>(f16.value_ & value_mask) << exponent_bits); 116 // Just the exponent. 117 unsigned int exp = (shifted_exp & f32.u); 118 f32.u += exponent_adjust; 119 // Handle exponent special cases. 120 if (exp == shifted_exp) { 121 // Inf/NaN, extra exp adjust. 122 f32.u += inf_extra_exp_adjust; 123 } else if (exp == 0) { 124 // Zero/Denormal, extra exp adjust and renormalize. 125 f32.u += zero_extra_exp_adjust; 126 f32.f -= magic.f; 127 } 128 // Set sign bit. 129 f32.u |= ((f16.value_ & sign_mask) << sign_bit_shift); 130 return f32.f; 131 } 132 133 private: FromFloat32(float f32)134 static uint16_t FromFloat32(float f32) { 135 constexpr uint32_t magic = {113 << 23}; 136 constexpr uint32_t f32infty_value = 255 << 23; 137 Union32 f32infty; 138 f32infty.u = f32infty_value; 139 constexpr uint32_t f16max_value = (127 + 16) << 23; 140 Union32 f16max; 141 f16max.u = f16max_value; 142 constexpr uint32_t denorm_magic_value = ((127 - 15) + (23 - 10) + 1) << 23; 143 Union32 denorm_magic; 144 denorm_magic.u = denorm_magic_value; 145 constexpr unsigned int exponent_bits = 13; 146 constexpr unsigned int sign_bit_shift = 16; 147 constexpr unsigned int sign_mask = 0x80000000u; 148 constexpr uint32_t rouding_bias_part1 = (static_cast<unsigned int>(15 - 127) << 23) + 0xfff; 149 150 Union32 f; 151 f.f = f32; 152 unsigned int sign = f.u & sign_mask; 153 f.u ^= sign; 154 uint16_t result = 0; 155 156 // NOTE all the integer compares in this function can be safely 157 // compiled into signed compares since all operands are below 158 // 0x80000000. Important if you want fast straight SSE2 code 159 // (since there's no unsigned PCMPGTD). 160 if (f.u >= f16max.u) { 161 // Result is Inf or NaN (all exponent bits set). 162 result = (f.u > f32infty.u) ? nan_value : inf_value; 163 } else if (f.u < magic) { 164 // (De)normalized number or zero; resulting FP16 is subnormal or zero. 165 // Use a magic value to align our 10 mantissa bits at the bottom of 166 // the float. as long as FP addition is round-to-nearest-even this 167 // just works. 168 f.f += denorm_magic.f; 169 // And one integer subtract of the bias later, we have our final float! 170 result = static_cast<uint16_t>(f.u - denorm_magic.u); 171 } else { 172 // Resulting mantissa is odd. 173 unsigned int mant_odd = (f.u >> exponent_bits) & 1; 174 // Update exponent, rounding bias part 1; 175 f.u += rouding_bias_part1; 176 // Rounding bias part 2; 177 f.u += mant_odd; 178 // Take the bits! 179 result = static_cast<uint16_t>(f.u >> exponent_bits); 180 } 181 // Set sign bit. 182 result |= static_cast<uint16_t>(sign >> sign_bit_shift); 183 return result; 184 } 185 186 uint16_t value_; 187 }; 188 189 inline Float16 operator+(const Float16 &a, const Float16 &b) { 190 return Float16(static_cast<float>(a) + static_cast<float>(b)); 191 } 192 193 inline Float16 operator*(const Float16 &a, const Float16 &b) { 194 return Float16(static_cast<float>(a) * static_cast<float>(b)); 195 } 196 197 inline Float16 operator-(const Float16 &a, const Float16 &b) { 198 return Float16(static_cast<float>(a) - static_cast<float>(b)); 199 } 200 201 inline Float16 operator/(const Float16 &a, const Float16 &b) { 202 return Float16(static_cast<float>(a) / static_cast<float>(b)); 203 } 204 205 // Division by an size_t. Do it in full float precision to avoid 206 // accuracy issues in converting the denominator to float16. 207 inline Float16 operator/(const Float16 &a, size_t b) { return Float16(static_cast<float>(a) / static_cast<float>(b)); } 208 209 inline Float16 operator-(const Float16 &a) { 210 constexpr uint16_t sign_mask = 0x8000; 211 return Float16::FromRaw(a.int_value() ^ sign_mask); 212 } 213 214 inline bool operator==(const Float16 &a, const Float16 &b) { 215 return std::equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 216 } 217 218 inline bool operator!=(const Float16 &a, const Float16 &b) { 219 return std::not_equal_to<float>()(static_cast<float>(a), static_cast<float>(b)); 220 } 221 222 inline bool operator<(const Float16 &a, const Float16 &b) { return static_cast<float>(a) < static_cast<float>(b); } 223 inline bool operator<=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) <= static_cast<float>(b); } 224 inline bool operator>(const Float16 &a, const Float16 &b) { return static_cast<float>(a) > static_cast<float>(b); } 225 inline bool operator>=(const Float16 &a, const Float16 &b) { return static_cast<float>(a) >= static_cast<float>(b); } 226 227 inline std::ostream &operator<<(std::ostream &os, const Float16 &v) { return (os << static_cast<float>(v)); } 228 229 } // namespace mindspore 230 231 using float16 = mindspore::Float16; 232 233 namespace std { 234 template <> 235 struct hash<float16> { 236 std::size_t operator()(const float16 &f16) const noexcept { return static_cast<std::size_t>(f16.int_value()); } 237 }; 238 239 template <> 240 struct is_floating_point<float16> : public std::true_type {}; 241 242 template <> 243 struct is_signed<float16> : public std::true_type {}; 244 245 template <> 246 struct numeric_limits<float16> { 247 static constexpr bool is_specialized = true; 248 static constexpr bool is_signed = true; 249 static constexpr bool is_integer = false; 250 static constexpr bool is_exact = false; 251 static constexpr bool has_infinity = true; 252 static constexpr bool has_quiet_NaN = true; 253 static constexpr bool has_signaling_NaN = true; 254 static constexpr std::float_denorm_style has_denorm = std::denorm_present; 255 static constexpr bool has_denorm_loss = false; 256 static constexpr std::float_round_style round_style = std::round_to_nearest; 257 static constexpr bool is_iec559 = false; 258 static constexpr bool is_bounded = false; 259 static constexpr bool is_modulo = false; 260 static constexpr int digits = 11; 261 static constexpr int digits10 = 3; 262 static constexpr int max_digits10 = 5; 263 static constexpr int radix = 2; 264 static constexpr int min_exponent = -13; 265 static constexpr int min_exponent10 = -4; 266 static constexpr int max_exponent = 16; 267 static constexpr int max_exponent10 = 4; 268 static constexpr bool traps = true; 269 static constexpr bool tinyness_before = false; 270 271 static constexpr uint16_t raw_min = 0x400; 272 static constexpr uint16_t raw_max = 0x7bff; 273 static constexpr uint16_t raw_lowest = 0xfbff; 274 static constexpr uint16_t raw_epsilon = 0x0800; 275 static constexpr float round_error_value = 0.5; 276 277 static float16(min)() noexcept { return float16::FromRaw(raw_min); } 278 static float16(max)() noexcept { return float16::FromRaw(raw_max); } 279 static float16 lowest() noexcept { return float16::FromRaw(raw_lowest); } 280 static float16 epsilon() noexcept { return float16::FromRaw(raw_epsilon); } 281 static float16 round_error() noexcept { return float16(round_error_value); } 282 static float16 infinity() noexcept { return float16::FromRaw(float16::inf_value); } 283 static float16 quiet_NaN() noexcept { return float16::FromRaw(float16::nan_value); } 284 static float16 signaling_NaN() noexcept { return float16::FromRaw(float16::nan_value); } 285 static float16 denorm_min() noexcept { return float16::FromRaw(1); } 286 }; 287 288 // If std::numeric_limits<T> is specialized, should also specialize 289 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and 290 // std::numeric_limits<const volatile T> 291 // https://stackoverflow.com/a/16519653/ 292 template <> 293 struct numeric_limits<const mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; 294 template <> 295 struct numeric_limits<volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; 296 template <> 297 struct numeric_limits<const volatile mindspore::Float16> : private numeric_limits<mindspore::Float16> {}; 298 } // namespace std 299 300 // Implements standard math functions for float16. 301 inline bool(isinf)(const float16 &a) { return (a.int_value() & float16::value_mask) == float16::inf_value; } 302 inline bool(isnan)(const float16 &a) { return (a.int_value() & float16::value_mask) > float16::inf_value; } 303 inline bool(isfinite)(const float16 &a) { return !(isinf(a)) && !(isnan(a)); } 304 inline float16 abs(const float16 &a) { return float16::FromRaw(a.int_value() & float16::value_mask); } 305 inline float16 exp(const float16 &a) { return float16(::expf(static_cast<float>(a))); } 306 inline float16 log(const float16 &a) { return float16(::logf(static_cast<float>(a))); } 307 inline float16 log1p(const float16 &a) { return float16(::log1pf(static_cast<float>(a))); } 308 inline float16 log10(const float16 &a) { return float16(::log10f(static_cast<float>(a))); } 309 inline float16 sqrt(const float16 &a) { return float16(::sqrtf(static_cast<float>(a))); } 310 inline float16 sin(const float16 &a) { return float16(::sinf(static_cast<float>(a))); } 311 inline float16 cos(const float16 &a) { return float16(::cosf(static_cast<float>(a))); } 312 inline float16 tan(const float16 &a) { return float16(::tanf(static_cast<float>(a))); } 313 inline float16 tanh(const float16 &a) { return float16(::tanhf(static_cast<float>(a))); } 314 inline float16 floor(const float16 &a) { return float16(::floorf(static_cast<float>(a))); } 315 inline float16 ceil(const float16 &a) { return float16(::ceilf(static_cast<float>(a))); } 316 inline float16(min)(const float16 &a, const float16 &b) { return b < a ? b : a; } 317 inline float16(max)(const float16 &a, const float16 &b) { return a < b ? b : a; } 318 inline float16 pow(const float16 &a, const float16 &b) { 319 return float16(::powf(static_cast<float>(a), static_cast<float>(b))); 320 } 321 322 #endif // ENABLE_ARM32 || ENABLE_ARM64 323 324 inline float half_to_float(const float16 &h) { return static_cast<float>(h); } 325 326 #endif // MINDSPORE_CORE_BASE_FLOAT16_H_ 327