• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2015-2016 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_UTIL_HEX_FLOAT_H_
16 #define SOURCE_UTIL_HEX_FLOAT_H_
17 
18 #include <cassert>
19 #include <cctype>
20 #include <cmath>
21 #include <cstdint>
22 #include <iomanip>
23 #include <limits>
24 #include <sstream>
25 #include <vector>
26 
27 #include "source/util/bitutils.h"
28 
29 #ifndef __GNUC__
30 #define GCC_VERSION 0
31 #else
32 #define GCC_VERSION \
33   (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
34 #endif
35 
36 namespace spvtools {
37 namespace utils {
38 
39 class Float16 {
40  public:
Float16(uint16_t v)41   Float16(uint16_t v) : val(v) {}
42   Float16() = default;
isNan(const Float16 & val)43   static bool isNan(const Float16& val) {
44     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0);
45   }
46   // Returns true if the given value is any kind of infinity.
isInfinity(const Float16 & val)47   static bool isInfinity(const Float16& val) {
48     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) == 0);
49   }
Float16(const Float16 & other)50   Float16(const Float16& other) { val = other.val; }
get_value()51   uint16_t get_value() const { return val; }
52 
53   // Returns the maximum normal value.
max()54   static Float16 max() { return Float16(0x7bff); }
55   // Returns the lowest normal value.
lowest()56   static Float16 lowest() { return Float16(0xfbff); }
57 
58  private:
59   uint16_t val;
60 };
61 
62 // To specialize this type, you must override uint_type to define
63 // an unsigned integer that can fit your floating point type.
64 // You must also add a isNan function that returns true if
65 // a value is Nan.
66 template <typename T>
67 struct FloatProxyTraits {
68   using uint_type = void;
69 };
70 
71 template <>
72 struct FloatProxyTraits<float> {
73   using uint_type = uint32_t;
74   static bool isNan(float f) { return std::isnan(f); }
75   // Returns true if the given value is any kind of infinity.
76   static bool isInfinity(float f) { return std::isinf(f); }
77   // Returns the maximum normal value.
78   static float max() { return std::numeric_limits<float>::max(); }
79   // Returns the lowest normal value.
80   static float lowest() { return std::numeric_limits<float>::lowest(); }
81   // Returns the value as the native floating point format.
82   static float getAsFloat(const uint_type& t) { return BitwiseCast<float>(t); }
83   // Returns the bits from the given floating pointer number.
84   static uint_type getBitsFromFloat(const float& t) {
85     return BitwiseCast<uint_type>(t);
86   }
87   // Returns the bitwidth.
88   static uint32_t width() { return 32u; }
89 };
90 
91 template <>
92 struct FloatProxyTraits<double> {
93   using uint_type = uint64_t;
94   static bool isNan(double f) { return std::isnan(f); }
95   // Returns true if the given value is any kind of infinity.
96   static bool isInfinity(double f) { return std::isinf(f); }
97   // Returns the maximum normal value.
98   static double max() { return std::numeric_limits<double>::max(); }
99   // Returns the lowest normal value.
100   static double lowest() { return std::numeric_limits<double>::lowest(); }
101   // Returns the value as the native floating point format.
102   static double getAsFloat(const uint_type& t) {
103     return BitwiseCast<double>(t);
104   }
105   // Returns the bits from the given floating pointer number.
106   static uint_type getBitsFromFloat(const double& t) {
107     return BitwiseCast<uint_type>(t);
108   }
109   // Returns the bitwidth.
110   static uint32_t width() { return 64u; }
111 };
112 
113 template <>
114 struct FloatProxyTraits<Float16> {
115   using uint_type = uint16_t;
116   static bool isNan(Float16 f) { return Float16::isNan(f); }
117   // Returns true if the given value is any kind of infinity.
118   static bool isInfinity(Float16 f) { return Float16::isInfinity(f); }
119   // Returns the maximum normal value.
120   static Float16 max() { return Float16::max(); }
121   // Returns the lowest normal value.
122   static Float16 lowest() { return Float16::lowest(); }
123   // Returns the value as the native floating point format.
124   static Float16 getAsFloat(const uint_type& t) { return Float16(t); }
125   // Returns the bits from the given floating pointer number.
126   static uint_type getBitsFromFloat(const Float16& t) { return t.get_value(); }
127   // Returns the bitwidth.
128   static uint32_t width() { return 16u; }
129 };
130 
131 // Since copying a floating point number (especially if it is NaN)
132 // does not guarantee that bits are preserved, this class lets us
133 // store the type and use it as a float when necessary.
134 template <typename T>
135 class FloatProxy {
136  public:
137   using uint_type = typename FloatProxyTraits<T>::uint_type;
138 
139   // Since this is to act similar to the normal floats,
140   // do not initialize the data by default.
141   FloatProxy() = default;
142 
143   // Intentionally non-explicit. This is a proxy type so
144   // implicit conversions allow us to use it more transparently.
145   FloatProxy(T val) { data_ = FloatProxyTraits<T>::getBitsFromFloat(val); }
146 
147   // Intentionally non-explicit. This is a proxy type so
148   // implicit conversions allow us to use it more transparently.
149   FloatProxy(uint_type val) { data_ = val; }
150 
151   // This is helpful to have and is guaranteed not to stomp bits.
152   FloatProxy<T> operator-() const {
153     return static_cast<uint_type>(data_ ^
154                                   (uint_type(0x1) << (sizeof(T) * 8 - 1)));
155   }
156 
157   // Returns the data as a floating point value.
158   T getAsFloat() const { return FloatProxyTraits<T>::getAsFloat(data_); }
159 
160   // Returns the raw data.
161   uint_type data() const { return data_; }
162 
163   // Returns a vector of words suitable for use in an Operand.
164   std::vector<uint32_t> GetWords() const {
165     std::vector<uint32_t> words;
166     if (FloatProxyTraits<T>::width() == 64) {
167       FloatProxyTraits<double>::uint_type d = data();
168       words.push_back(static_cast<uint32_t>(d));
169       words.push_back(static_cast<uint32_t>(d >> 32));
170     } else {
171       words.push_back(static_cast<uint32_t>(data()));
172     }
173     return words;
174   }
175 
176   // Returns true if the value represents any type of NaN.
177   bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
178   // Returns true if the value represents any type of infinity.
179   bool isInfinity() { return FloatProxyTraits<T>::isInfinity(getAsFloat()); }
180 
181   // Returns the maximum normal value.
182   static FloatProxy<T> max() {
183     return FloatProxy<T>(FloatProxyTraits<T>::max());
184   }
185   // Returns the lowest normal value.
186   static FloatProxy<T> lowest() {
187     return FloatProxy<T>(FloatProxyTraits<T>::lowest());
188   }
189 
190  private:
191   uint_type data_;
192 };
193 
194 template <typename T>
195 bool operator==(const FloatProxy<T>& first, const FloatProxy<T>& second) {
196   return first.data() == second.data();
197 }
198 
199 // Reads a FloatProxy value as a normal float from a stream.
200 template <typename T>
201 std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
202   T float_val = static_cast<T>(0.0);
203   is >> float_val;
204   value = FloatProxy<T>(float_val);
205   return is;
206 }
207 
208 // This is an example traits. It is not meant to be used in practice, but will
209 // be the default for any non-specialized type.
210 template <typename T>
211 struct HexFloatTraits {
212   // Integer type that can store the bit representation of this hex-float.
213   using uint_type = void;
214   // Signed integer type that can store the bit representation of this
215   // hex-float.
216   using int_type = void;
217   // The numerical type that this HexFloat represents.
218   using underlying_type = void;
219   // The type needed to construct the underlying type.
220   using native_type = void;
221   // The number of bits that are actually relevant in the uint_type.
222   // This allows us to deal with, for example, 24-bit values in a 32-bit
223   // integer.
224   static const uint32_t num_used_bits = 0;
225   // Number of bits that represent the exponent.
226   static const uint32_t num_exponent_bits = 0;
227   // Number of bits that represent the fractional part.
228   static const uint32_t num_fraction_bits = 0;
229   // The bias of the exponent. (How much we need to subtract from the stored
230   // value to get the correct value.)
231   static const uint32_t exponent_bias = 0;
232 };
233 
234 // Traits for IEEE float.
235 // 1 sign bit, 8 exponent bits, 23 fractional bits.
236 template <>
237 struct HexFloatTraits<FloatProxy<float>> {
238   using uint_type = uint32_t;
239   using int_type = int32_t;
240   using underlying_type = FloatProxy<float>;
241   using native_type = float;
242   static const uint_type num_used_bits = 32;
243   static const uint_type num_exponent_bits = 8;
244   static const uint_type num_fraction_bits = 23;
245   static const uint_type exponent_bias = 127;
246 };
247 
248 // Traits for IEEE double.
249 // 1 sign bit, 11 exponent bits, 52 fractional bits.
250 template <>
251 struct HexFloatTraits<FloatProxy<double>> {
252   using uint_type = uint64_t;
253   using int_type = int64_t;
254   using underlying_type = FloatProxy<double>;
255   using native_type = double;
256   static const uint_type num_used_bits = 64;
257   static const uint_type num_exponent_bits = 11;
258   static const uint_type num_fraction_bits = 52;
259   static const uint_type exponent_bias = 1023;
260 };
261 
262 // Traits for IEEE half.
263 // 1 sign bit, 5 exponent bits, 10 fractional bits.
264 template <>
265 struct HexFloatTraits<FloatProxy<Float16>> {
266   using uint_type = uint16_t;
267   using int_type = int16_t;
268   using underlying_type = uint16_t;
269   using native_type = uint16_t;
270   static const uint_type num_used_bits = 16;
271   static const uint_type num_exponent_bits = 5;
272   static const uint_type num_fraction_bits = 10;
273   static const uint_type exponent_bias = 15;
274 };
275 
276 enum class round_direction {
277   kToZero,
278   kToNearestEven,
279   kToPositiveInfinity,
280   kToNegativeInfinity,
281   max = kToNegativeInfinity
282 };
283 
284 // Template class that houses a floating pointer number.
285 // It exposes a number of constants based on the provided traits to
286 // assist in interpreting the bits of the value.
287 template <typename T, typename Traits = HexFloatTraits<T>>
288 class HexFloat {
289  public:
290   using uint_type = typename Traits::uint_type;
291   using int_type = typename Traits::int_type;
292   using underlying_type = typename Traits::underlying_type;
293   using native_type = typename Traits::native_type;
294 
295   explicit HexFloat(T f) : value_(f) {}
296 
297   T value() const { return value_; }
298   void set_value(T f) { value_ = f; }
299 
300   // These are all written like this because it is convenient to have
301   // compile-time constants for all of these values.
302 
303   // Pass-through values to save typing.
304   static const uint32_t num_used_bits = Traits::num_used_bits;
305   static const uint32_t exponent_bias = Traits::exponent_bias;
306   static const uint32_t num_exponent_bits = Traits::num_exponent_bits;
307   static const uint32_t num_fraction_bits = Traits::num_fraction_bits;
308 
309   // Number of bits to shift left to set the highest relevant bit.
310   static const uint32_t top_bit_left_shift = num_used_bits - 1;
311   // How many nibbles (hex characters) the fractional part takes up.
312   static const uint32_t fraction_nibbles = (num_fraction_bits + 3) / 4;
313   // If the fractional part does not fit evenly into a hex character (4-bits)
314   // then we have to left-shift to get rid of leading 0s. This is the amount
315   // we have to shift (might be 0).
316   static const uint32_t num_overflow_bits =
317       fraction_nibbles * 4 - num_fraction_bits;
318 
319   // The representation of the fraction, not the actual bits. This
320   // includes the leading bit that is usually implicit.
321   static const uint_type fraction_represent_mask =
322       SetBits<uint_type, 0, num_fraction_bits + num_overflow_bits>::get;
323 
324   // The topmost bit in the nibble-aligned fraction.
325   static const uint_type fraction_top_bit =
326       uint_type(1) << (num_fraction_bits + num_overflow_bits - 1);
327 
328   // The least significant bit in the exponent, which is also the bit
329   // immediately to the left of the significand.
330   static const uint_type first_exponent_bit = uint_type(1)
331                                               << (num_fraction_bits);
332 
333   // The mask for the encoded fraction. It does not include the
334   // implicit bit.
335   static const uint_type fraction_encode_mask =
336       SetBits<uint_type, 0, num_fraction_bits>::get;
337 
338   // The bit that is used as a sign.
339   static const uint_type sign_mask = uint_type(1) << top_bit_left_shift;
340 
341   // The bits that represent the exponent.
342   static const uint_type exponent_mask =
343       SetBits<uint_type, num_fraction_bits, num_exponent_bits>::get;
344 
345   // How far left the exponent is shifted.
346   static const uint32_t exponent_left_shift = num_fraction_bits;
347 
348   // How far from the right edge the fraction is shifted.
349   static const uint32_t fraction_right_shift =
350       static_cast<uint32_t>(sizeof(uint_type) * 8) - num_fraction_bits;
351 
352   // The maximum representable unbiased exponent.
353   static const int_type max_exponent =
354       (exponent_mask >> num_fraction_bits) - exponent_bias;
355   // The minimum representable exponent for normalized numbers.
356   static const int_type min_exponent = -static_cast<int_type>(exponent_bias);
357 
358   // Returns the bits associated with the value.
359   uint_type getBits() const { return value_.data(); }
360 
361   // Returns the bits associated with the value, without the leading sign bit.
362   uint_type getUnsignedBits() const {
363     return static_cast<uint_type>(value_.data() & ~sign_mask);
364   }
365 
366   // Returns the bits associated with the exponent, shifted to start at the
367   // lsb of the type.
368   const uint_type getExponentBits() const {
369     return static_cast<uint_type>((getBits() & exponent_mask) >>
370                                   num_fraction_bits);
371   }
372 
373   // Returns the exponent in unbiased form. This is the exponent in the
374   // human-friendly form.
375   const int_type getUnbiasedExponent() const {
376     return static_cast<int_type>(getExponentBits() - exponent_bias);
377   }
378 
379   // Returns just the significand bits from the value.
380   const uint_type getSignificandBits() const {
381     return getBits() & fraction_encode_mask;
382   }
383 
384   // If the number was normalized, returns the unbiased exponent.
385   // If the number was denormal, normalize the exponent first.
386   const int_type getUnbiasedNormalizedExponent() const {
387     if ((getBits() & ~sign_mask) == 0) {  // special case if everything is 0
388       return 0;
389     }
390     int_type exp = getUnbiasedExponent();
391     if (exp == min_exponent) {  // We are in denorm land.
392       uint_type significand_bits = getSignificandBits();
393       while ((significand_bits & (first_exponent_bit >> 1)) == 0) {
394         significand_bits = static_cast<uint_type>(significand_bits << 1);
395         exp = static_cast<int_type>(exp - 1);
396       }
397       significand_bits &= fraction_encode_mask;
398     }
399     return exp;
400   }
401 
402   // Returns the signficand after it has been normalized.
403   const uint_type getNormalizedSignificand() const {
404     int_type unbiased_exponent = getUnbiasedNormalizedExponent();
405     uint_type significand = getSignificandBits();
406     for (int_type i = unbiased_exponent; i <= min_exponent; ++i) {
407       significand = static_cast<uint_type>(significand << 1);
408     }
409     significand &= fraction_encode_mask;
410     return significand;
411   }
412 
413   // Returns true if this number represents a negative value.
414   bool isNegative() const { return (getBits() & sign_mask) != 0; }
415 
416   // Sets this HexFloat from the individual components.
417   // Note this assumes EVERY significand is normalized, and has an implicit
418   // leading one. This means that the only way that this method will set 0,
419   // is if you set a number so denormalized that it underflows.
420   // Do not use this method with raw bits extracted from a subnormal number,
421   // since subnormals do not have an implicit leading 1 in the significand.
422   // The significand is also expected to be in the
423   // lowest-most num_fraction_bits of the uint_type.
424   // The exponent is expected to be unbiased, meaning an exponent of
425   // 0 actually means 0.
426   // If underflow_round_up is set, then on underflow, if a number is non-0
427   // and would underflow, we round up to the smallest denorm.
428   void setFromSignUnbiasedExponentAndNormalizedSignificand(
429       bool negative, int_type exponent, uint_type significand,
430       bool round_denorm_up) {
431     bool significand_is_zero = significand == 0;
432 
433     if (exponent <= min_exponent) {
434       // If this was denormalized, then we have to shift the bit on, meaning
435       // the significand is not zero.
436       significand_is_zero = false;
437       significand |= first_exponent_bit;
438       significand = static_cast<uint_type>(significand >> 1);
439     }
440 
441     while (exponent < min_exponent) {
442       significand = static_cast<uint_type>(significand >> 1);
443       ++exponent;
444     }
445 
446     if (exponent == min_exponent) {
447       if (significand == 0 && !significand_is_zero && round_denorm_up) {
448         significand = static_cast<uint_type>(0x1);
449       }
450     }
451 
452     uint_type new_value = 0;
453     if (negative) {
454       new_value = static_cast<uint_type>(new_value | sign_mask);
455     }
456     exponent = static_cast<int_type>(exponent + exponent_bias);
457     assert(exponent >= 0);
458 
459     // put it all together
460     exponent = static_cast<uint_type>((exponent << exponent_left_shift) &
461                                       exponent_mask);
462     significand = static_cast<uint_type>(significand & fraction_encode_mask);
463     new_value = static_cast<uint_type>(new_value | (exponent | significand));
464     value_ = T(new_value);
465   }
466 
467   // Increments the significand of this number by the given amount.
468   // If this would spill the significand into the implicit bit,
469   // carry is set to true and the significand is shifted to fit into
470   // the correct location, otherwise carry is set to false.
471   // All significands and to_increment are assumed to be within the bounds
472   // for a valid significand.
473   static uint_type incrementSignificand(uint_type significand,
474                                         uint_type to_increment, bool* carry) {
475     significand = static_cast<uint_type>(significand + to_increment);
476     *carry = false;
477     if (significand & first_exponent_bit) {
478       *carry = true;
479       // The implicit 1-bit will have carried, so we should zero-out the
480       // top bit and shift back.
481       significand = static_cast<uint_type>(significand & ~first_exponent_bit);
482       significand = static_cast<uint_type>(significand >> 1);
483     }
484     return significand;
485   }
486 
487 #if GCC_VERSION == 40801
488   // These exist because MSVC throws warnings on negative right-shifts
489   // even if they are not going to be executed. Eg:
490   // constant_number < 0? 0: constant_number
491   // These convert the negative left-shifts into right shifts.
492   template <int_type N>
493   struct negatable_left_shift {
494     static uint_type val(uint_type val) {
495       if (N > 0) {
496         return static_cast<uint_type>(val << N);
497       } else {
498         return static_cast<uint_type>(val >> N);
499       }
500     }
501   };
502 
503   template <int_type N>
504   struct negatable_right_shift {
505     static uint_type val(uint_type val) {
506       if (N > 0) {
507         return static_cast<uint_type>(val >> N);
508       } else {
509         return static_cast<uint_type>(val << N);
510       }
511     }
512   };
513 
514 #else
515   // These exist because MSVC throws warnings on negative right-shifts
516   // even if they are not going to be executed. Eg:
517   // constant_number < 0? 0: constant_number
518   // These convert the negative left-shifts into right shifts.
519   template <int_type N, typename enable = void>
520   struct negatable_left_shift {
521     static uint_type val(uint_type val) {
522       return static_cast<uint_type>(val >> -N);
523     }
524   };
525 
526   template <int_type N>
527   struct negatable_left_shift<N, typename std::enable_if<N >= 0>::type> {
528     static uint_type val(uint_type val) {
529       return static_cast<uint_type>(val << N);
530     }
531   };
532 
533   template <int_type N, typename enable = void>
534   struct negatable_right_shift {
535     static uint_type val(uint_type val) {
536       return static_cast<uint_type>(val << -N);
537     }
538   };
539 
540   template <int_type N>
541   struct negatable_right_shift<N, typename std::enable_if<N >= 0>::type> {
542     static uint_type val(uint_type val) {
543       return static_cast<uint_type>(val >> N);
544     }
545   };
546 #endif
547 
548   // Returns the significand, rounded to fit in a significand in
549   // other_T. This is shifted so that the most significant
550   // bit of the rounded number lines up with the most significant bit
551   // of the returned significand.
552   template <typename other_T>
553   typename other_T::uint_type getRoundedNormalizedSignificand(
554       round_direction dir, bool* carry_bit) {
555     using other_uint_type = typename other_T::uint_type;
556     static const int_type num_throwaway_bits =
557         static_cast<int_type>(num_fraction_bits) -
558         static_cast<int_type>(other_T::num_fraction_bits);
559 
560     static const uint_type last_significant_bit =
561         (num_throwaway_bits < 0)
562             ? 0
563             : negatable_left_shift<num_throwaway_bits>::val(1u);
564     static const uint_type first_rounded_bit =
565         (num_throwaway_bits < 1)
566             ? 0
567             : negatable_left_shift<num_throwaway_bits - 1>::val(1u);
568 
569     static const uint_type throwaway_mask_bits =
570         num_throwaway_bits > 0 ? num_throwaway_bits : 0;
571     static const uint_type throwaway_mask =
572         SetBits<uint_type, 0, throwaway_mask_bits>::get;
573 
574     *carry_bit = false;
575     other_uint_type out_val = 0;
576     uint_type significand = getNormalizedSignificand();
577     // If we are up-casting, then we just have to shift to the right location.
578     if (num_throwaway_bits <= 0) {
579       out_val = static_cast<other_uint_type>(significand);
580       uint_type shift_amount = static_cast<uint_type>(-num_throwaway_bits);
581       out_val = static_cast<other_uint_type>(out_val << shift_amount);
582       return out_val;
583     }
584 
585     // If every non-representable bit is 0, then we don't have any casting to
586     // do.
587     if ((significand & throwaway_mask) == 0) {
588       return static_cast<other_uint_type>(
589           negatable_right_shift<num_throwaway_bits>::val(significand));
590     }
591 
592     bool round_away_from_zero = false;
593     // We actually have to narrow the significand here, so we have to follow the
594     // rounding rules.
595     switch (dir) {
596       case round_direction::kToZero:
597         break;
598       case round_direction::kToPositiveInfinity:
599         round_away_from_zero = !isNegative();
600         break;
601       case round_direction::kToNegativeInfinity:
602         round_away_from_zero = isNegative();
603         break;
604       case round_direction::kToNearestEven:
605         // Have to round down, round bit is 0
606         if ((first_rounded_bit & significand) == 0) {
607           break;
608         }
609         if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) {
610           // If any subsequent bit of the rounded portion is non-0 then we round
611           // up.
612           round_away_from_zero = true;
613           break;
614         }
615         // We are exactly half-way between 2 numbers, pick even.
616         if ((significand & last_significant_bit) != 0) {
617           // 1 for our last bit, round up.
618           round_away_from_zero = true;
619           break;
620         }
621         break;
622     }
623 
624     if (round_away_from_zero) {
625       return static_cast<other_uint_type>(
626           negatable_right_shift<num_throwaway_bits>::val(incrementSignificand(
627               significand, last_significant_bit, carry_bit)));
628     } else {
629       return static_cast<other_uint_type>(
630           negatable_right_shift<num_throwaway_bits>::val(significand));
631     }
632   }
633 
634   // Casts this value to another HexFloat. If the cast is widening,
635   // then round_dir is ignored. If the cast is narrowing, then
636   // the result is rounded in the direction specified.
637   // This number will retain Nan and Inf values.
638   // It will also saturate to Inf if the number overflows, and
639   // underflow to (0 or min depending on rounding) if the number underflows.
640   template <typename other_T>
641   void castTo(other_T& other, round_direction round_dir) {
642     other = other_T(static_cast<typename other_T::native_type>(0));
643     bool negate = isNegative();
644     if (getUnsignedBits() == 0) {
645       if (negate) {
646         other.set_value(-other.value());
647       }
648       return;
649     }
650     uint_type significand = getSignificandBits();
651     bool carried = false;
652     typename other_T::uint_type rounded_significand =
653         getRoundedNormalizedSignificand<other_T>(round_dir, &carried);
654 
655     int_type exponent = getUnbiasedExponent();
656     if (exponent == min_exponent) {
657       // If we are denormal, normalize the exponent, so that we can encode
658       // easily.
659       exponent = static_cast<int_type>(exponent + 1);
660       for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0;
661            check_bit = static_cast<uint_type>(check_bit >> 1)) {
662         exponent = static_cast<int_type>(exponent - 1);
663         if (check_bit & significand) break;
664       }
665     }
666 
667     bool is_nan =
668         (getBits() & exponent_mask) == exponent_mask && significand != 0;
669     bool is_inf =
670         !is_nan &&
671         ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
672          (significand == 0 && (getBits() & exponent_mask) == exponent_mask));
673 
674     // If we are Nan or Inf we should pass that through.
675     if (is_inf) {
676       other.set_value(typename other_T::underlying_type(
677           static_cast<typename other_T::uint_type>(
678               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
679       return;
680     }
681     if (is_nan) {
682       typename other_T::uint_type shifted_significand;
683       shifted_significand = static_cast<typename other_T::uint_type>(
684           negatable_left_shift<
685               static_cast<int_type>(other_T::num_fraction_bits) -
686               static_cast<int_type>(num_fraction_bits)>::val(significand));
687 
688       // We are some sort of Nan. We try to keep the bit-pattern of the Nan
689       // as close as possible. If we had to shift off bits so we are 0, then we
690       // just set the last bit.
691       other.set_value(typename other_T::underlying_type(
692           static_cast<typename other_T::uint_type>(
693               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask |
694               (shifted_significand == 0 ? 0x1 : shifted_significand))));
695       return;
696     }
697 
698     bool round_underflow_up =
699         isNegative() ? round_dir == round_direction::kToNegativeInfinity
700                      : round_dir == round_direction::kToPositiveInfinity;
701     using other_int_type = typename other_T::int_type;
702     // setFromSignUnbiasedExponentAndNormalizedSignificand will
703     // zero out any underflowing value (but retain the sign).
704     other.setFromSignUnbiasedExponentAndNormalizedSignificand(
705         negate, static_cast<other_int_type>(exponent), rounded_significand,
706         round_underflow_up);
707     return;
708   }
709 
710  private:
711   T value_;
712 
713   static_assert(num_used_bits ==
714                     Traits::num_exponent_bits + Traits::num_fraction_bits + 1,
715                 "The number of bits do not fit");
716   static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match");
717 };
718 
719 // Returns 4 bits represented by the hex character.
720 inline uint8_t get_nibble_from_character(int character) {
721   const char* dec = "0123456789";
722   const char* lower = "abcdef";
723   const char* upper = "ABCDEF";
724   const char* p = nullptr;
725   if ((p = strchr(dec, character))) {
726     return static_cast<uint8_t>(p - dec);
727   } else if ((p = strchr(lower, character))) {
728     return static_cast<uint8_t>(p - lower + 0xa);
729   } else if ((p = strchr(upper, character))) {
730     return static_cast<uint8_t>(p - upper + 0xa);
731   }
732 
733   assert(false && "This was called with a non-hex character");
734   return 0;
735 }
736 
737 // Outputs the given HexFloat to the stream.
738 template <typename T, typename Traits>
739 std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
740   using HF = HexFloat<T, Traits>;
741   using uint_type = typename HF::uint_type;
742   using int_type = typename HF::int_type;
743 
744   static_assert(HF::num_used_bits != 0,
745                 "num_used_bits must be non-zero for a valid float");
746   static_assert(HF::num_exponent_bits != 0,
747                 "num_exponent_bits must be non-zero for a valid float");
748   static_assert(HF::num_fraction_bits != 0,
749                 "num_fractin_bits must be non-zero for a valid float");
750 
751   const uint_type bits = value.value().data();
752   const char* const sign = (bits & HF::sign_mask) ? "-" : "";
753   const uint_type exponent = static_cast<uint_type>(
754       (bits & HF::exponent_mask) >> HF::num_fraction_bits);
755 
756   uint_type fraction = static_cast<uint_type>((bits & HF::fraction_encode_mask)
757                                               << HF::num_overflow_bits);
758 
759   const bool is_zero = exponent == 0 && fraction == 0;
760   const bool is_denorm = exponent == 0 && !is_zero;
761 
762   // exponent contains the biased exponent we have to convert it back into
763   // the normal range.
764   int_type int_exponent = static_cast<int_type>(exponent - HF::exponent_bias);
765   // If the number is all zeros, then we actually have to NOT shift the
766   // exponent.
767   int_exponent = is_zero ? 0 : int_exponent;
768 
769   // If we are denorm, then start shifting, and decreasing the exponent until
770   // our leading bit is 1.
771 
772   if (is_denorm) {
773     while ((fraction & HF::fraction_top_bit) == 0) {
774       fraction = static_cast<uint_type>(fraction << 1);
775       int_exponent = static_cast<int_type>(int_exponent - 1);
776     }
777     // Since this is denormalized, we have to consume the leading 1 since it
778     // will end up being implicit.
779     fraction = static_cast<uint_type>(fraction << 1);  // eat the leading 1
780     fraction &= HF::fraction_represent_mask;
781   }
782 
783   uint_type fraction_nibbles = HF::fraction_nibbles;
784   // We do not have to display any trailing 0s, since this represents the
785   // fractional part.
786   while (fraction_nibbles > 0 && (fraction & 0xF) == 0) {
787     // Shift off any trailing values;
788     fraction = static_cast<uint_type>(fraction >> 4);
789     --fraction_nibbles;
790   }
791 
792   const auto saved_flags = os.flags();
793   const auto saved_fill = os.fill();
794 
795   os << sign << "0x" << (is_zero ? '0' : '1');
796   if (fraction_nibbles) {
797     // Make sure to keep the leading 0s in place, since this is the fractional
798     // part.
799     os << "." << std::setw(static_cast<int>(fraction_nibbles))
800        << std::setfill('0') << std::hex << fraction;
801   }
802   os << "p" << std::dec << (int_exponent >= 0 ? "+" : "") << int_exponent;
803 
804   os.flags(saved_flags);
805   os.fill(saved_fill);
806 
807   return os;
808 }
809 
810 // Returns true if negate_value is true and the next character on the
811 // input stream is a plus or minus sign.  In that case we also set the fail bit
812 // on the stream and set the value to the zero value for its type.
813 template <typename T, typename Traits>
814 inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value,
815                                         HexFloat<T, Traits>& value) {
816   if (negate_value) {
817     auto next_char = is.peek();
818     if (next_char == '-' || next_char == '+') {
819       // Fail the parse.  Emulate standard behaviour by setting the value to
820       // the zero value, and set the fail bit on the stream.
821       value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0});
822       is.setstate(std::ios_base::failbit);
823       return true;
824     }
825   }
826   return false;
827 }
828 
829 // Parses a floating point number from the given stream and stores it into the
830 // value parameter.
831 // If negate_value is true then the number may not have a leading minus or
832 // plus, and if it successfully parses, then the number is negated before
833 // being stored into the value parameter.
834 // If the value cannot be correctly parsed or overflows the target floating
835 // point type, then set the fail bit on the stream.
836 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
837 // the error case, but only after all target platforms implement it correctly.
838 // In particular, the Microsoft C++ runtime appears to be out of spec.
839 template <typename T, typename Traits>
840 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
841                                       HexFloat<T, Traits>& value) {
842   if (RejectParseDueToLeadingSign(is, negate_value, value)) {
843     return is;
844   }
845   T val;
846   is >> val;
847   if (negate_value) {
848     val = -val;
849   }
850   value.set_value(val);
851   // In the failure case, map -0.0 to 0.0.
852   if (is.fail() && value.getUnsignedBits() == 0u) {
853     value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type{0});
854   }
855   if (val.isInfinity()) {
856     // Fail the parse.  Emulate standard behaviour by setting the value to
857     // the closest normal value, and set the fail bit on the stream.
858     value.set_value((value.isNegative() | negate_value) ? T::lowest()
859                                                         : T::max());
860     is.setstate(std::ios_base::failbit);
861   }
862   return is;
863 }
864 
865 // Specialization of ParseNormalFloat for FloatProxy<Float16> values.
866 // This will parse the float as it were a 32-bit floating point number,
867 // and then round it down to fit into a Float16 value.
868 // The number is rounded towards zero.
869 // If negate_value is true then the number may not have a leading minus or
870 // plus, and if it successfully parses, then the number is negated before
871 // being stored into the value parameter.
872 // If the value cannot be correctly parsed or overflows the target floating
873 // point type, then set the fail bit on the stream.
874 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
875 // the error case, but only after all target platforms implement it correctly.
876 // In particular, the Microsoft C++ runtime appears to be out of spec.
877 template <>
878 inline std::istream&
879 ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
880     std::istream& is, bool negate_value,
881     HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
882   // First parse as a 32-bit float.
883   HexFloat<FloatProxy<float>> float_val(0.0f);
884   ParseNormalFloat(is, negate_value, float_val);
885 
886   // Then convert to 16-bit float, saturating at infinities, and
887   // rounding toward zero.
888   float_val.castTo(value, round_direction::kToZero);
889 
890   // Overflow on 16-bit behaves the same as for 32- and 64-bit: set the
891   // fail bit and set the lowest or highest value.
892   if (Float16::isInfinity(value.value().getAsFloat())) {
893     value.set_value(value.isNegative() ? Float16::lowest() : Float16::max());
894     is.setstate(std::ios_base::failbit);
895   }
896   return is;
897 }
898 
899 namespace detail {
900 
901 // Returns a new value formed from 'value' by setting 'bit' that is the
902 // 'n'th most significant bit (where 0 is the most significant bit).
903 // If 'bit' is zero or 'n' is more than the number of bits in the integer
904 // type, then return the original value.
905 template <typename UINT_TYPE>
906 UINT_TYPE set_nth_most_significant_bit(UINT_TYPE value, UINT_TYPE bit,
907                                        UINT_TYPE n) {
908   constexpr UINT_TYPE max_position = std::numeric_limits<UINT_TYPE>::digits - 1;
909   if ((bit != 0) && (n <= max_position)) {
910     return static_cast<UINT_TYPE>(value | (bit << (max_position - n)));
911   }
912   return value;
913 }
914 
915 // Attempts to increment the argument.
916 // If it does not overflow, then increments the argument and returns true.
917 // If it would overflow, returns false.
918 template <typename INT_TYPE>
919 bool saturated_inc(INT_TYPE& value) {
920   if (value == std::numeric_limits<INT_TYPE>::max()) {
921     return false;
922   }
923   value++;
924   return true;
925 }
926 
927 // Attempts to decrement the argument.
928 // If it does not underflow, then decrements the argument and returns true.
929 // If it would overflow, returns false.
930 template <typename INT_TYPE>
931 bool saturated_dec(INT_TYPE& value) {
932   if (value == std::numeric_limits<INT_TYPE>::min()) {
933     return false;
934   }
935   value--;
936   return true;
937 }
938 }  // namespace detail
939 
940 // Reads a HexFloat from the given stream.
941 // If the float is not encoded as a hex-float then it will be parsed
942 // as a regular float.
943 // This may fail if your stream does not support at least one unget.
944 // Nan values can be encoded with "0x1.<not zero>p+exponent_bias".
945 // This would normally overflow a float and round to
946 // infinity but this special pattern is the exact representation for a NaN,
947 // and therefore is actually encoded as the correct NaN. To encode inf,
948 // either 0x0p+exponent_bias can be specified or any exponent greater than
949 // exponent_bias.
950 // Examples using IEEE 32-bit float encoding.
951 //    0x1.0p+128 (+inf)
952 //    -0x1.0p-128 (-inf)
953 //
954 //    0x1.1p+128 (+Nan)
955 //    -0x1.1p+128 (-Nan)
956 //
957 //    0x1p+129 (+inf)
958 //    -0x1p+129 (-inf)
959 template <typename T, typename Traits>
960 std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
961   using HF = HexFloat<T, Traits>;
962   using uint_type = typename HF::uint_type;
963   using int_type = typename HF::int_type;
964 
965   value.set_value(static_cast<typename HF::native_type>(0.f));
966 
967   if (is.flags() & std::ios::skipws) {
968     // If the user wants to skip whitespace , then we should obey that.
969     while (std::isspace(is.peek())) {
970       is.get();
971     }
972   }
973 
974   auto next_char = is.peek();
975   bool negate_value = false;
976 
977   if (next_char != '-' && next_char != '0') {
978     return ParseNormalFloat(is, negate_value, value);
979   }
980 
981   if (next_char == '-') {
982     negate_value = true;
983     is.get();
984     next_char = is.peek();
985   }
986 
987   if (next_char == '0') {
988     is.get();  // We may have to unget this.
989     auto maybe_hex_start = is.peek();
990     if (maybe_hex_start != 'x' && maybe_hex_start != 'X') {
991       is.unget();
992       return ParseNormalFloat(is, negate_value, value);
993     } else {
994       is.get();  // Throw away the 'x';
995     }
996   } else {
997     return ParseNormalFloat(is, negate_value, value);
998   }
999 
1000   // This "looks" like a hex-float so treat it as one.
1001   bool seen_p = false;
1002   bool seen_dot = false;
1003 
1004   // The mantissa bits, without the most significant 1 bit, and with the
1005   // the most recently read bits in the least significant positions.
1006   uint_type fraction = 0;
1007   // The number of mantissa bits that have been read, including the leading 1
1008   // bit that is not written into 'fraction'.
1009   uint_type fraction_index = 0;
1010 
1011   // TODO(dneto): handle overflow and underflow
1012   int_type exponent = HF::exponent_bias;
1013 
1014   // Strip off leading zeros so we don't have to special-case them later.
1015   while ((next_char = is.peek()) == '0') {
1016     is.get();
1017   }
1018 
1019   // Does the mantissa, as written, have non-zero digits to the left of
1020   // the decimal point.  Assume no until proven otherwise.
1021   bool has_integer_part = false;
1022   bool bits_written = false;  // Stays false until we write a bit.
1023 
1024   // Scan the mantissa hex digits until we see a '.' or the 'p' that
1025   // starts the exponent.
1026   while (!seen_p && !seen_dot) {
1027     // Handle characters that are left of the fractional part.
1028     if (next_char == '.') {
1029       seen_dot = true;
1030     } else if (next_char == 'p') {
1031       seen_p = true;
1032     } else if (::isxdigit(next_char)) {
1033       // We have stripped all leading zeroes and we have not yet seen a ".".
1034       has_integer_part = true;
1035       int number = get_nibble_from_character(next_char);
1036       for (int i = 0; i < 4; ++i, number <<= 1) {
1037         uint_type write_bit = (number & 0x8) ? 0x1 : 0x0;
1038         if (bits_written) {
1039           // If we are here the bits represented belong in the fractional
1040           // part of the float, and we have to adjust the exponent accordingly.
1041           fraction = detail::set_nth_most_significant_bit(fraction, write_bit,
1042                                                           fraction_index);
1043           // Increment the fraction index. If the input has bizarrely many
1044           // significant digits, then silently drop them.
1045           detail::saturated_inc(fraction_index);
1046           if (!detail::saturated_inc(exponent)) {
1047             // Overflow failure
1048             is.setstate(std::ios::failbit);
1049             return is;
1050           }
1051         }
1052         // Since this updated after setting fraction bits, this effectively
1053         // drops the leading 1 bit.
1054         bits_written |= write_bit != 0;
1055       }
1056     } else {
1057       // We have not found our exponent yet, so we have to fail.
1058       is.setstate(std::ios::failbit);
1059       return is;
1060     }
1061     is.get();
1062     next_char = is.peek();
1063   }
1064 
1065   // Finished reading the part preceding any '.' or 'p'.
1066 
1067   bits_written = false;
1068   while (seen_dot && !seen_p) {
1069     // Handle only fractional parts now.
1070     if (next_char == 'p') {
1071       seen_p = true;
1072     } else if (::isxdigit(next_char)) {
1073       int number = get_nibble_from_character(next_char);
1074       for (int i = 0; i < 4; ++i, number <<= 1) {
1075         uint_type write_bit = (number & 0x8) ? 0x01 : 0x00;
1076         bits_written |= write_bit != 0;
1077         if ((!has_integer_part) && !bits_written) {
1078           // Handle modifying the exponent here this way we can handle
1079           // an arbitrary number of hex values without overflowing our
1080           // integer.
1081           if (!detail::saturated_dec(exponent)) {
1082             // Overflow failure
1083             is.setstate(std::ios::failbit);
1084             return is;
1085           }
1086         } else {
1087           fraction = detail::set_nth_most_significant_bit(fraction, write_bit,
1088                                                           fraction_index);
1089           // Increment the fraction index. If the input has bizarrely many
1090           // significant digits, then silently drop them.
1091           detail::saturated_inc(fraction_index);
1092         }
1093       }
1094     } else {
1095       // We still have not found our 'p' exponent yet, so this is not a valid
1096       // hex-float.
1097       is.setstate(std::ios::failbit);
1098       return is;
1099     }
1100     is.get();
1101     next_char = is.peek();
1102   }
1103 
1104   // Finished reading the part preceding 'p'.
1105   // In hex floats syntax, the binary exponent is required.
1106 
1107   bool seen_exponent_sign = false;
1108   int8_t exponent_sign = 1;
1109   bool seen_written_exponent_digits = false;
1110   // The magnitude of the exponent, as written, or the sentinel value to signal
1111   // overflow.
1112   int_type written_exponent = 0;
1113   // A sentinel value signalling overflow of the magnitude of the written
1114   // exponent.  We'll assume that -written_exponent_overflow is valid for the
1115   // type. Later we may add 1 or subtract 1 from the adjusted exponent, so leave
1116   // room for an extra 1.
1117   const int_type written_exponent_overflow =
1118       std::numeric_limits<int_type>::max() - 1;
1119   while (true) {
1120     if (!seen_written_exponent_digits &&
1121         (next_char == '-' || next_char == '+')) {
1122       if (seen_exponent_sign) {
1123         is.setstate(std::ios::failbit);
1124         return is;
1125       }
1126       seen_exponent_sign = true;
1127       exponent_sign = (next_char == '-') ? -1 : 1;
1128     } else if (::isdigit(next_char)) {
1129       seen_written_exponent_digits = true;
1130       // Hex-floats express their exponent as decimal.
1131       int_type digit =
1132           static_cast<int_type>(static_cast<int_type>(next_char) - '0');
1133       if (written_exponent >= (written_exponent_overflow - digit) / 10) {
1134         // The exponent is very big. Saturate rather than overflow the exponent.
1135         // signed integer, which would be undefined behaviour.
1136         written_exponent = written_exponent_overflow;
1137       } else {
1138         written_exponent = static_cast<int_type>(
1139             static_cast<int_type>(written_exponent * 10) + digit);
1140       }
1141     } else {
1142       break;
1143     }
1144     is.get();
1145     next_char = is.peek();
1146   }
1147   if (!seen_written_exponent_digits) {
1148     // Binary exponent had no digits.
1149     is.setstate(std::ios::failbit);
1150     return is;
1151   }
1152 
1153   written_exponent = static_cast<int_type>(written_exponent * exponent_sign);
1154   // Now fold in the exponent bias into the written exponent, updating exponent.
1155   // But avoid undefined behaviour that would result from overflowing int_type.
1156   if (written_exponent >= 0 && exponent >= 0) {
1157     // Saturate up to written_exponent_overflow.
1158     if (written_exponent_overflow - exponent > written_exponent) {
1159       exponent = static_cast<int_type>(written_exponent + exponent);
1160     } else {
1161       exponent = written_exponent_overflow;
1162     }
1163   } else if (written_exponent < 0 && exponent < 0) {
1164     // Saturate down to -written_exponent_overflow.
1165     if (written_exponent_overflow + exponent > -written_exponent) {
1166       exponent = static_cast<int_type>(written_exponent + exponent);
1167     } else {
1168       exponent = static_cast<int_type>(-written_exponent_overflow);
1169     }
1170   } else {
1171     // They're of opposing sign, so it's safe to add.
1172     exponent = static_cast<int_type>(written_exponent + exponent);
1173   }
1174 
1175   bool is_zero = (!has_integer_part) && (fraction == 0);
1176   if ((!has_integer_part) && !is_zero) {
1177     fraction = static_cast<uint_type>(fraction << 1);
1178     exponent = static_cast<int_type>(exponent - 1);
1179   } else if (is_zero) {
1180     exponent = 0;
1181   }
1182 
1183   if (exponent <= 0 && !is_zero) {
1184     fraction = static_cast<uint_type>(fraction >> 1);
1185     fraction |= static_cast<uint_type>(1) << HF::top_bit_left_shift;
1186   }
1187 
1188   fraction = (fraction >> HF::fraction_right_shift) & HF::fraction_encode_mask;
1189 
1190   const int_type max_exponent =
1191       SetBits<uint_type, 0, HF::num_exponent_bits>::get;
1192 
1193   // Handle denorm numbers
1194   while (exponent < 0 && !is_zero) {
1195     fraction = static_cast<uint_type>(fraction >> 1);
1196     exponent = static_cast<int_type>(exponent + 1);
1197 
1198     fraction &= HF::fraction_encode_mask;
1199     if (fraction == 0) {
1200       // We have underflowed our fraction. We should clamp to zero.
1201       is_zero = true;
1202       exponent = 0;
1203     }
1204   }
1205 
1206   // We have overflowed so we should be inf/-inf.
1207   if (exponent > max_exponent) {
1208     exponent = max_exponent;
1209     fraction = 0;
1210   }
1211 
1212   uint_type output_bits = static_cast<uint_type>(
1213       static_cast<uint_type>(negate_value ? 1 : 0) << HF::top_bit_left_shift);
1214   output_bits |= fraction;
1215 
1216   uint_type shifted_exponent = static_cast<uint_type>(
1217       static_cast<uint_type>(exponent << HF::exponent_left_shift) &
1218       HF::exponent_mask);
1219   output_bits |= shifted_exponent;
1220 
1221   T output_float(output_bits);
1222   value.set_value(output_float);
1223 
1224   return is;
1225 }
1226 
1227 // Writes a FloatProxy value to a stream.
1228 // Zero and normal numbers are printed in the usual notation, but with
1229 // enough digits to fully reproduce the value.  Other values (subnormal,
1230 // NaN, and infinity) are printed as a hex float.
1231 template <typename T>
1232 std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
1233   auto float_val = value.getAsFloat();
1234   switch (std::fpclassify(float_val)) {
1235     case FP_ZERO:
1236     case FP_NORMAL: {
1237       auto saved_precision = os.precision();
1238       os.precision(std::numeric_limits<T>::max_digits10);
1239       os << float_val;
1240       os.precision(saved_precision);
1241     } break;
1242     default:
1243       os << HexFloat<FloatProxy<T>>(value);
1244       break;
1245   }
1246   return os;
1247 }
1248 
1249 template <>
1250 inline std::ostream& operator<<<Float16>(std::ostream& os,
1251                                          const FloatProxy<Float16>& value) {
1252   os << HexFloat<FloatProxy<Float16>>(value);
1253   return os;
1254 }
1255 
1256 }  // namespace utils
1257 }  // namespace spvtools
1258 
1259 #endif  // SOURCE_UTIL_HEX_FLOAT_H_
1260