1 // Copyright (c) 2015-2016 The Khronos Group Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef LIBSPIRV_UTIL_HEX_FLOAT_H_ 16 #define LIBSPIRV_UTIL_HEX_FLOAT_H_ 17 18 #include <cassert> 19 #include <cctype> 20 #include <cmath> 21 #include <cstdint> 22 #include <iomanip> 23 #include <limits> 24 #include <sstream> 25 26 #include "bitutils.h" 27 28 namespace spvutils { 29 30 class Float16 { 31 public: Float16(uint16_t v)32 Float16(uint16_t v) : val(v) {} 33 Float16() = default; isNan(const Float16 & val)34 static bool isNan(const Float16& val) { 35 return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0); 36 } 37 // Returns true if the given value is any kind of infinity. isInfinity(const Float16 & val)38 static bool isInfinity(const Float16& val) { 39 return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) == 0); 40 } Float16(const Float16 & other)41 Float16(const Float16& other) { val = other.val; } get_value()42 uint16_t get_value() const { return val; } 43 44 // Returns the maximum normal value. max()45 static Float16 max() { return Float16(0x7bff); } 46 // Returns the lowest normal value. lowest()47 static Float16 lowest() { return Float16(0xfbff); } 48 49 private: 50 uint16_t val; 51 }; 52 53 // To specialize this type, you must override uint_type to define 54 // an unsigned integer that can fit your floating point type. 55 // You must also add a isNan function that returns true if 56 // a value is Nan. 57 template <typename T> 58 struct FloatProxyTraits { 59 using uint_type = void; 60 }; 61 62 template <> 63 struct FloatProxyTraits<float> { 64 using uint_type = uint32_t; 65 static bool isNan(float f) { return std::isnan(f); } 66 // Returns true if the given value is any kind of infinity. 67 static bool isInfinity(float f) { return std::isinf(f); } 68 // Returns the maximum normal value. 69 static float max() { return std::numeric_limits<float>::max(); } 70 // Returns the lowest normal value. 71 static float lowest() { return std::numeric_limits<float>::lowest(); } 72 }; 73 74 template <> 75 struct FloatProxyTraits<double> { 76 using uint_type = uint64_t; 77 static bool isNan(double f) { return std::isnan(f); } 78 // Returns true if the given value is any kind of infinity. 79 static bool isInfinity(double f) { return std::isinf(f); } 80 // Returns the maximum normal value. 81 static double max() { return std::numeric_limits<double>::max(); } 82 // Returns the lowest normal value. 83 static double lowest() { return std::numeric_limits<double>::lowest(); } 84 }; 85 86 template <> 87 struct FloatProxyTraits<Float16> { 88 using uint_type = uint16_t; 89 static bool isNan(Float16 f) { return Float16::isNan(f); } 90 // Returns true if the given value is any kind of infinity. 91 static bool isInfinity(Float16 f) { return Float16::isInfinity(f); } 92 // Returns the maximum normal value. 93 static Float16 max() { return Float16::max(); } 94 // Returns the lowest normal value. 95 static Float16 lowest() { return Float16::lowest(); } 96 }; 97 98 // Since copying a floating point number (especially if it is NaN) 99 // does not guarantee that bits are preserved, this class lets us 100 // store the type and use it as a float when necessary. 101 template <typename T> 102 class FloatProxy { 103 public: 104 using uint_type = typename FloatProxyTraits<T>::uint_type; 105 106 // Since this is to act similar to the normal floats, 107 // do not initialize the data by default. 108 FloatProxy() = default; 109 110 // Intentionally non-explicit. This is a proxy type so 111 // implicit conversions allow us to use it more transparently. 112 FloatProxy(T val) { data_ = BitwiseCast<uint_type>(val); } 113 114 // Intentionally non-explicit. This is a proxy type so 115 // implicit conversions allow us to use it more transparently. 116 FloatProxy(uint_type val) { data_ = val; } 117 118 // This is helpful to have and is guaranteed not to stomp bits. 119 FloatProxy<T> operator-() const { 120 return static_cast<uint_type>(data_ ^ 121 (uint_type(0x1) << (sizeof(T) * 8 - 1))); 122 } 123 124 // Returns the data as a floating point value. 125 T getAsFloat() const { return BitwiseCast<T>(data_); } 126 127 // Returns the raw data. 128 uint_type data() const { return data_; } 129 130 // Returns true if the value represents any type of NaN. 131 bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); } 132 // Returns true if the value represents any type of infinity. 133 bool isInfinity() { return FloatProxyTraits<T>::isInfinity(getAsFloat()); } 134 135 // Returns the maximum normal value. 136 static FloatProxy<T> max() { 137 return FloatProxy<T>(FloatProxyTraits<T>::max()); 138 } 139 // Returns the lowest normal value. 140 static FloatProxy<T> lowest() { 141 return FloatProxy<T>(FloatProxyTraits<T>::lowest()); 142 } 143 144 private: 145 uint_type data_; 146 }; 147 148 template <typename T> 149 bool operator==(const FloatProxy<T>& first, const FloatProxy<T>& second) { 150 return first.data() == second.data(); 151 } 152 153 // Reads a FloatProxy value as a normal float from a stream. 154 template <typename T> 155 std::istream& operator>>(std::istream& is, FloatProxy<T>& value) { 156 T float_val; 157 is >> float_val; 158 value = FloatProxy<T>(float_val); 159 return is; 160 } 161 162 // This is an example traits. It is not meant to be used in practice, but will 163 // be the default for any non-specialized type. 164 template <typename T> 165 struct HexFloatTraits { 166 // Integer type that can store this hex-float. 167 using uint_type = void; 168 // Signed integer type that can store this hex-float. 169 using int_type = void; 170 // The numerical type that this HexFloat represents. 171 using underlying_type = void; 172 // The type needed to construct the underlying type. 173 using native_type = void; 174 // The number of bits that are actually relevant in the uint_type. 175 // This allows us to deal with, for example, 24-bit values in a 32-bit 176 // integer. 177 static const uint32_t num_used_bits = 0; 178 // Number of bits that represent the exponent. 179 static const uint32_t num_exponent_bits = 0; 180 // Number of bits that represent the fractional part. 181 static const uint32_t num_fraction_bits = 0; 182 // The bias of the exponent. (How much we need to subtract from the stored 183 // value to get the correct value.) 184 static const uint32_t exponent_bias = 0; 185 }; 186 187 // Traits for IEEE float. 188 // 1 sign bit, 8 exponent bits, 23 fractional bits. 189 template <> 190 struct HexFloatTraits<FloatProxy<float>> { 191 using uint_type = uint32_t; 192 using int_type = int32_t; 193 using underlying_type = FloatProxy<float>; 194 using native_type = float; 195 static const uint_type num_used_bits = 32; 196 static const uint_type num_exponent_bits = 8; 197 static const uint_type num_fraction_bits = 23; 198 static const uint_type exponent_bias = 127; 199 }; 200 201 // Traits for IEEE double. 202 // 1 sign bit, 11 exponent bits, 52 fractional bits. 203 template <> 204 struct HexFloatTraits<FloatProxy<double>> { 205 using uint_type = uint64_t; 206 using int_type = int64_t; 207 using underlying_type = FloatProxy<double>; 208 using native_type = double; 209 static const uint_type num_used_bits = 64; 210 static const uint_type num_exponent_bits = 11; 211 static const uint_type num_fraction_bits = 52; 212 static const uint_type exponent_bias = 1023; 213 }; 214 215 // Traits for IEEE half. 216 // 1 sign bit, 5 exponent bits, 10 fractional bits. 217 template <> 218 struct HexFloatTraits<FloatProxy<Float16>> { 219 using uint_type = uint16_t; 220 using int_type = int16_t; 221 using underlying_type = uint16_t; 222 using native_type = uint16_t; 223 static const uint_type num_used_bits = 16; 224 static const uint_type num_exponent_bits = 5; 225 static const uint_type num_fraction_bits = 10; 226 static const uint_type exponent_bias = 15; 227 }; 228 229 enum class round_direction { 230 kToZero, 231 kToNearestEven, 232 kToPositiveInfinity, 233 kToNegativeInfinity, 234 max = kToNegativeInfinity 235 }; 236 237 // Template class that houses a floating pointer number. 238 // It exposes a number of constants based on the provided traits to 239 // assist in interpreting the bits of the value. 240 template <typename T, typename Traits = HexFloatTraits<T>> 241 class HexFloat { 242 public: 243 using uint_type = typename Traits::uint_type; 244 using int_type = typename Traits::int_type; 245 using underlying_type = typename Traits::underlying_type; 246 using native_type = typename Traits::native_type; 247 248 explicit HexFloat(T f) : value_(f) {} 249 250 T value() const { return value_; } 251 void set_value(T f) { value_ = f; } 252 253 // These are all written like this because it is convenient to have 254 // compile-time constants for all of these values. 255 256 // Pass-through values to save typing. 257 static const uint32_t num_used_bits = Traits::num_used_bits; 258 static const uint32_t exponent_bias = Traits::exponent_bias; 259 static const uint32_t num_exponent_bits = Traits::num_exponent_bits; 260 static const uint32_t num_fraction_bits = Traits::num_fraction_bits; 261 262 // Number of bits to shift left to set the highest relevant bit. 263 static const uint32_t top_bit_left_shift = num_used_bits - 1; 264 // How many nibbles (hex characters) the fractional part takes up. 265 static const uint32_t fraction_nibbles = (num_fraction_bits + 3) / 4; 266 // If the fractional part does not fit evenly into a hex character (4-bits) 267 // then we have to left-shift to get rid of leading 0s. This is the amount 268 // we have to shift (might be 0). 269 static const uint32_t num_overflow_bits = 270 fraction_nibbles * 4 - num_fraction_bits; 271 272 // The representation of the fraction, not the actual bits. This 273 // includes the leading bit that is usually implicit. 274 static const uint_type fraction_represent_mask = 275 spvutils::SetBits<uint_type, 0, 276 num_fraction_bits + num_overflow_bits>::get; 277 278 // The topmost bit in the nibble-aligned fraction. 279 static const uint_type fraction_top_bit = 280 uint_type(1) << (num_fraction_bits + num_overflow_bits - 1); 281 282 // The least significant bit in the exponent, which is also the bit 283 // immediately to the left of the significand. 284 static const uint_type first_exponent_bit = uint_type(1) 285 << (num_fraction_bits); 286 287 // The mask for the encoded fraction. It does not include the 288 // implicit bit. 289 static const uint_type fraction_encode_mask = 290 spvutils::SetBits<uint_type, 0, num_fraction_bits>::get; 291 292 // The bit that is used as a sign. 293 static const uint_type sign_mask = uint_type(1) << top_bit_left_shift; 294 295 // The bits that represent the exponent. 296 static const uint_type exponent_mask = 297 spvutils::SetBits<uint_type, num_fraction_bits, num_exponent_bits>::get; 298 299 // How far left the exponent is shifted. 300 static const uint32_t exponent_left_shift = num_fraction_bits; 301 302 // How far from the right edge the fraction is shifted. 303 static const uint32_t fraction_right_shift = 304 static_cast<uint32_t>(sizeof(uint_type) * 8) - num_fraction_bits; 305 306 // The maximum representable unbiased exponent. 307 static const int_type max_exponent = 308 (exponent_mask >> num_fraction_bits) - exponent_bias; 309 // The minimum representable exponent for normalized numbers. 310 static const int_type min_exponent = -static_cast<int_type>(exponent_bias); 311 312 // Returns the bits associated with the value. 313 uint_type getBits() const { return spvutils::BitwiseCast<uint_type>(value_); } 314 315 // Returns the bits associated with the value, without the leading sign bit. 316 uint_type getUnsignedBits() const { 317 return static_cast<uint_type>(spvutils::BitwiseCast<uint_type>(value_) & 318 ~sign_mask); 319 } 320 321 // Returns the bits associated with the exponent, shifted to start at the 322 // lsb of the type. 323 const uint_type getExponentBits() const { 324 return static_cast<uint_type>((getBits() & exponent_mask) >> 325 num_fraction_bits); 326 } 327 328 // Returns the exponent in unbiased form. This is the exponent in the 329 // human-friendly form. 330 const int_type getUnbiasedExponent() const { 331 return static_cast<int_type>(getExponentBits() - exponent_bias); 332 } 333 334 // Returns just the significand bits from the value. 335 const uint_type getSignificandBits() const { 336 return getBits() & fraction_encode_mask; 337 } 338 339 // If the number was normalized, returns the unbiased exponent. 340 // If the number was denormal, normalize the exponent first. 341 const int_type getUnbiasedNormalizedExponent() const { 342 if ((getBits() & ~sign_mask) == 0) { // special case if everything is 0 343 return 0; 344 } 345 int_type exp = getUnbiasedExponent(); 346 if (exp == min_exponent) { // We are in denorm land. 347 uint_type significand_bits = getSignificandBits(); 348 while ((significand_bits & (first_exponent_bit >> 1)) == 0) { 349 significand_bits = static_cast<uint_type>(significand_bits << 1); 350 exp = static_cast<int_type>(exp - 1); 351 } 352 significand_bits &= fraction_encode_mask; 353 } 354 return exp; 355 } 356 357 // Returns the signficand after it has been normalized. 358 const uint_type getNormalizedSignificand() const { 359 int_type unbiased_exponent = getUnbiasedNormalizedExponent(); 360 uint_type significand = getSignificandBits(); 361 for (int_type i = unbiased_exponent; i <= min_exponent; ++i) { 362 significand = static_cast<uint_type>(significand << 1); 363 } 364 significand &= fraction_encode_mask; 365 return significand; 366 } 367 368 // Returns true if this number represents a negative value. 369 bool isNegative() const { return (getBits() & sign_mask) != 0; } 370 371 // Sets this HexFloat from the individual components. 372 // Note this assumes EVERY significand is normalized, and has an implicit 373 // leading one. This means that the only way that this method will set 0, 374 // is if you set a number so denormalized that it underflows. 375 // Do not use this method with raw bits extracted from a subnormal number, 376 // since subnormals do not have an implicit leading 1 in the significand. 377 // The significand is also expected to be in the 378 // lowest-most num_fraction_bits of the uint_type. 379 // The exponent is expected to be unbiased, meaning an exponent of 380 // 0 actually means 0. 381 // If underflow_round_up is set, then on underflow, if a number is non-0 382 // and would underflow, we round up to the smallest denorm. 383 void setFromSignUnbiasedExponentAndNormalizedSignificand( 384 bool negative, int_type exponent, uint_type significand, 385 bool round_denorm_up) { 386 bool significand_is_zero = significand == 0; 387 388 if (exponent <= min_exponent) { 389 // If this was denormalized, then we have to shift the bit on, meaning 390 // the significand is not zero. 391 significand_is_zero = false; 392 significand |= first_exponent_bit; 393 significand = static_cast<uint_type>(significand >> 1); 394 } 395 396 while (exponent < min_exponent) { 397 significand = static_cast<uint_type>(significand >> 1); 398 ++exponent; 399 } 400 401 if (exponent == min_exponent) { 402 if (significand == 0 && !significand_is_zero && round_denorm_up) { 403 significand = static_cast<uint_type>(0x1); 404 } 405 } 406 407 uint_type new_value = 0; 408 if (negative) { 409 new_value = static_cast<uint_type>(new_value | sign_mask); 410 } 411 exponent = static_cast<int_type>(exponent + exponent_bias); 412 assert(exponent >= 0); 413 414 // put it all together 415 exponent = static_cast<uint_type>((exponent << exponent_left_shift) & 416 exponent_mask); 417 significand = static_cast<uint_type>(significand & fraction_encode_mask); 418 new_value = static_cast<uint_type>(new_value | (exponent | significand)); 419 value_ = BitwiseCast<T>(new_value); 420 } 421 422 // Increments the significand of this number by the given amount. 423 // If this would spill the significand into the implicit bit, 424 // carry is set to true and the significand is shifted to fit into 425 // the correct location, otherwise carry is set to false. 426 // All significands and to_increment are assumed to be within the bounds 427 // for a valid significand. 428 static uint_type incrementSignificand(uint_type significand, 429 uint_type to_increment, bool* carry) { 430 significand = static_cast<uint_type>(significand + to_increment); 431 *carry = false; 432 if (significand & first_exponent_bit) { 433 *carry = true; 434 // The implicit 1-bit will have carried, so we should zero-out the 435 // top bit and shift back. 436 significand = static_cast<uint_type>(significand & ~first_exponent_bit); 437 significand = static_cast<uint_type>(significand >> 1); 438 } 439 return significand; 440 } 441 442 // These exist because MSVC throws warnings on negative right-shifts 443 // even if they are not going to be executed. Eg: 444 // constant_number < 0? 0: constant_number 445 // These convert the negative left-shifts into right shifts. 446 447 template <int_type N, typename enable = void> 448 struct negatable_left_shift { 449 static uint_type val(uint_type val) { 450 return static_cast<uint_type>(val >> -N); 451 } 452 }; 453 454 template <int_type N> 455 struct negatable_left_shift<N, typename std::enable_if<N >= 0>::type> { 456 static uint_type val(uint_type val) { 457 return static_cast<uint_type>(val << N); 458 } 459 }; 460 461 template <int_type N, typename enable = void> 462 struct negatable_right_shift { 463 static uint_type val(uint_type val) { 464 return static_cast<uint_type>(val << -N); 465 } 466 }; 467 468 template <int_type N> 469 struct negatable_right_shift<N, typename std::enable_if<N >= 0>::type> { 470 static uint_type val(uint_type val) { 471 return static_cast<uint_type>(val >> N); 472 } 473 }; 474 475 // Returns the significand, rounded to fit in a significand in 476 // other_T. This is shifted so that the most significant 477 // bit of the rounded number lines up with the most significant bit 478 // of the returned significand. 479 template <typename other_T> 480 typename other_T::uint_type getRoundedNormalizedSignificand( 481 round_direction dir, bool* carry_bit) { 482 using other_uint_type = typename other_T::uint_type; 483 static const int_type num_throwaway_bits = 484 static_cast<int_type>(num_fraction_bits) - 485 static_cast<int_type>(other_T::num_fraction_bits); 486 487 static const uint_type last_significant_bit = 488 (num_throwaway_bits < 0) 489 ? 0 490 : negatable_left_shift<num_throwaway_bits>::val(1u); 491 static const uint_type first_rounded_bit = 492 (num_throwaway_bits < 1) 493 ? 0 494 : negatable_left_shift<num_throwaway_bits - 1>::val(1u); 495 496 static const uint_type throwaway_mask_bits = 497 num_throwaway_bits > 0 ? num_throwaway_bits : 0; 498 static const uint_type throwaway_mask = 499 spvutils::SetBits<uint_type, 0, throwaway_mask_bits>::get; 500 501 *carry_bit = false; 502 other_uint_type out_val = 0; 503 uint_type significand = getNormalizedSignificand(); 504 // If we are up-casting, then we just have to shift to the right location. 505 if (num_throwaway_bits <= 0) { 506 out_val = static_cast<other_uint_type>(significand); 507 uint_type shift_amount = static_cast<uint_type>(-num_throwaway_bits); 508 out_val = static_cast<other_uint_type>(out_val << shift_amount); 509 return out_val; 510 } 511 512 // If every non-representable bit is 0, then we don't have any casting to 513 // do. 514 if ((significand & throwaway_mask) == 0) { 515 return static_cast<other_uint_type>( 516 negatable_right_shift<num_throwaway_bits>::val(significand)); 517 } 518 519 bool round_away_from_zero = false; 520 // We actually have to narrow the significand here, so we have to follow the 521 // rounding rules. 522 switch (dir) { 523 case round_direction::kToZero: 524 break; 525 case round_direction::kToPositiveInfinity: 526 round_away_from_zero = !isNegative(); 527 break; 528 case round_direction::kToNegativeInfinity: 529 round_away_from_zero = isNegative(); 530 break; 531 case round_direction::kToNearestEven: 532 // Have to round down, round bit is 0 533 if ((first_rounded_bit & significand) == 0) { 534 break; 535 } 536 if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) { 537 // If any subsequent bit of the rounded portion is non-0 then we round 538 // up. 539 round_away_from_zero = true; 540 break; 541 } 542 // We are exactly half-way between 2 numbers, pick even. 543 if ((significand & last_significant_bit) != 0) { 544 // 1 for our last bit, round up. 545 round_away_from_zero = true; 546 break; 547 } 548 break; 549 } 550 551 if (round_away_from_zero) { 552 return static_cast<other_uint_type>( 553 negatable_right_shift<num_throwaway_bits>::val(incrementSignificand( 554 significand, last_significant_bit, carry_bit))); 555 } else { 556 return static_cast<other_uint_type>( 557 negatable_right_shift<num_throwaway_bits>::val(significand)); 558 } 559 } 560 561 // Casts this value to another HexFloat. If the cast is widening, 562 // then round_dir is ignored. If the cast is narrowing, then 563 // the result is rounded in the direction specified. 564 // This number will retain Nan and Inf values. 565 // It will also saturate to Inf if the number overflows, and 566 // underflow to (0 or min depending on rounding) if the number underflows. 567 template <typename other_T> 568 void castTo(other_T& other, round_direction round_dir) { 569 other = other_T(static_cast<typename other_T::native_type>(0)); 570 bool negate = isNegative(); 571 if (getUnsignedBits() == 0) { 572 if (negate) { 573 other.set_value(-other.value()); 574 } 575 return; 576 } 577 uint_type significand = getSignificandBits(); 578 bool carried = false; 579 typename other_T::uint_type rounded_significand = 580 getRoundedNormalizedSignificand<other_T>(round_dir, &carried); 581 582 int_type exponent = getUnbiasedExponent(); 583 if (exponent == min_exponent) { 584 // If we are denormal, normalize the exponent, so that we can encode 585 // easily. 586 exponent = static_cast<int_type>(exponent + 1); 587 for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0; 588 check_bit = static_cast<uint_type>(check_bit >> 1)) { 589 exponent = static_cast<int_type>(exponent - 1); 590 if (check_bit & significand) break; 591 } 592 } 593 594 bool is_nan = 595 (getBits() & exponent_mask) == exponent_mask && significand != 0; 596 bool is_inf = 597 !is_nan && 598 ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) || 599 (significand == 0 && (getBits() & exponent_mask) == exponent_mask)); 600 601 // If we are Nan or Inf we should pass that through. 602 if (is_inf) { 603 other.set_value(BitwiseCast<typename other_T::underlying_type>( 604 static_cast<typename other_T::uint_type>( 605 (negate ? other_T::sign_mask : 0) | other_T::exponent_mask))); 606 return; 607 } 608 if (is_nan) { 609 typename other_T::uint_type shifted_significand; 610 shifted_significand = static_cast<typename other_T::uint_type>( 611 negatable_left_shift< 612 static_cast<int_type>(other_T::num_fraction_bits) - 613 static_cast<int_type>(num_fraction_bits)>::val(significand)); 614 615 // We are some sort of Nan. We try to keep the bit-pattern of the Nan 616 // as close as possible. If we had to shift off bits so we are 0, then we 617 // just set the last bit. 618 other.set_value(BitwiseCast<typename other_T::underlying_type>( 619 static_cast<typename other_T::uint_type>( 620 (negate ? other_T::sign_mask : 0) | other_T::exponent_mask | 621 (shifted_significand == 0 ? 0x1 : shifted_significand)))); 622 return; 623 } 624 625 bool round_underflow_up = 626 isNegative() ? round_dir == round_direction::kToNegativeInfinity 627 : round_dir == round_direction::kToPositiveInfinity; 628 using other_int_type = typename other_T::int_type; 629 // setFromSignUnbiasedExponentAndNormalizedSignificand will 630 // zero out any underflowing value (but retain the sign). 631 other.setFromSignUnbiasedExponentAndNormalizedSignificand( 632 negate, static_cast<other_int_type>(exponent), rounded_significand, 633 round_underflow_up); 634 return; 635 } 636 637 private: 638 T value_; 639 640 static_assert(num_used_bits == 641 Traits::num_exponent_bits + Traits::num_fraction_bits + 1, 642 "The number of bits do not fit"); 643 static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match"); 644 }; 645 646 // Returns 4 bits represented by the hex character. 647 inline uint8_t get_nibble_from_character(int character) { 648 const char* dec = "0123456789"; 649 const char* lower = "abcdef"; 650 const char* upper = "ABCDEF"; 651 const char* p = nullptr; 652 if ((p = strchr(dec, character))) { 653 return static_cast<uint8_t>(p - dec); 654 } else if ((p = strchr(lower, character))) { 655 return static_cast<uint8_t>(p - lower + 0xa); 656 } else if ((p = strchr(upper, character))) { 657 return static_cast<uint8_t>(p - upper + 0xa); 658 } 659 660 assert(false && "This was called with a non-hex character"); 661 return 0; 662 } 663 664 // Outputs the given HexFloat to the stream. 665 template <typename T, typename Traits> 666 std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) { 667 using HF = HexFloat<T, Traits>; 668 using uint_type = typename HF::uint_type; 669 using int_type = typename HF::int_type; 670 671 static_assert(HF::num_used_bits != 0, 672 "num_used_bits must be non-zero for a valid float"); 673 static_assert(HF::num_exponent_bits != 0, 674 "num_exponent_bits must be non-zero for a valid float"); 675 static_assert(HF::num_fraction_bits != 0, 676 "num_fractin_bits must be non-zero for a valid float"); 677 678 const uint_type bits = spvutils::BitwiseCast<uint_type>(value.value()); 679 const char* const sign = (bits & HF::sign_mask) ? "-" : ""; 680 const uint_type exponent = static_cast<uint_type>( 681 (bits & HF::exponent_mask) >> HF::num_fraction_bits); 682 683 uint_type fraction = static_cast<uint_type>((bits & HF::fraction_encode_mask) 684 << HF::num_overflow_bits); 685 686 const bool is_zero = exponent == 0 && fraction == 0; 687 const bool is_denorm = exponent == 0 && !is_zero; 688 689 // exponent contains the biased exponent we have to convert it back into 690 // the normal range. 691 int_type int_exponent = static_cast<int_type>(exponent - HF::exponent_bias); 692 // If the number is all zeros, then we actually have to NOT shift the 693 // exponent. 694 int_exponent = is_zero ? 0 : int_exponent; 695 696 // If we are denorm, then start shifting, and decreasing the exponent until 697 // our leading bit is 1. 698 699 if (is_denorm) { 700 while ((fraction & HF::fraction_top_bit) == 0) { 701 fraction = static_cast<uint_type>(fraction << 1); 702 int_exponent = static_cast<int_type>(int_exponent - 1); 703 } 704 // Since this is denormalized, we have to consume the leading 1 since it 705 // will end up being implicit. 706 fraction = static_cast<uint_type>(fraction << 1); // eat the leading 1 707 fraction &= HF::fraction_represent_mask; 708 } 709 710 uint_type fraction_nibbles = HF::fraction_nibbles; 711 // We do not have to display any trailing 0s, since this represents the 712 // fractional part. 713 while (fraction_nibbles > 0 && (fraction & 0xF) == 0) { 714 // Shift off any trailing values; 715 fraction = static_cast<uint_type>(fraction >> 4); 716 --fraction_nibbles; 717 } 718 719 const auto saved_flags = os.flags(); 720 const auto saved_fill = os.fill(); 721 722 os << sign << "0x" << (is_zero ? '0' : '1'); 723 if (fraction_nibbles) { 724 // Make sure to keep the leading 0s in place, since this is the fractional 725 // part. 726 os << "." << std::setw(static_cast<int>(fraction_nibbles)) 727 << std::setfill('0') << std::hex << fraction; 728 } 729 os << "p" << std::dec << (int_exponent >= 0 ? "+" : "") << int_exponent; 730 731 os.flags(saved_flags); 732 os.fill(saved_fill); 733 734 return os; 735 } 736 737 // Returns true if negate_value is true and the next character on the 738 // input stream is a plus or minus sign. In that case we also set the fail bit 739 // on the stream and set the value to the zero value for its type. 740 template <typename T, typename Traits> 741 inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value, 742 HexFloat<T, Traits>& value) { 743 if (negate_value) { 744 auto next_char = is.peek(); 745 if (next_char == '-' || next_char == '+') { 746 // Fail the parse. Emulate standard behaviour by setting the value to 747 // the zero value, and set the fail bit on the stream. 748 value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0}); 749 is.setstate(std::ios_base::failbit); 750 return true; 751 } 752 } 753 return false; 754 } 755 756 // Parses a floating point number from the given stream and stores it into the 757 // value parameter. 758 // If negate_value is true then the number may not have a leading minus or 759 // plus, and if it successfully parses, then the number is negated before 760 // being stored into the value parameter. 761 // If the value cannot be correctly parsed or overflows the target floating 762 // point type, then set the fail bit on the stream. 763 // TODO(dneto): Promise C++11 standard behavior in how the value is set in 764 // the error case, but only after all target platforms implement it correctly. 765 // In particular, the Microsoft C++ runtime appears to be out of spec. 766 template <typename T, typename Traits> 767 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value, 768 HexFloat<T, Traits>& value) { 769 if (RejectParseDueToLeadingSign(is, negate_value, value)) { 770 return is; 771 } 772 T val; 773 is >> val; 774 if (negate_value) { 775 val = -val; 776 } 777 value.set_value(val); 778 // In the failure case, map -0.0 to 0.0. 779 if (is.fail() && value.getUnsignedBits() == 0u) { 780 value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0}); 781 } 782 if (val.isInfinity()) { 783 // Fail the parse. Emulate standard behaviour by setting the value to 784 // the closest normal value, and set the fail bit on the stream. 785 value.set_value((value.isNegative() | negate_value) ? T::lowest() 786 : T::max()); 787 is.setstate(std::ios_base::failbit); 788 } 789 return is; 790 } 791 792 // Specialization of ParseNormalFloat for FloatProxy<Float16> values. 793 // This will parse the float as it were a 32-bit floating point number, 794 // and then round it down to fit into a Float16 value. 795 // The number is rounded towards zero. 796 // If negate_value is true then the number may not have a leading minus or 797 // plus, and if it successfully parses, then the number is negated before 798 // being stored into the value parameter. 799 // If the value cannot be correctly parsed or overflows the target floating 800 // point type, then set the fail bit on the stream. 801 // TODO(dneto): Promise C++11 standard behavior in how the value is set in 802 // the error case, but only after all target platforms implement it correctly. 803 // In particular, the Microsoft C++ runtime appears to be out of spec. 804 template <> 805 inline std::istream& 806 ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>( 807 std::istream& is, bool negate_value, 808 HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) { 809 // First parse as a 32-bit float. 810 HexFloat<FloatProxy<float>> float_val(0.0f); 811 ParseNormalFloat(is, negate_value, float_val); 812 813 // Then convert to 16-bit float, saturating at infinities, and 814 // rounding toward zero. 815 float_val.castTo(value, round_direction::kToZero); 816 817 // Overflow on 16-bit behaves the same as for 32- and 64-bit: set the 818 // fail bit and set the lowest or highest value. 819 if (Float16::isInfinity(value.value().getAsFloat())) { 820 value.set_value(value.isNegative() ? Float16::lowest() : Float16::max()); 821 is.setstate(std::ios_base::failbit); 822 } 823 return is; 824 } 825 826 // Reads a HexFloat from the given stream. 827 // If the float is not encoded as a hex-float then it will be parsed 828 // as a regular float. 829 // This may fail if your stream does not support at least one unget. 830 // Nan values can be encoded with "0x1.<not zero>p+exponent_bias". 831 // This would normally overflow a float and round to 832 // infinity but this special pattern is the exact representation for a NaN, 833 // and therefore is actually encoded as the correct NaN. To encode inf, 834 // either 0x0p+exponent_bias can be specified or any exponent greater than 835 // exponent_bias. 836 // Examples using IEEE 32-bit float encoding. 837 // 0x1.0p+128 (+inf) 838 // -0x1.0p-128 (-inf) 839 // 840 // 0x1.1p+128 (+Nan) 841 // -0x1.1p+128 (-Nan) 842 // 843 // 0x1p+129 (+inf) 844 // -0x1p+129 (-inf) 845 template <typename T, typename Traits> 846 std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) { 847 using HF = HexFloat<T, Traits>; 848 using uint_type = typename HF::uint_type; 849 using int_type = typename HF::int_type; 850 851 value.set_value(static_cast<typename HF::native_type>(0.f)); 852 853 if (is.flags() & std::ios::skipws) { 854 // If the user wants to skip whitespace , then we should obey that. 855 while (std::isspace(is.peek())) { 856 is.get(); 857 } 858 } 859 860 auto next_char = is.peek(); 861 bool negate_value = false; 862 863 if (next_char != '-' && next_char != '0') { 864 return ParseNormalFloat(is, negate_value, value); 865 } 866 867 if (next_char == '-') { 868 negate_value = true; 869 is.get(); 870 next_char = is.peek(); 871 } 872 873 if (next_char == '0') { 874 is.get(); // We may have to unget this. 875 auto maybe_hex_start = is.peek(); 876 if (maybe_hex_start != 'x' && maybe_hex_start != 'X') { 877 is.unget(); 878 return ParseNormalFloat(is, negate_value, value); 879 } else { 880 is.get(); // Throw away the 'x'; 881 } 882 } else { 883 return ParseNormalFloat(is, negate_value, value); 884 } 885 886 // This "looks" like a hex-float so treat it as one. 887 bool seen_p = false; 888 bool seen_dot = false; 889 uint_type fraction_index = 0; 890 891 uint_type fraction = 0; 892 int_type exponent = HF::exponent_bias; 893 894 // Strip off leading zeros so we don't have to special-case them later. 895 while ((next_char = is.peek()) == '0') { 896 is.get(); 897 } 898 899 bool is_denorm = 900 true; // Assume denorm "representation" until we hear otherwise. 901 // NB: This does not mean the value is actually denorm, 902 // it just means that it was written 0. 903 bool bits_written = false; // Stays false until we write a bit. 904 while (!seen_p && !seen_dot) { 905 // Handle characters that are left of the fractional part. 906 if (next_char == '.') { 907 seen_dot = true; 908 } else if (next_char == 'p') { 909 seen_p = true; 910 } else if (::isxdigit(next_char)) { 911 // We know this is not denormalized since we have stripped all leading 912 // zeroes and we are not a ".". 913 is_denorm = false; 914 int number = get_nibble_from_character(next_char); 915 for (int i = 0; i < 4; ++i, number <<= 1) { 916 uint_type write_bit = (number & 0x8) ? 0x1 : 0x0; 917 if (bits_written) { 918 // If we are here the bits represented belong in the fractional 919 // part of the float, and we have to adjust the exponent accordingly. 920 fraction = static_cast<uint_type>( 921 fraction | 922 static_cast<uint_type>( 923 write_bit << (HF::top_bit_left_shift - fraction_index++))); 924 exponent = static_cast<int_type>(exponent + 1); 925 } 926 bits_written |= write_bit != 0; 927 } 928 } else { 929 // We have not found our exponent yet, so we have to fail. 930 is.setstate(std::ios::failbit); 931 return is; 932 } 933 is.get(); 934 next_char = is.peek(); 935 } 936 bits_written = false; 937 while (seen_dot && !seen_p) { 938 // Handle only fractional parts now. 939 if (next_char == 'p') { 940 seen_p = true; 941 } else if (::isxdigit(next_char)) { 942 int number = get_nibble_from_character(next_char); 943 for (int i = 0; i < 4; ++i, number <<= 1) { 944 uint_type write_bit = (number & 0x8) ? 0x01 : 0x00; 945 bits_written |= write_bit != 0; 946 if (is_denorm && !bits_written) { 947 // Handle modifying the exponent here this way we can handle 948 // an arbitrary number of hex values without overflowing our 949 // integer. 950 exponent = static_cast<int_type>(exponent - 1); 951 } else { 952 fraction = static_cast<uint_type>( 953 fraction | 954 static_cast<uint_type>( 955 write_bit << (HF::top_bit_left_shift - fraction_index++))); 956 } 957 } 958 } else { 959 // We still have not found our 'p' exponent yet, so this is not a valid 960 // hex-float. 961 is.setstate(std::ios::failbit); 962 return is; 963 } 964 is.get(); 965 next_char = is.peek(); 966 } 967 968 bool seen_sign = false; 969 int8_t exponent_sign = 1; 970 int_type written_exponent = 0; 971 while (true) { 972 if ((next_char == '-' || next_char == '+')) { 973 if (seen_sign) { 974 is.setstate(std::ios::failbit); 975 return is; 976 } 977 seen_sign = true; 978 exponent_sign = (next_char == '-') ? -1 : 1; 979 } else if (::isdigit(next_char)) { 980 // Hex-floats express their exponent as decimal. 981 written_exponent = static_cast<int_type>(written_exponent * 10); 982 written_exponent = 983 static_cast<int_type>(written_exponent + (next_char - '0')); 984 } else { 985 break; 986 } 987 is.get(); 988 next_char = is.peek(); 989 } 990 991 written_exponent = static_cast<int_type>(written_exponent * exponent_sign); 992 exponent = static_cast<int_type>(exponent + written_exponent); 993 994 bool is_zero = is_denorm && (fraction == 0); 995 if (is_denorm && !is_zero) { 996 fraction = static_cast<uint_type>(fraction << 1); 997 exponent = static_cast<int_type>(exponent - 1); 998 } else if (is_zero) { 999 exponent = 0; 1000 } 1001 1002 if (exponent <= 0 && !is_zero) { 1003 fraction = static_cast<uint_type>(fraction >> 1); 1004 fraction |= static_cast<uint_type>(1) << HF::top_bit_left_shift; 1005 } 1006 1007 fraction = (fraction >> HF::fraction_right_shift) & HF::fraction_encode_mask; 1008 1009 const int_type max_exponent = 1010 SetBits<uint_type, 0, HF::num_exponent_bits>::get; 1011 1012 // Handle actual denorm numbers 1013 while (exponent < 0 && !is_zero) { 1014 fraction = static_cast<uint_type>(fraction >> 1); 1015 exponent = static_cast<int_type>(exponent + 1); 1016 1017 fraction &= HF::fraction_encode_mask; 1018 if (fraction == 0) { 1019 // We have underflowed our fraction. We should clamp to zero. 1020 is_zero = true; 1021 exponent = 0; 1022 } 1023 } 1024 1025 // We have overflowed so we should be inf/-inf. 1026 if (exponent > max_exponent) { 1027 exponent = max_exponent; 1028 fraction = 0; 1029 } 1030 1031 uint_type output_bits = static_cast<uint_type>( 1032 static_cast<uint_type>(negate_value ? 1 : 0) << HF::top_bit_left_shift); 1033 output_bits |= fraction; 1034 1035 uint_type shifted_exponent = static_cast<uint_type>( 1036 static_cast<uint_type>(exponent << HF::exponent_left_shift) & 1037 HF::exponent_mask); 1038 output_bits |= shifted_exponent; 1039 1040 T output_float = spvutils::BitwiseCast<T>(output_bits); 1041 value.set_value(output_float); 1042 1043 return is; 1044 } 1045 1046 // Writes a FloatProxy value to a stream. 1047 // Zero and normal numbers are printed in the usual notation, but with 1048 // enough digits to fully reproduce the value. Other values (subnormal, 1049 // NaN, and infinity) are printed as a hex float. 1050 template <typename T> 1051 std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) { 1052 auto float_val = value.getAsFloat(); 1053 switch (std::fpclassify(float_val)) { 1054 case FP_ZERO: 1055 case FP_NORMAL: { 1056 auto saved_precision = os.precision(); 1057 os.precision(std::numeric_limits<T>::digits10); 1058 os << float_val; 1059 os.precision(saved_precision); 1060 } break; 1061 default: 1062 os << HexFloat<FloatProxy<T>>(value); 1063 break; 1064 } 1065 return os; 1066 } 1067 1068 template <> 1069 inline std::ostream& operator<<<Float16>(std::ostream& os, 1070 const FloatProxy<Float16>& value) { 1071 os << HexFloat<FloatProxy<Float16>>(value); 1072 return os; 1073 } 1074 } 1075 1076 #endif // LIBSPIRV_UTIL_HEX_FLOAT_H_ 1077