• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_UTIL_HEX_FLOAT_H_
16 #define SOURCE_UTIL_HEX_FLOAT_H_
17 
18 #include <cassert>
19 #include <cctype>
20 #include <cmath>
21 #include <cstdint>
22 #include <iomanip>
23 #include <limits>
24 #include <sstream>
25 #include <vector>
26 
27 #include "source/util/bitutils.h"
28 
29 #ifndef __GNUC__
30 #define GCC_VERSION 0
31 #else
32 #define GCC_VERSION \
33   (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
34 #endif
35 
36 namespace spvtools {
37 namespace utils {
38 
39 class Float16 {
40  public:
Float16(uint16_t v)41   Float16(uint16_t v) : val(v) {}
42   Float16() = default;
isNan(const Float16 & val)43   static bool isNan(const Float16& val) {
44     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0);
45   }
46   // Returns true if the given value is any kind of infinity.
isInfinity(const Float16 & val)47   static bool isInfinity(const Float16& val) {
48     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) == 0);
49   }
Float16(const Float16 & other)50   Float16(const Float16& other) { val = other.val; }
get_value()51   uint16_t get_value() const { return val; }
52 
53   // Returns the maximum normal value.
max()54   static Float16 max() { return Float16(0x7bff); }
55   // Returns the lowest normal value.
lowest()56   static Float16 lowest() { return Float16(0xfbff); }
57 
58  private:
59   uint16_t val;
60 };
61 
62 // To specialize this type, you must override uint_type to define
63 // an unsigned integer that can fit your floating point type.
64 // You must also add a isNan function that returns true if
65 // a value is Nan.
66 template <typename T>
67 struct FloatProxyTraits {
68   using uint_type = void;
69 };
70 
71 template <>
72 struct FloatProxyTraits<float> {
73   using uint_type = uint32_t;
74   static bool isNan(float f) { return std::isnan(f); }
75   // Returns true if the given value is any kind of infinity.
76   static bool isInfinity(float f) { return std::isinf(f); }
77   // Returns the maximum normal value.
78   static float max() { return std::numeric_limits<float>::max(); }
79   // Returns the lowest normal value.
80   static float lowest() { return std::numeric_limits<float>::lowest(); }
81   // Returns the value as the native floating point format.
82   static float getAsFloat(const uint_type& t) { return BitwiseCast<float>(t); }
83   // Returns the bits from the given floating pointer number.
84   static uint_type getBitsFromFloat(const float& t) {
85     return BitwiseCast<uint_type>(t);
86   }
87   // Returns the bitwidth.
88   static uint32_t width() { return 32u; }
89 };
90 
91 template <>
92 struct FloatProxyTraits<double> {
93   using uint_type = uint64_t;
94   static bool isNan(double f) { return std::isnan(f); }
95   // Returns true if the given value is any kind of infinity.
96   static bool isInfinity(double f) { return std::isinf(f); }
97   // Returns the maximum normal value.
98   static double max() { return std::numeric_limits<double>::max(); }
99   // Returns the lowest normal value.
100   static double lowest() { return std::numeric_limits<double>::lowest(); }
101   // Returns the value as the native floating point format.
102   static double getAsFloat(const uint_type& t) {
103     return BitwiseCast<double>(t);
104   }
105   // Returns the bits from the given floating pointer number.
106   static uint_type getBitsFromFloat(const double& t) {
107     return BitwiseCast<uint_type>(t);
108   }
109   // Returns the bitwidth.
110   static uint32_t width() { return 64u; }
111 };
112 
113 template <>
114 struct FloatProxyTraits<Float16> {
115   using uint_type = uint16_t;
116   static bool isNan(Float16 f) { return Float16::isNan(f); }
117   // Returns true if the given value is any kind of infinity.
118   static bool isInfinity(Float16 f) { return Float16::isInfinity(f); }
119   // Returns the maximum normal value.
120   static Float16 max() { return Float16::max(); }
121   // Returns the lowest normal value.
122   static Float16 lowest() { return Float16::lowest(); }
123   // Returns the value as the native floating point format.
124   static Float16 getAsFloat(const uint_type& t) { return Float16(t); }
125   // Returns the bits from the given floating pointer number.
126   static uint_type getBitsFromFloat(const Float16& t) { return t.get_value(); }
127   // Returns the bitwidth.
128   static uint32_t width() { return 16u; }
129 };
130 
131 // Since copying a floating point number (especially if it is NaN)
132 // does not guarantee that bits are preserved, this class lets us
133 // store the type and use it as a float when necessary.
134 template <typename T>
135 class FloatProxy {
136  public:
137   using uint_type = typename FloatProxyTraits<T>::uint_type;
138 
139   // Since this is to act similar to the normal floats,
140   // do not initialize the data by default.
141   FloatProxy() = default;
142 
143   // Intentionally non-explicit. This is a proxy type so
144   // implicit conversions allow us to use it more transparently.
145   FloatProxy(T val) { data_ = FloatProxyTraits<T>::getBitsFromFloat(val); }
146 
147   // Intentionally non-explicit. This is a proxy type so
148   // implicit conversions allow us to use it more transparently.
149   FloatProxy(uint_type val) { data_ = val; }
150 
151   // This is helpful to have and is guaranteed not to stomp bits.
152   FloatProxy<T> operator-() const {
153     return static_cast<uint_type>(data_ ^
154                                   (uint_type(0x1) << (sizeof(T) * 8 - 1)));
155   }
156 
157   // Returns the data as a floating point value.
158   T getAsFloat() const { return FloatProxyTraits<T>::getAsFloat(data_); }
159 
160   // Returns the raw data.
161   uint_type data() const { return data_; }
162 
163   // Returns a vector of words suitable for use in an Operand.
164   std::vector<uint32_t> GetWords() const {
165     std::vector<uint32_t> words;
166     if (FloatProxyTraits<T>::width() == 64) {
167       FloatProxyTraits<double>::uint_type d = data();
168       words.push_back(static_cast<uint32_t>(d));
169       words.push_back(static_cast<uint32_t>(d >> 32));
170     } else {
171       words.push_back(static_cast<uint32_t>(data()));
172     }
173     return words;
174   }
175 
176   // Returns true if the value represents any type of NaN.
177   bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
178   // Returns true if the value represents any type of infinity.
179   bool isInfinity() { return FloatProxyTraits<T>::isInfinity(getAsFloat()); }
180 
181   // Returns the maximum normal value.
182   static FloatProxy<T> max() {
183     return FloatProxy<T>(FloatProxyTraits<T>::max());
184   }
185   // Returns the lowest normal value.
186   static FloatProxy<T> lowest() {
187     return FloatProxy<T>(FloatProxyTraits<T>::lowest());
188   }
189 
190  private:
191   uint_type data_;
192 };
193 
194 template <typename T>
195 bool operator==(const FloatProxy<T>& first, const FloatProxy<T>& second) {
196   return first.data() == second.data();
197 }
198 
199 // Reads a FloatProxy value as a normal float from a stream.
200 template <typename T>
201 std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
202   T float_val;
203   is >> float_val;
204   value = FloatProxy<T>(float_val);
205   return is;
206 }
207 
208 // This is an example traits. It is not meant to be used in practice, but will
209 // be the default for any non-specialized type.
210 template <typename T>
211 struct HexFloatTraits {
212   // Integer type that can store this hex-float.
213   using uint_type = void;
214   // Signed integer type that can store this hex-float.
215   using int_type = void;
216   // The numerical type that this HexFloat represents.
217   using underlying_type = void;
218   // The type needed to construct the underlying type.
219   using native_type = void;
220   // The number of bits that are actually relevant in the uint_type.
221   // This allows us to deal with, for example, 24-bit values in a 32-bit
222   // integer.
223   static const uint32_t num_used_bits = 0;
224   // Number of bits that represent the exponent.
225   static const uint32_t num_exponent_bits = 0;
226   // Number of bits that represent the fractional part.
227   static const uint32_t num_fraction_bits = 0;
228   // The bias of the exponent. (How much we need to subtract from the stored
229   // value to get the correct value.)
230   static const uint32_t exponent_bias = 0;
231 };
232 
233 // Traits for IEEE float.
234 // 1 sign bit, 8 exponent bits, 23 fractional bits.
235 template <>
236 struct HexFloatTraits<FloatProxy<float>> {
237   using uint_type = uint32_t;
238   using int_type = int32_t;
239   using underlying_type = FloatProxy<float>;
240   using native_type = float;
241   static const uint_type num_used_bits = 32;
242   static const uint_type num_exponent_bits = 8;
243   static const uint_type num_fraction_bits = 23;
244   static const uint_type exponent_bias = 127;
245 };
246 
247 // Traits for IEEE double.
248 // 1 sign bit, 11 exponent bits, 52 fractional bits.
249 template <>
250 struct HexFloatTraits<FloatProxy<double>> {
251   using uint_type = uint64_t;
252   using int_type = int64_t;
253   using underlying_type = FloatProxy<double>;
254   using native_type = double;
255   static const uint_type num_used_bits = 64;
256   static const uint_type num_exponent_bits = 11;
257   static const uint_type num_fraction_bits = 52;
258   static const uint_type exponent_bias = 1023;
259 };
260 
261 // Traits for IEEE half.
262 // 1 sign bit, 5 exponent bits, 10 fractional bits.
263 template <>
264 struct HexFloatTraits<FloatProxy<Float16>> {
265   using uint_type = uint16_t;
266   using int_type = int16_t;
267   using underlying_type = uint16_t;
268   using native_type = uint16_t;
269   static const uint_type num_used_bits = 16;
270   static const uint_type num_exponent_bits = 5;
271   static const uint_type num_fraction_bits = 10;
272   static const uint_type exponent_bias = 15;
273 };
274 
275 enum class round_direction {
276   kToZero,
277   kToNearestEven,
278   kToPositiveInfinity,
279   kToNegativeInfinity,
280   max = kToNegativeInfinity
281 };
282 
283 // Template class that houses a floating pointer number.
284 // It exposes a number of constants based on the provided traits to
285 // assist in interpreting the bits of the value.
286 template <typename T, typename Traits = HexFloatTraits<T>>
287 class HexFloat {
288  public:
289   using uint_type = typename Traits::uint_type;
290   using int_type = typename Traits::int_type;
291   using underlying_type = typename Traits::underlying_type;
292   using native_type = typename Traits::native_type;
293 
294   explicit HexFloat(T f) : value_(f) {}
295 
296   T value() const { return value_; }
297   void set_value(T f) { value_ = f; }
298 
299   // These are all written like this because it is convenient to have
300   // compile-time constants for all of these values.
301 
302   // Pass-through values to save typing.
303   static const uint32_t num_used_bits = Traits::num_used_bits;
304   static const uint32_t exponent_bias = Traits::exponent_bias;
305   static const uint32_t num_exponent_bits = Traits::num_exponent_bits;
306   static const uint32_t num_fraction_bits = Traits::num_fraction_bits;
307 
308   // Number of bits to shift left to set the highest relevant bit.
309   static const uint32_t top_bit_left_shift = num_used_bits - 1;
310   // How many nibbles (hex characters) the fractional part takes up.
311   static const uint32_t fraction_nibbles = (num_fraction_bits + 3) / 4;
312   // If the fractional part does not fit evenly into a hex character (4-bits)
313   // then we have to left-shift to get rid of leading 0s. This is the amount
314   // we have to shift (might be 0).
315   static const uint32_t num_overflow_bits =
316       fraction_nibbles * 4 - num_fraction_bits;
317 
318   // The representation of the fraction, not the actual bits. This
319   // includes the leading bit that is usually implicit.
320   static const uint_type fraction_represent_mask =
321       SetBits<uint_type, 0, num_fraction_bits + num_overflow_bits>::get;
322 
323   // The topmost bit in the nibble-aligned fraction.
324   static const uint_type fraction_top_bit =
325       uint_type(1) << (num_fraction_bits + num_overflow_bits - 1);
326 
327   // The least significant bit in the exponent, which is also the bit
328   // immediately to the left of the significand.
329   static const uint_type first_exponent_bit = uint_type(1)
330                                               << (num_fraction_bits);
331 
332   // The mask for the encoded fraction. It does not include the
333   // implicit bit.
334   static const uint_type fraction_encode_mask =
335       SetBits<uint_type, 0, num_fraction_bits>::get;
336 
337   // The bit that is used as a sign.
338   static const uint_type sign_mask = uint_type(1) << top_bit_left_shift;
339 
340   // The bits that represent the exponent.
341   static const uint_type exponent_mask =
342       SetBits<uint_type, num_fraction_bits, num_exponent_bits>::get;
343 
344   // How far left the exponent is shifted.
345   static const uint32_t exponent_left_shift = num_fraction_bits;
346 
347   // How far from the right edge the fraction is shifted.
348   static const uint32_t fraction_right_shift =
349       static_cast<uint32_t>(sizeof(uint_type) * 8) - num_fraction_bits;
350 
351   // The maximum representable unbiased exponent.
352   static const int_type max_exponent =
353       (exponent_mask >> num_fraction_bits) - exponent_bias;
354   // The minimum representable exponent for normalized numbers.
355   static const int_type min_exponent = -static_cast<int_type>(exponent_bias);
356 
357   // Returns the bits associated with the value.
358   uint_type getBits() const { return value_.data(); }
359 
360   // Returns the bits associated with the value, without the leading sign bit.
361   uint_type getUnsignedBits() const {
362     return static_cast<uint_type>(value_.data() & ~sign_mask);
363   }
364 
365   // Returns the bits associated with the exponent, shifted to start at the
366   // lsb of the type.
367   const uint_type getExponentBits() const {
368     return static_cast<uint_type>((getBits() & exponent_mask) >>
369                                   num_fraction_bits);
370   }
371 
372   // Returns the exponent in unbiased form. This is the exponent in the
373   // human-friendly form.
374   const int_type getUnbiasedExponent() const {
375     return static_cast<int_type>(getExponentBits() - exponent_bias);
376   }
377 
378   // Returns just the significand bits from the value.
379   const uint_type getSignificandBits() const {
380     return getBits() & fraction_encode_mask;
381   }
382 
383   // If the number was normalized, returns the unbiased exponent.
384   // If the number was denormal, normalize the exponent first.
385   const int_type getUnbiasedNormalizedExponent() const {
386     if ((getBits() & ~sign_mask) == 0) {  // special case if everything is 0
387       return 0;
388     }
389     int_type exp = getUnbiasedExponent();
390     if (exp == min_exponent) {  // We are in denorm land.
391       uint_type significand_bits = getSignificandBits();
392       while ((significand_bits & (first_exponent_bit >> 1)) == 0) {
393         significand_bits = static_cast<uint_type>(significand_bits << 1);
394         exp = static_cast<int_type>(exp - 1);
395       }
396       significand_bits &= fraction_encode_mask;
397     }
398     return exp;
399   }
400 
401   // Returns the signficand after it has been normalized.
402   const uint_type getNormalizedSignificand() const {
403     int_type unbiased_exponent = getUnbiasedNormalizedExponent();
404     uint_type significand = getSignificandBits();
405     for (int_type i = unbiased_exponent; i <= min_exponent; ++i) {
406       significand = static_cast<uint_type>(significand << 1);
407     }
408     significand &= fraction_encode_mask;
409     return significand;
410   }
411 
412   // Returns true if this number represents a negative value.
413   bool isNegative() const { return (getBits() & sign_mask) != 0; }
414 
415   // Sets this HexFloat from the individual components.
416   // Note this assumes EVERY significand is normalized, and has an implicit
417   // leading one. This means that the only way that this method will set 0,
418   // is if you set a number so denormalized that it underflows.
419   // Do not use this method with raw bits extracted from a subnormal number,
420   // since subnormals do not have an implicit leading 1 in the significand.
421   // The significand is also expected to be in the
422   // lowest-most num_fraction_bits of the uint_type.
423   // The exponent is expected to be unbiased, meaning an exponent of
424   // 0 actually means 0.
425   // If underflow_round_up is set, then on underflow, if a number is non-0
426   // and would underflow, we round up to the smallest denorm.
427   void setFromSignUnbiasedExponentAndNormalizedSignificand(
428       bool negative, int_type exponent, uint_type significand,
429       bool round_denorm_up) {
430     bool significand_is_zero = significand == 0;
431 
432     if (exponent <= min_exponent) {
433       // If this was denormalized, then we have to shift the bit on, meaning
434       // the significand is not zero.
435       significand_is_zero = false;
436       significand |= first_exponent_bit;
437       significand = static_cast<uint_type>(significand >> 1);
438     }
439 
440     while (exponent < min_exponent) {
441       significand = static_cast<uint_type>(significand >> 1);
442       ++exponent;
443     }
444 
445     if (exponent == min_exponent) {
446       if (significand == 0 && !significand_is_zero && round_denorm_up) {
447         significand = static_cast<uint_type>(0x1);
448       }
449     }
450 
451     uint_type new_value = 0;
452     if (negative) {
453       new_value = static_cast<uint_type>(new_value | sign_mask);
454     }
455     exponent = static_cast<int_type>(exponent + exponent_bias);
456     assert(exponent >= 0);
457 
458     // put it all together
459     exponent = static_cast<uint_type>((exponent << exponent_left_shift) &
460                                       exponent_mask);
461     significand = static_cast<uint_type>(significand & fraction_encode_mask);
462     new_value = static_cast<uint_type>(new_value | (exponent | significand));
463     value_ = T(new_value);
464   }
465 
466   // Increments the significand of this number by the given amount.
467   // If this would spill the significand into the implicit bit,
468   // carry is set to true and the significand is shifted to fit into
469   // the correct location, otherwise carry is set to false.
470   // All significands and to_increment are assumed to be within the bounds
471   // for a valid significand.
472   static uint_type incrementSignificand(uint_type significand,
473                                         uint_type to_increment, bool* carry) {
474     significand = static_cast<uint_type>(significand + to_increment);
475     *carry = false;
476     if (significand & first_exponent_bit) {
477       *carry = true;
478       // The implicit 1-bit will have carried, so we should zero-out the
479       // top bit and shift back.
480       significand = static_cast<uint_type>(significand & ~first_exponent_bit);
481       significand = static_cast<uint_type>(significand >> 1);
482     }
483     return significand;
484   }
485 
486 #if GCC_VERSION == 40801
487   // These exist because MSVC throws warnings on negative right-shifts
488   // even if they are not going to be executed. Eg:
489   // constant_number < 0? 0: constant_number
490   // These convert the negative left-shifts into right shifts.
491   template <int_type N>
492   struct negatable_left_shift {
493     static uint_type val(uint_type val) {
494       if (N > 0) {
495         return static_cast<uint_type>(val << N);
496       } else {
497         return static_cast<uint_type>(val >> N);
498       }
499     }
500   };
501 
502   template <int_type N>
503   struct negatable_right_shift {
504     static uint_type val(uint_type val) {
505       if (N > 0) {
506         return static_cast<uint_type>(val >> N);
507       } else {
508         return static_cast<uint_type>(val << N);
509       }
510     }
511   };
512 
513 #else
514   // These exist because MSVC throws warnings on negative right-shifts
515   // even if they are not going to be executed. Eg:
516   // constant_number < 0? 0: constant_number
517   // These convert the negative left-shifts into right shifts.
518   template <int_type N, typename enable = void>
519   struct negatable_left_shift {
520     static uint_type val(uint_type val) {
521       return static_cast<uint_type>(val >> -N);
522     }
523   };
524 
525   template <int_type N>
526   struct negatable_left_shift<N, typename std::enable_if<N >= 0>::type> {
527     static uint_type val(uint_type val) {
528       return static_cast<uint_type>(val << N);
529     }
530   };
531 
532   template <int_type N, typename enable = void>
533   struct negatable_right_shift {
534     static uint_type val(uint_type val) {
535       return static_cast<uint_type>(val << -N);
536     }
537   };
538 
539   template <int_type N>
540   struct negatable_right_shift<N, typename std::enable_if<N >= 0>::type> {
541     static uint_type val(uint_type val) {
542       return static_cast<uint_type>(val >> N);
543     }
544   };
545 #endif
546 
547   // Returns the significand, rounded to fit in a significand in
548   // other_T. This is shifted so that the most significant
549   // bit of the rounded number lines up with the most significant bit
550   // of the returned significand.
551   template <typename other_T>
552   typename other_T::uint_type getRoundedNormalizedSignificand(
553       round_direction dir, bool* carry_bit) {
554     using other_uint_type = typename other_T::uint_type;
555     static const int_type num_throwaway_bits =
556         static_cast<int_type>(num_fraction_bits) -
557         static_cast<int_type>(other_T::num_fraction_bits);
558 
559     static const uint_type last_significant_bit =
560         (num_throwaway_bits < 0)
561             ? 0
562             : negatable_left_shift<num_throwaway_bits>::val(1u);
563     static const uint_type first_rounded_bit =
564         (num_throwaway_bits < 1)
565             ? 0
566             : negatable_left_shift<num_throwaway_bits - 1>::val(1u);
567 
568     static const uint_type throwaway_mask_bits =
569         num_throwaway_bits > 0 ? num_throwaway_bits : 0;
570     static const uint_type throwaway_mask =
571         SetBits<uint_type, 0, throwaway_mask_bits>::get;
572 
573     *carry_bit = false;
574     other_uint_type out_val = 0;
575     uint_type significand = getNormalizedSignificand();
576     // If we are up-casting, then we just have to shift to the right location.
577     if (num_throwaway_bits <= 0) {
578       out_val = static_cast<other_uint_type>(significand);
579       uint_type shift_amount = static_cast<uint_type>(-num_throwaway_bits);
580       out_val = static_cast<other_uint_type>(out_val << shift_amount);
581       return out_val;
582     }
583 
584     // If every non-representable bit is 0, then we don't have any casting to
585     // do.
586     if ((significand & throwaway_mask) == 0) {
587       return static_cast<other_uint_type>(
588           negatable_right_shift<num_throwaway_bits>::val(significand));
589     }
590 
591     bool round_away_from_zero = false;
592     // We actually have to narrow the significand here, so we have to follow the
593     // rounding rules.
594     switch (dir) {
595       case round_direction::kToZero:
596         break;
597       case round_direction::kToPositiveInfinity:
598         round_away_from_zero = !isNegative();
599         break;
600       case round_direction::kToNegativeInfinity:
601         round_away_from_zero = isNegative();
602         break;
603       case round_direction::kToNearestEven:
604         // Have to round down, round bit is 0
605         if ((first_rounded_bit & significand) == 0) {
606           break;
607         }
608         if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) {
609           // If any subsequent bit of the rounded portion is non-0 then we round
610           // up.
611           round_away_from_zero = true;
612           break;
613         }
614         // We are exactly half-way between 2 numbers, pick even.
615         if ((significand & last_significant_bit) != 0) {
616           // 1 for our last bit, round up.
617           round_away_from_zero = true;
618           break;
619         }
620         break;
621     }
622 
623     if (round_away_from_zero) {
624       return static_cast<other_uint_type>(
625           negatable_right_shift<num_throwaway_bits>::val(incrementSignificand(
626               significand, last_significant_bit, carry_bit)));
627     } else {
628       return static_cast<other_uint_type>(
629           negatable_right_shift<num_throwaway_bits>::val(significand));
630     }
631   }
632 
633   // Casts this value to another HexFloat. If the cast is widening,
634   // then round_dir is ignored. If the cast is narrowing, then
635   // the result is rounded in the direction specified.
636   // This number will retain Nan and Inf values.
637   // It will also saturate to Inf if the number overflows, and
638   // underflow to (0 or min depending on rounding) if the number underflows.
639   template <typename other_T>
640   void castTo(other_T& other, round_direction round_dir) {
641     other = other_T(static_cast<typename other_T::native_type>(0));
642     bool negate = isNegative();
643     if (getUnsignedBits() == 0) {
644       if (negate) {
645         other.set_value(-other.value());
646       }
647       return;
648     }
649     uint_type significand = getSignificandBits();
650     bool carried = false;
651     typename other_T::uint_type rounded_significand =
652         getRoundedNormalizedSignificand<other_T>(round_dir, &carried);
653 
654     int_type exponent = getUnbiasedExponent();
655     if (exponent == min_exponent) {
656       // If we are denormal, normalize the exponent, so that we can encode
657       // easily.
658       exponent = static_cast<int_type>(exponent + 1);
659       for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0;
660            check_bit = static_cast<uint_type>(check_bit >> 1)) {
661         exponent = static_cast<int_type>(exponent - 1);
662         if (check_bit & significand) break;
663       }
664     }
665 
666     bool is_nan =
667         (getBits() & exponent_mask) == exponent_mask && significand != 0;
668     bool is_inf =
669         !is_nan &&
670         ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
671          (significand == 0 && (getBits() & exponent_mask) == exponent_mask));
672 
673     // If we are Nan or Inf we should pass that through.
674     if (is_inf) {
675       other.set_value(typename other_T::underlying_type(
676           static_cast<typename other_T::uint_type>(
677               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
678       return;
679     }
680     if (is_nan) {
681       typename other_T::uint_type shifted_significand;
682       shifted_significand = static_cast<typename other_T::uint_type>(
683           negatable_left_shift<
684               static_cast<int_type>(other_T::num_fraction_bits) -
685               static_cast<int_type>(num_fraction_bits)>::val(significand));
686 
687       // We are some sort of Nan. We try to keep the bit-pattern of the Nan
688       // as close as possible. If we had to shift off bits so we are 0, then we
689       // just set the last bit.
690       other.set_value(typename other_T::underlying_type(
691           static_cast<typename other_T::uint_type>(
692               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask |
693               (shifted_significand == 0 ? 0x1 : shifted_significand))));
694       return;
695     }
696 
697     bool round_underflow_up =
698         isNegative() ? round_dir == round_direction::kToNegativeInfinity
699                      : round_dir == round_direction::kToPositiveInfinity;
700     using other_int_type = typename other_T::int_type;
701     // setFromSignUnbiasedExponentAndNormalizedSignificand will
702     // zero out any underflowing value (but retain the sign).
703     other.setFromSignUnbiasedExponentAndNormalizedSignificand(
704         negate, static_cast<other_int_type>(exponent), rounded_significand,
705         round_underflow_up);
706     return;
707   }
708 
709  private:
710   T value_;
711 
712   static_assert(num_used_bits ==
713                     Traits::num_exponent_bits + Traits::num_fraction_bits + 1,
714                 "The number of bits do not fit");
715   static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match");
716 };
717 
718 // Returns 4 bits represented by the hex character.
719 inline uint8_t get_nibble_from_character(int character) {
720   const char* dec = "0123456789";
721   const char* lower = "abcdef";
722   const char* upper = "ABCDEF";
723   const char* p = nullptr;
724   if ((p = strchr(dec, character))) {
725     return static_cast<uint8_t>(p - dec);
726   } else if ((p = strchr(lower, character))) {
727     return static_cast<uint8_t>(p - lower + 0xa);
728   } else if ((p = strchr(upper, character))) {
729     return static_cast<uint8_t>(p - upper + 0xa);
730   }
731 
732   assert(false && "This was called with a non-hex character");
733   return 0;
734 }
735 
736 // Outputs the given HexFloat to the stream.
737 template <typename T, typename Traits>
738 std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
739   using HF = HexFloat<T, Traits>;
740   using uint_type = typename HF::uint_type;
741   using int_type = typename HF::int_type;
742 
743   static_assert(HF::num_used_bits != 0,
744                 "num_used_bits must be non-zero for a valid float");
745   static_assert(HF::num_exponent_bits != 0,
746                 "num_exponent_bits must be non-zero for a valid float");
747   static_assert(HF::num_fraction_bits != 0,
748                 "num_fractin_bits must be non-zero for a valid float");
749 
750   const uint_type bits = value.value().data();
751   const char* const sign = (bits & HF::sign_mask) ? "-" : "";
752   const uint_type exponent = static_cast<uint_type>(
753       (bits & HF::exponent_mask) >> HF::num_fraction_bits);
754 
755   uint_type fraction = static_cast<uint_type>((bits & HF::fraction_encode_mask)
756                                               << HF::num_overflow_bits);
757 
758   const bool is_zero = exponent == 0 && fraction == 0;
759   const bool is_denorm = exponent == 0 && !is_zero;
760 
761   // exponent contains the biased exponent we have to convert it back into
762   // the normal range.
763   int_type int_exponent = static_cast<int_type>(exponent - HF::exponent_bias);
764   // If the number is all zeros, then we actually have to NOT shift the
765   // exponent.
766   int_exponent = is_zero ? 0 : int_exponent;
767 
768   // If we are denorm, then start shifting, and decreasing the exponent until
769   // our leading bit is 1.
770 
771   if (is_denorm) {
772     while ((fraction & HF::fraction_top_bit) == 0) {
773       fraction = static_cast<uint_type>(fraction << 1);
774       int_exponent = static_cast<int_type>(int_exponent - 1);
775     }
776     // Since this is denormalized, we have to consume the leading 1 since it
777     // will end up being implicit.
778     fraction = static_cast<uint_type>(fraction << 1);  // eat the leading 1
779     fraction &= HF::fraction_represent_mask;
780   }
781 
782   uint_type fraction_nibbles = HF::fraction_nibbles;
783   // We do not have to display any trailing 0s, since this represents the
784   // fractional part.
785   while (fraction_nibbles > 0 && (fraction & 0xF) == 0) {
786     // Shift off any trailing values;
787     fraction = static_cast<uint_type>(fraction >> 4);
788     --fraction_nibbles;
789   }
790 
791   const auto saved_flags = os.flags();
792   const auto saved_fill = os.fill();
793 
794   os << sign << "0x" << (is_zero ? '0' : '1');
795   if (fraction_nibbles) {
796     // Make sure to keep the leading 0s in place, since this is the fractional
797     // part.
798     os << "." << std::setw(static_cast<int>(fraction_nibbles))
799        << std::setfill('0') << std::hex << fraction;
800   }
801   os << "p" << std::dec << (int_exponent >= 0 ? "+" : "") << int_exponent;
802 
803   os.flags(saved_flags);
804   os.fill(saved_fill);
805 
806   return os;
807 }
808 
809 // Returns true if negate_value is true and the next character on the
810 // input stream is a plus or minus sign.  In that case we also set the fail bit
811 // on the stream and set the value to the zero value for its type.
812 template <typename T, typename Traits>
813 inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value,
814                                         HexFloat<T, Traits>& value) {
815   if (negate_value) {
816     auto next_char = is.peek();
817     if (next_char == '-' || next_char == '+') {
818       // Fail the parse.  Emulate standard behaviour by setting the value to
819       // the zero value, and set the fail bit on the stream.
820       value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0});
821       is.setstate(std::ios_base::failbit);
822       return true;
823     }
824   }
825   return false;
826 }
827 
828 // Parses a floating point number from the given stream and stores it into the
829 // value parameter.
830 // If negate_value is true then the number may not have a leading minus or
831 // plus, and if it successfully parses, then the number is negated before
832 // being stored into the value parameter.
833 // If the value cannot be correctly parsed or overflows the target floating
834 // point type, then set the fail bit on the stream.
835 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
836 // the error case, but only after all target platforms implement it correctly.
837 // In particular, the Microsoft C++ runtime appears to be out of spec.
838 template <typename T, typename Traits>
839 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
840                                       HexFloat<T, Traits>& value) {
841   if (RejectParseDueToLeadingSign(is, negate_value, value)) {
842     return is;
843   }
844   T val;
845   is >> val;
846   if (negate_value) {
847     val = -val;
848   }
849   value.set_value(val);
850   // In the failure case, map -0.0 to 0.0.
851   if (is.fail() && value.getUnsignedBits() == 0u) {
852     value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0});
853   }
854   if (val.isInfinity()) {
855     // Fail the parse.  Emulate standard behaviour by setting the value to
856     // the closest normal value, and set the fail bit on the stream.
857     value.set_value((value.isNegative() | negate_value) ? T::lowest()
858                                                         : T::max());
859     is.setstate(std::ios_base::failbit);
860   }
861   return is;
862 }
863 
864 // Specialization of ParseNormalFloat for FloatProxy<Float16> values.
865 // This will parse the float as it were a 32-bit floating point number,
866 // and then round it down to fit into a Float16 value.
867 // The number is rounded towards zero.
868 // If negate_value is true then the number may not have a leading minus or
869 // plus, and if it successfully parses, then the number is negated before
870 // being stored into the value parameter.
871 // If the value cannot be correctly parsed or overflows the target floating
872 // point type, then set the fail bit on the stream.
873 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
874 // the error case, but only after all target platforms implement it correctly.
875 // In particular, the Microsoft C++ runtime appears to be out of spec.
876 template <>
877 inline std::istream&
878 ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
879     std::istream& is, bool negate_value,
880     HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
881   // First parse as a 32-bit float.
882   HexFloat<FloatProxy<float>> float_val(0.0f);
883   ParseNormalFloat(is, negate_value, float_val);
884 
885   // Then convert to 16-bit float, saturating at infinities, and
886   // rounding toward zero.
887   float_val.castTo(value, round_direction::kToZero);
888 
889   // Overflow on 16-bit behaves the same as for 32- and 64-bit: set the
890   // fail bit and set the lowest or highest value.
891   if (Float16::isInfinity(value.value().getAsFloat())) {
892     value.set_value(value.isNegative() ? Float16::lowest() : Float16::max());
893     is.setstate(std::ios_base::failbit);
894   }
895   return is;
896 }
897 
898 // Reads a HexFloat from the given stream.
899 // If the float is not encoded as a hex-float then it will be parsed
900 // as a regular float.
901 // This may fail if your stream does not support at least one unget.
902 // Nan values can be encoded with "0x1.<not zero>p+exponent_bias".
903 // This would normally overflow a float and round to
904 // infinity but this special pattern is the exact representation for a NaN,
905 // and therefore is actually encoded as the correct NaN. To encode inf,
906 // either 0x0p+exponent_bias can be specified or any exponent greater than
907 // exponent_bias.
908 // Examples using IEEE 32-bit float encoding.
909 //    0x1.0p+128 (+inf)
910 //    -0x1.0p-128 (-inf)
911 //
912 //    0x1.1p+128 (+Nan)
913 //    -0x1.1p+128 (-Nan)
914 //
915 //    0x1p+129 (+inf)
916 //    -0x1p+129 (-inf)
917 template <typename T, typename Traits>
918 std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
919   using HF = HexFloat<T, Traits>;
920   using uint_type = typename HF::uint_type;
921   using int_type = typename HF::int_type;
922 
923   value.set_value(static_cast<typename HF::native_type>(0.f));
924 
925   if (is.flags() & std::ios::skipws) {
926     // If the user wants to skip whitespace , then we should obey that.
927     while (std::isspace(is.peek())) {
928       is.get();
929     }
930   }
931 
932   auto next_char = is.peek();
933   bool negate_value = false;
934 
935   if (next_char != '-' && next_char != '0') {
936     return ParseNormalFloat(is, negate_value, value);
937   }
938 
939   if (next_char == '-') {
940     negate_value = true;
941     is.get();
942     next_char = is.peek();
943   }
944 
945   if (next_char == '0') {
946     is.get();  // We may have to unget this.
947     auto maybe_hex_start = is.peek();
948     if (maybe_hex_start != 'x' && maybe_hex_start != 'X') {
949       is.unget();
950       return ParseNormalFloat(is, negate_value, value);
951     } else {
952       is.get();  // Throw away the 'x';
953     }
954   } else {
955     return ParseNormalFloat(is, negate_value, value);
956   }
957 
958   // This "looks" like a hex-float so treat it as one.
959   bool seen_p = false;
960   bool seen_dot = false;
961   uint_type fraction_index = 0;
962 
963   uint_type fraction = 0;
964   int_type exponent = HF::exponent_bias;
965 
966   // Strip off leading zeros so we don't have to special-case them later.
967   while ((next_char = is.peek()) == '0') {
968     is.get();
969   }
970 
971   bool is_denorm =
972       true;  // Assume denorm "representation" until we hear otherwise.
973              // NB: This does not mean the value is actually denorm,
974              // it just means that it was written 0.
975   bool bits_written = false;  // Stays false until we write a bit.
976   while (!seen_p && !seen_dot) {
977     // Handle characters that are left of the fractional part.
978     if (next_char == '.') {
979       seen_dot = true;
980     } else if (next_char == 'p') {
981       seen_p = true;
982     } else if (::isxdigit(next_char)) {
983       // We know this is not denormalized since we have stripped all leading
984       // zeroes and we are not a ".".
985       is_denorm = false;
986       int number = get_nibble_from_character(next_char);
987       for (int i = 0; i < 4; ++i, number <<= 1) {
988         uint_type write_bit = (number & 0x8) ? 0x1 : 0x0;
989         if (bits_written) {
990           // If we are here the bits represented belong in the fractional
991           // part of the float, and we have to adjust the exponent accordingly.
992           fraction = static_cast<uint_type>(
993               fraction |
994               static_cast<uint_type>(
995                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
996           exponent = static_cast<int_type>(exponent + 1);
997         }
998         bits_written |= write_bit != 0;
999       }
1000     } else {
1001       // We have not found our exponent yet, so we have to fail.
1002       is.setstate(std::ios::failbit);
1003       return is;
1004     }
1005     is.get();
1006     next_char = is.peek();
1007   }
1008   bits_written = false;
1009   while (seen_dot && !seen_p) {
1010     // Handle only fractional parts now.
1011     if (next_char == 'p') {
1012       seen_p = true;
1013     } else if (::isxdigit(next_char)) {
1014       int number = get_nibble_from_character(next_char);
1015       for (int i = 0; i < 4; ++i, number <<= 1) {
1016         uint_type write_bit = (number & 0x8) ? 0x01 : 0x00;
1017         bits_written |= write_bit != 0;
1018         if (is_denorm && !bits_written) {
1019           // Handle modifying the exponent here this way we can handle
1020           // an arbitrary number of hex values without overflowing our
1021           // integer.
1022           exponent = static_cast<int_type>(exponent - 1);
1023         } else {
1024           fraction = static_cast<uint_type>(
1025               fraction |
1026               static_cast<uint_type>(
1027                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
1028         }
1029       }
1030     } else {
1031       // We still have not found our 'p' exponent yet, so this is not a valid
1032       // hex-float.
1033       is.setstate(std::ios::failbit);
1034       return is;
1035     }
1036     is.get();
1037     next_char = is.peek();
1038   }
1039 
1040   bool seen_sign = false;
1041   int8_t exponent_sign = 1;
1042   int_type written_exponent = 0;
1043   while (true) {
1044     if ((next_char == '-' || next_char == '+')) {
1045       if (seen_sign) {
1046         is.setstate(std::ios::failbit);
1047         return is;
1048       }
1049       seen_sign = true;
1050       exponent_sign = (next_char == '-') ? -1 : 1;
1051     } else if (::isdigit(next_char)) {
1052       // Hex-floats express their exponent as decimal.
1053       written_exponent = static_cast<int_type>(written_exponent * 10);
1054       written_exponent =
1055           static_cast<int_type>(written_exponent + (next_char - '0'));
1056     } else {
1057       break;
1058     }
1059     is.get();
1060     next_char = is.peek();
1061   }
1062 
1063   written_exponent = static_cast<int_type>(written_exponent * exponent_sign);
1064   exponent = static_cast<int_type>(exponent + written_exponent);
1065 
1066   bool is_zero = is_denorm && (fraction == 0);
1067   if (is_denorm && !is_zero) {
1068     fraction = static_cast<uint_type>(fraction << 1);
1069     exponent = static_cast<int_type>(exponent - 1);
1070   } else if (is_zero) {
1071     exponent = 0;
1072   }
1073 
1074   if (exponent <= 0 && !is_zero) {
1075     fraction = static_cast<uint_type>(fraction >> 1);
1076     fraction |= static_cast<uint_type>(1) << HF::top_bit_left_shift;
1077   }
1078 
1079   fraction = (fraction >> HF::fraction_right_shift) & HF::fraction_encode_mask;
1080 
1081   const int_type max_exponent =
1082       SetBits<uint_type, 0, HF::num_exponent_bits>::get;
1083 
1084   // Handle actual denorm numbers
1085   while (exponent < 0 && !is_zero) {
1086     fraction = static_cast<uint_type>(fraction >> 1);
1087     exponent = static_cast<int_type>(exponent + 1);
1088 
1089     fraction &= HF::fraction_encode_mask;
1090     if (fraction == 0) {
1091       // We have underflowed our fraction. We should clamp to zero.
1092       is_zero = true;
1093       exponent = 0;
1094     }
1095   }
1096 
1097   // We have overflowed so we should be inf/-inf.
1098   if (exponent > max_exponent) {
1099     exponent = max_exponent;
1100     fraction = 0;
1101   }
1102 
1103   uint_type output_bits = static_cast<uint_type>(
1104       static_cast<uint_type>(negate_value ? 1 : 0) << HF::top_bit_left_shift);
1105   output_bits |= fraction;
1106 
1107   uint_type shifted_exponent = static_cast<uint_type>(
1108       static_cast<uint_type>(exponent << HF::exponent_left_shift) &
1109       HF::exponent_mask);
1110   output_bits |= shifted_exponent;
1111 
1112   T output_float(output_bits);
1113   value.set_value(output_float);
1114 
1115   return is;
1116 }
1117 
1118 // Writes a FloatProxy value to a stream.
1119 // Zero and normal numbers are printed in the usual notation, but with
1120 // enough digits to fully reproduce the value.  Other values (subnormal,
1121 // NaN, and infinity) are printed as a hex float.
1122 template <typename T>
1123 std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
1124   auto float_val = value.getAsFloat();
1125   switch (std::fpclassify(float_val)) {
1126     case FP_ZERO:
1127     case FP_NORMAL: {
1128       auto saved_precision = os.precision();
1129       os.precision(std::numeric_limits<T>::max_digits10);
1130       os << float_val;
1131       os.precision(saved_precision);
1132     } break;
1133     default:
1134       os << HexFloat<FloatProxy<T>>(value);
1135       break;
1136   }
1137   return os;
1138 }
1139 
1140 template <>
1141 inline std::ostream& operator<<<Float16>(std::ostream& os,
1142                                          const FloatProxy<Float16>& value) {
1143   os << HexFloat<FloatProxy<Float16>>(value);
1144   return os;
1145 }
1146 
1147 }  // namespace utils
1148 }  // namespace spvtools
1149 
1150 #endif  // SOURCE_UTIL_HEX_FLOAT_H_
1151