1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ 17 #define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ 18 19 #include <cmath> 20 #include <complex> 21 22 #include "tensorflow/core/platform/byte_order.h" 23 24 #ifdef __CUDACC__ 25 // All functions callable from CUDA code must be qualified with __device__ 26 #define B16_DEVICE_FUNC __host__ __device__ 27 28 #else 29 #define B16_DEVICE_FUNC 30 31 #endif 32 33 namespace Eigen { 34 struct half; 35 } 36 37 namespace tensorflow { 38 39 // Single precision complex. 40 typedef std::complex<float> complex64; 41 // Double precision complex. 42 typedef std::complex<double> complex128; 43 44 // see framework/bfloat16.h for description. 45 struct bfloat16 { 46 // The default constructor must yield a zero value, not an uninitialized 47 // value; some TF kernels use T() as a zero value. bfloat16bfloat1648 B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {} 49 truncate_to_bfloat16bfloat1650 B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) { 51 bfloat16 output; 52 if (float_isnan(v)) { 53 output.value = NAN_VALUE; 54 return output; 55 } 56 const uint16_t* p = reinterpret_cast<const uint16_t*>(&v); 57 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 58 output.value = p[0]; 59 #else 60 output.value = p[1]; 61 #endif 62 return output; 63 } 64 bfloat16bfloat1665 B16_DEVICE_FUNC explicit bfloat16(const float v) { 66 value = round_to_bfloat16(v).value; 67 } 68 bfloat16bfloat1669 B16_DEVICE_FUNC explicit bfloat16(const double val) 70 : bfloat16(static_cast<float>(val)) {} 71 // Following the convention of numpy, converting between complex and 72 // float will lead to loss of imag value. bfloat16bfloat1673 B16_DEVICE_FUNC explicit bfloat16(const complex64& val) 74 : bfloat16(val.real()) {} 75 bfloat16bfloat1676 B16_DEVICE_FUNC explicit bfloat16(const complex128& val) 77 : bfloat16(static_cast<float>(val.real())) {} 78 bfloat16bfloat1679 B16_DEVICE_FUNC explicit bfloat16(const unsigned short val) 80 : bfloat16(static_cast<float>(val)) {} 81 bfloat16bfloat1682 B16_DEVICE_FUNC explicit bfloat16(const unsigned int val) 83 : bfloat16(static_cast<float>(val)) {} 84 bfloat16bfloat1685 B16_DEVICE_FUNC explicit bfloat16(const int val) 86 : bfloat16(static_cast<float>(val)) {} 87 bfloat16bfloat1688 B16_DEVICE_FUNC explicit bfloat16(const long val) 89 : bfloat16(static_cast<float>(val)) {} 90 bfloat16bfloat1691 B16_DEVICE_FUNC explicit bfloat16(const long long val) 92 : bfloat16(static_cast<float>(val)) {} 93 94 template <class T> bfloat16bfloat1695 B16_DEVICE_FUNC explicit bfloat16(const T& val) 96 : bfloat16(static_cast<float>(val)) {} 97 98 B16_DEVICE_FUNC explicit operator float() const { 99 float result = 0; 100 101 uint16_t* q = reinterpret_cast<uint16_t*>(&result); 102 103 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 104 q[0] = value; 105 #else 106 q[1] = value; 107 #endif 108 return result; 109 } 110 111 B16_DEVICE_FUNC explicit operator bool() const { 112 return static_cast<bool>(float(*this)); 113 } 114 115 B16_DEVICE_FUNC explicit operator Eigen::half() const; 116 117 B16_DEVICE_FUNC explicit operator short() const { 118 return static_cast<short>(float(*this)); 119 } 120 121 B16_DEVICE_FUNC explicit operator int() const { 122 return static_cast<int>(float(*this)); 123 } 124 125 B16_DEVICE_FUNC explicit operator long() const { 126 return static_cast<long>(float(*this)); 127 } 128 129 B16_DEVICE_FUNC explicit operator char() const { 130 return static_cast<char>(float(*this)); 131 } 132 133 B16_DEVICE_FUNC explicit operator signed char() const { 134 return static_cast<signed char>(float(*this)); 135 } 136 137 B16_DEVICE_FUNC explicit operator unsigned char() const { 138 return static_cast<unsigned char>(float(*this)); 139 } 140 141 B16_DEVICE_FUNC explicit operator unsigned short() const { 142 return static_cast<unsigned short>(float(*this)); 143 } 144 145 B16_DEVICE_FUNC explicit operator unsigned int() const { 146 return static_cast<unsigned int>(float(*this)); 147 } 148 149 B16_DEVICE_FUNC explicit operator unsigned long() const { 150 return static_cast<unsigned long>(float(*this)); 151 } 152 153 B16_DEVICE_FUNC explicit operator unsigned long long() const { 154 return static_cast<unsigned long long>(float(*this)); 155 } 156 157 B16_DEVICE_FUNC explicit operator long long() const { 158 return static_cast<long long>(float(*this)); 159 } 160 161 B16_DEVICE_FUNC explicit operator double() const { 162 return static_cast<double>(float(*this)); 163 } 164 complex64bfloat16165 B16_DEVICE_FUNC explicit operator complex64() const { 166 return complex64(float(*this), float(0.0)); 167 } 168 complex128bfloat16169 B16_DEVICE_FUNC explicit operator complex128() const { 170 return complex128(double(*this), double(0.0)); 171 } 172 173 union FP32 { 174 unsigned int u; 175 float f; 176 }; 177 178 // Converts a float point to bfloat16, with round-nearest-to-even as rounding 179 // method. 180 // TODO: There is a slightly faster implementation (8% faster on CPU) 181 // than this (documented in cl/175987786), that is exponentially harder to 182 // understand and document. Switch to the faster version when converting to 183 // BF16 becomes compute-bound. round_to_bfloat16bfloat16184 B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) { 185 uint32_t input; 186 FP32 f; 187 f.f = v; 188 input = f.u; 189 bfloat16 output; 190 191 if (float_isnan(v)) { 192 // If the value is a NaN, squash it to a qNaN with msb of fraction set, 193 // this makes sure after truncation we don't end up with an inf. 194 // 195 // qNaN magic: All exponent bits set + most significant bit of fraction 196 // set. 197 output.value = 0x7fc0; 198 } else { 199 // Fast rounding algorithm that rounds a half value to nearest even. This 200 // reduces expected error when we convert a large number of floats. Here 201 // is how it works: 202 // 203 // Definitions: 204 // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits 205 // with the following tags: 206 // 207 // Sign | Exp (8 bits) | Frac (23 bits) 208 // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT 209 // 210 // S: Sign bit. 211 // E: Exponent bits. 212 // F: First 6 bits of fraction. 213 // L: Least significant bit of resulting bfloat16 if we truncate away the 214 // rest of the float32. This is also the 7th bit of fraction 215 // R: Rounding bit, 8th bit of fraction. 216 // T: Sticky bits, rest of fraction, 15 bits. 217 // 218 // To round half to nearest even, there are 3 cases where we want to round 219 // down (simply truncate the result of the bits away, which consists of 220 // rounding bit and sticky bits) and two cases where we want to round up 221 // (truncate then add one to the result). 222 // 223 // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of 224 // 1s) as the rounding bias, adds the rounding bias to the input, then 225 // truncates the last 16 bits away. 226 // 227 // To understand how it works, we can analyze this algorithm case by case: 228 // 229 // 1. L = 0, R = 0: 230 // Expect: round down, this is less than half value. 231 // 232 // Algorithm: 233 // - Rounding bias: 0x7fff + 0 = 0x7fff 234 // - Adding rounding bias to input may create any carry, depending on 235 // whether there is any value set to 1 in T bits. 236 // - R may be set to 1 if there is a carry. 237 // - L remains 0. 238 // - Note that this case also handles Inf and -Inf, where all fraction 239 // bits, including L, R and Ts are all 0. The output remains Inf after 240 // this algorithm. 241 // 242 // 2. L = 1, R = 0: 243 // Expect: round down, this is less than half value. 244 // 245 // Algorithm: 246 // - Rounding bias: 0x7fff + 1 = 0x8000 247 // - Adding rounding bias to input doesn't change sticky bits but 248 // adds 1 to rounding bit. 249 // - L remains 1. 250 // 251 // 3. L = 0, R = 1, all of T are 0: 252 // Expect: round down, this is exactly at half, the result is already 253 // even (L=0). 254 // 255 // Algorithm: 256 // - Rounding bias: 0x7fff + 0 = 0x7fff 257 // - Adding rounding bias to input sets all sticky bits to 1, but 258 // doesn't create a carry. 259 // - R remains 1. 260 // - L remains 0. 261 // 262 // 4. L = 1, R = 1: 263 // Expect: round up, this is exactly at half, the result needs to be 264 // round to the next even number. 265 // 266 // Algorithm: 267 // - Rounding bias: 0x7fff + 1 = 0x8000 268 // - Adding rounding bias to input doesn't change sticky bits, but 269 // creates a carry from rounding bit. 270 // - The carry sets L to 0, creates another carry bit and propagate 271 // forward to F bits. 272 // - If all the F bits are 1, a carry then propagates to the exponent 273 // bits, which then creates the minimum value with the next exponent 274 // value. Note that we won't have the case where exponents are all 1, 275 // since that's either a NaN (handled in the other if condition) or inf 276 // (handled in case 1). 277 // 278 // 5. L = 0, R = 1, any of T is 1: 279 // Expect: round up, this is greater than half. 280 // 281 // Algorithm: 282 // - Rounding bias: 0x7fff + 0 = 0x7fff 283 // - Adding rounding bias to input creates a carry from sticky bits, 284 // sets rounding bit to 0, then create another carry. 285 // - The second carry sets L to 1. 286 // 287 // Examples: 288 // 289 // Exact half value that is already even: 290 // Input: 291 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) 292 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT 293 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 294 // 295 // This falls into case 3. We truncate the rest of 16 bits and no 296 // carry is created into F and L: 297 // 298 // Output: 299 // Sign | Exp (8 bit) | Frac (first 7 bit) 300 // S E E E E E E E E F F F F F F L 301 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 302 // 303 // Exact half value, round to next even number: 304 // Input: 305 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) 306 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT 307 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 308 // 309 // This falls into case 4. We create a carry from R and T, 310 // which then propagates into L and F: 311 // 312 // Output: 313 // Sign | Exp (8 bit) | Frac (first 7 bit) 314 // S E E E E E E E E F F F F F F L 315 // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 316 // 317 // 318 // Max denormal value round to min normal value: 319 // Input: 320 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) 321 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT 322 // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 323 // 324 // This falls into case 4. We create a carry from R and T, 325 // propagate into L and F, which then propagates into exponent 326 // bits: 327 // 328 // Output: 329 // Sign | Exp (8 bit) | Frac (first 7 bit) 330 // S E E E E E E E E F F F F F F L 331 // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 332 // 333 // Max normal value round to Inf: 334 // Input: 335 // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) 336 // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT 337 // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 338 // 339 // This falls into case 4. We create a carry from R and T, 340 // propagate into L and F, which then propagates into exponent 341 // bits: 342 // 343 // Sign | Exp (8 bit) | Frac (first 7 bit) 344 // S E E E E E E E E F F F F F F L 345 // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 346 // 347 // 348 // Least significant bit of resulting bfloat. 349 uint32_t lsb = (input >> 16) & 1; 350 uint32_t rounding_bias = 0x7fff + lsb; 351 input += rounding_bias; 352 output.value = static_cast<uint16_t>(input >> 16); 353 } 354 return output; 355 } 356 epsilonbfloat16357 static bfloat16 epsilon() { 358 bfloat16 x; 359 x.value = 0x3c00; // 0x1.0p-7 360 return x; 361 } 362 highestbfloat16363 static bfloat16 highest() { 364 bfloat16 x; 365 x.value = 0x7F7F; // 0x1.FEp127 366 return x; 367 } 368 lowestbfloat16369 static bfloat16 lowest() { 370 bfloat16 x; 371 x.value = 0xFF7F; // -0x1.FEp127 372 return x; 373 } 374 min_positive_normalbfloat16375 static bfloat16 min_positive_normal() { 376 bfloat16 x; 377 x.value = 0x0080; // 0x1p-126 378 return x; 379 } 380 IsZerobfloat16381 bool IsZero() const { return (value & 0x7FFF) == ZERO_VALUE; } 382 383 uint16_t value; 384 385 // A value that represents "not a number". 386 static const uint16_t NAN_VALUE = 0x7FC0; 387 388 private: 389 // A value that represents "zero". 390 static const uint16_t ZERO_VALUE = 0; 391 float_isnanbfloat16392 B16_DEVICE_FUNC static bool float_isnan(const float& x) { 393 #ifdef __CUDA_ARCH__ 394 return ::isnan(x); 395 #else 396 return std::isnan(x); 397 #endif 398 } 399 }; 400 401 B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os, 402 const bfloat16& dt) { 403 os << static_cast<float>(dt); 404 return os; 405 } 406 407 B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) { 408 return bfloat16(static_cast<float>(a) + static_cast<float>(b)); 409 } 410 B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) { 411 return bfloat16(static_cast<float>(a) + static_cast<float>(b)); 412 } 413 B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) { 414 return bfloat16(static_cast<float>(a) + static_cast<float>(b)); 415 } 416 B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) { 417 return bfloat16(static_cast<float>(a) - static_cast<float>(b)); 418 } 419 B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) { 420 return bfloat16(static_cast<float>(a) * static_cast<float>(b)); 421 } 422 B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) { 423 return bfloat16(static_cast<float>(a) / static_cast<float>(b)); 424 } 425 B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) { 426 a.value ^= 0x8000; 427 return a; 428 } 429 B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) { 430 return static_cast<float>(a) < static_cast<float>(b); 431 } 432 B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) { 433 return static_cast<float>(a) <= static_cast<float>(b); 434 } 435 B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) { 436 return static_cast<float>(a) == static_cast<float>(b); 437 } 438 B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) { 439 return static_cast<float>(a) != static_cast<float>(b); 440 } 441 B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) { 442 return static_cast<float>(a) > static_cast<float>(b); 443 } 444 B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) { 445 return static_cast<float>(a) >= static_cast<float>(b); 446 } 447 B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) { 448 a = a + b; 449 return a; 450 } 451 B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) { 452 a = a - b; 453 return a; 454 } 455 B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) { 456 a += bfloat16(1); 457 return a; 458 } 459 B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) { 460 a -= bfloat16(1); 461 return a; 462 } 463 B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) { 464 bfloat16 original_value = a; 465 ++a; 466 return original_value; 467 } 468 B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) { 469 bfloat16 original_value = a; 470 --a; 471 return original_value; 472 } 473 B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) { 474 a = a * b; 475 return a; 476 } 477 B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) { 478 a = a / b; 479 return a; 480 } 481 } // end namespace tensorflow 482 483 namespace std { 484 template <> 485 struct hash<tensorflow::bfloat16> { 486 size_t operator()(const tensorflow::bfloat16& v) const { 487 return hash<float>()(static_cast<float>(v)); 488 } 489 }; 490 491 using tensorflow::bfloat16; 492 inline bool isinf(const bfloat16& a) { return std::isinf(float(a)); } 493 inline bool isnan(const bfloat16& a) { return std::isnan(float(a)); } 494 inline bool isfinite(const bfloat16& a) { return std::isfinite(float(a)); } 495 inline bfloat16 abs(const bfloat16& a) { return bfloat16(std::abs(float(a))); } 496 inline bfloat16 exp(const bfloat16& a) { return bfloat16(std::exp(float(a))); } 497 inline bfloat16 log(const bfloat16& a) { return bfloat16(std::log(float(a))); } 498 inline bfloat16 log10(const bfloat16& a) { 499 return bfloat16(std::log10(float(a))); 500 } 501 inline bfloat16 sqrt(const bfloat16& a) { 502 return bfloat16(std::sqrt(float(a))); 503 } 504 inline bfloat16 pow(const bfloat16& a, const bfloat16& b) { 505 return bfloat16(std::pow(float(a), float(b))); 506 } 507 inline bfloat16 sin(const bfloat16& a) { return bfloat16(std::sin(float(a))); } 508 inline bfloat16 cos(const bfloat16& a) { return bfloat16(std::cos(float(a))); } 509 inline bfloat16 tan(const bfloat16& a) { return bfloat16(std::tan(float(a))); } 510 inline bfloat16 tanh(const bfloat16& a) { 511 return bfloat16(std::tanh(float(a))); 512 } 513 inline bfloat16 floor(const bfloat16& a) { 514 return bfloat16(std::floor(float(a))); 515 } 516 inline bfloat16 ceil(const bfloat16& a) { 517 return bfloat16(std::ceil(float(a))); 518 } 519 } // namespace std 520 521 #endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ 522