1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <ostream> 9 #include <cmath> 10 #include <stdint.h> 11 12 namespace armnn 13 { 14 class BFloat16 15 { 16 public: BFloat16()17 BFloat16() 18 : m_Value(0) 19 {} 20 21 BFloat16(const BFloat16& v) = default; 22 BFloat16(uint16_t v)23 explicit BFloat16(uint16_t v) 24 : m_Value(v) 25 {} 26 BFloat16(float v)27 explicit BFloat16(float v) 28 { 29 m_Value = Float32ToBFloat16(v).Val(); 30 } 31 operator float() const32 operator float() const 33 { 34 return ToFloat32(); 35 } 36 37 BFloat16& operator=(const BFloat16& other) = default; 38 operator =(float v)39 BFloat16& operator=(float v) 40 { 41 m_Value = Float32ToBFloat16(v).Val(); 42 return *this; 43 } 44 operator ==(const BFloat16 & r) const45 bool operator==(const BFloat16& r) const 46 { 47 return m_Value == r.Val(); 48 } 49 Float32ToBFloat16(const float v)50 static BFloat16 Float32ToBFloat16(const float v) 51 { 52 if (std::isnan(v)) 53 { 54 return Nan(); 55 } 56 else 57 { 58 // Round value to the nearest even 59 // Float32 60 // S EEEEEEEE MMMMMMLRMMMMMMMMMMMMMMM 61 // BFloat16 62 // S EEEEEEEE MMMMMML 63 // LSB (L): Least significat bit of BFloat16 (last bit of the Mantissa of BFloat16) 64 // R: Rounding bit 65 // LSB = 0, R = 0 -> round down 66 // LSB = 1, R = 0 -> round down 67 // LSB = 0, R = 1, all the rest = 0 -> round down 68 // LSB = 1, R = 1 -> round up 69 // LSB = 0, R = 1 -> round up 70 const uint32_t* u32 = reinterpret_cast<const uint32_t*>(&v); 71 uint16_t u16 = static_cast<uint16_t>(*u32 >> 16u); 72 // Mark the LSB 73 const uint16_t lsb = u16 & 0x0001; 74 // Mark the error to be truncate (the rest of 16 bits of FP32) 75 const uint16_t error = static_cast<uint16_t>((*u32 & 0x0000FFFF)); 76 if ((error > 0x8000 || (error == 0x8000 && lsb == 1))) 77 { 78 u16++; 79 } 80 BFloat16 b(u16); 81 return b; 82 } 83 } 84 ToFloat32() const85 float ToFloat32() const 86 { 87 const uint32_t u32 = static_cast<uint32_t>(m_Value << 16u); 88 const float* f32 = reinterpret_cast<const float*>(&u32); 89 return *f32; 90 } 91 Val() const92 uint16_t Val() const 93 { 94 return m_Value; 95 } 96 Max()97 static BFloat16 Max() 98 { 99 uint16_t max = 0x7F7F; 100 return BFloat16(max); 101 } 102 Nan()103 static BFloat16 Nan() 104 { 105 uint16_t nan = 0x7FC0; 106 return BFloat16(nan); 107 } 108 Inf()109 static BFloat16 Inf() 110 { 111 uint16_t infVal = 0x7F80; 112 return BFloat16(infVal); 113 } 114 115 private: 116 uint16_t m_Value; 117 }; 118 operator <<(std::ostream & os,const BFloat16 & b)119inline std::ostream& operator<<(std::ostream& os, const BFloat16& b) 120 { 121 os << b.ToFloat32() << "(0x" << std::hex << b.Val() << ")"; 122 return os; 123 } 124 125 } //namespace armnn 126