1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint.h: fixed-point arithmetic, with basic operations and 16 // a few math functions such as tanh. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 20 21 #include <algorithm> 22 #include <cassert> 23 #include <cmath> 24 #include <cstdint> 25 #include <limits> 26 27 #include "../internal/detect_platform.h" 28 29 namespace gemmlowp { 30 31 // Part 1: Low-level integer-arithmetic primitives. 32 // The implementations here are generic implementations valid for 33 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types 34 // (e.g. NEON int32x4_t) may be supported by providing 35 // specializations for them in separate files. 36 // 37 // The purpose of these primitives is two-fold: 38 // - They will be used to implement higher-level fixed-point 39 // abstractions, namely the FixedPoint class and its arithmetic 40 // operators. 41 // - They will be directly used to implement some more involved 42 // fixed-point computations, e.g. the fixed-point implementation 43 // of math functions such as tanh. 44 45 // Some compile-time traits around raw types to handle SIMD aspects: 46 // number of lanes, underlying scalar type. 47 template <typename tIntegerType> 48 struct FixedPointRawTypeTraits {}; 49 50 template <> 51 struct FixedPointRawTypeTraits<std::int32_t> { 52 typedef std::int32_t ScalarRawType; 53 static constexpr int kLanes = 1; 54 }; 55 56 template <> 57 struct FixedPointRawTypeTraits<std::int16_t> { 58 typedef std::int16_t ScalarRawType; 59 static constexpr int kLanes = 1; 60 }; 61 62 // Returns a SIMD value duplicating a scalar value across all lanes. 63 template <typename tRawType> 64 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 65 return x; 66 } 67 68 // Plain bit-wise AND 69 template <typename tIntegerType> 70 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 71 return a & b; 72 } 73 74 // Plain bit-wise OR 75 template <typename tIntegerType> 76 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 77 return a | b; 78 } 79 80 // Plain bit-wise XOR 81 template <typename tIntegerType> 82 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 83 return a ^ b; 84 } 85 86 // Plain bit-wise NOT 87 template <typename tIntegerType> 88 tIntegerType BitNot(tIntegerType a) { 89 return ~a; 90 } 91 92 // Integer addition. Not saturating. Overflow is undefined behavior. 93 template <typename tIntegerType> 94 tIntegerType Add(tIntegerType a, tIntegerType b) { 95 return a + b; 96 } 97 98 // Integer multiplication. Not saturating. Overflow is undefined behavior. 99 template <typename tIntegerType> 100 tIntegerType Mul(tIntegerType a, tIntegerType b) { 101 return a * b; 102 } 103 104 // Integer subtraction. Not saturating. Overflow is undefined behavior. 105 template <typename tIntegerType> 106 tIntegerType Sub(tIntegerType a, tIntegerType b) { 107 return a - b; 108 } 109 110 // Integer unary negative. Not saturating. Overflow is undefined behavior. 111 template <typename tIntegerType> 112 tIntegerType Neg(tIntegerType a) { 113 return -a; 114 } 115 116 // Integer arithmetic left-shift, equivalent to multiplying with a power of two. 117 // Negative values are OK. In case of overflow, no Undefined 118 // Behavior, but the results are implementation-defined (in practice, 119 // they currently are saturated, but we make no commitment to that). The idea 120 // is that the caller will want to implement the overflowing cases with 121 // saturation with compare-and-mask, so we don't care about the results 122 // in the overflow case, we just want to avoid undefined behavior. 123 // 124 // tIntegerType may be int32 or any narrower signed type. 125 template <typename tIntegerType, typename OffsetType> 126 tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { 127 const std::int64_t wide_a = static_cast<std::int64_t>(a); 128 const std::int64_t wide_shifted = wide_a * (1 << offset); 129 const auto min = std::numeric_limits<tIntegerType>::min(); 130 const auto max = std::numeric_limits<tIntegerType>::max(); 131 return wide_shifted < min 132 ? min 133 : wide_shifted > max ? max 134 : static_cast<tIntegerType>(wide_shifted); 135 } 136 137 // Integer arithmetic right-shift. Not rounding. 138 // Relying on implementation-defined, but in-practice-consistent, 139 // C++ compiler behavior. 140 template <typename tIntegerType> 141 tIntegerType ShiftRight(tIntegerType a, int offset) { 142 return a >> offset; 143 } 144 145 // Each bit of the result is set to the corresponding bit of either then_val or 146 // else_val depending on whether the corresponding bit of if_mask is set. 147 // Equivalent to the VBSL instruction in ARM NEON. 148 template <typename tIntegerType> 149 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 150 tIntegerType else_val) { 151 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 152 } 153 154 // For each input scalar, the corresponding bits of the result are set if the 155 // input scalar is non-zero. 156 template <typename tIntegerType> 157 tIntegerType MaskIfNonZero(tIntegerType a) { 158 static constexpr tIntegerType zero = 0; 159 return a ? BitNot(zero) : zero; 160 } 161 162 // For each input scalar, the corresponding bits of the result are set if the 163 // input scalar is zero. 164 template <typename tIntegerType> 165 tIntegerType MaskIfZero(tIntegerType a) { 166 return MaskIfNonZero<tIntegerType>(!a); 167 } 168 169 // For each pair of input scalars, the corresponding bits of the result are 170 // set if the input scalars are equal. 171 template <typename tIntegerType> 172 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 173 return MaskIfNonZero<tIntegerType>(a == b); 174 } 175 176 // For each pair of input scalars, the corresponding bits of the result are 177 // set if the input scalars are not equal. 178 template <typename tIntegerType> 179 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 180 return MaskIfNonZero<tIntegerType>(a != b); 181 } 182 183 // For each pair of input scalars, the corresponding bits of the result are 184 // set if the input scalars a, b satisfy a > b. 185 template <typename tIntegerType> 186 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 187 return MaskIfNonZero<tIntegerType>(a > b); 188 } 189 190 // For each pair of input scalars, the corresponding bits of the result are 191 // set if the input scalars a, b satisfy a >= b. 192 template <typename tIntegerType> 193 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 194 return MaskIfNonZero<tIntegerType>(a >= b); 195 } 196 197 // For each pair of input scalars, the corresponding bits of the result are 198 // set if the input scalars a, b satisfy a < b. 199 template <typename tIntegerType> 200 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 201 return MaskIfNonZero<tIntegerType>(a < b); 202 } 203 204 // For each pair of input scalars, the corresponding bits of the result are 205 // set if the input scalars a, b satisfy a <= b. 206 template <typename tIntegerType> 207 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 208 return MaskIfNonZero<tIntegerType>(a <= b); 209 } 210 211 // Returns true if all of the input scalars are nonzero. 212 // This function may currently assume that each of the input scalars has either 213 // all or none of its bits set. Otherwise, its behavior is currently undefined. 214 template <typename tIntegerType> 215 bool All(tIntegerType a) { 216 return a; 217 } 218 219 // Returns true if any of the input scalars are nonzero. 220 // This function may currently assume that each of the input scalars has either 221 // all or none of its bits set. Otherwise, its behavior is currently undefined. 222 template <typename tIntegerType> 223 bool Any(tIntegerType a) { 224 return a; 225 } 226 227 // Returns (a+b)/2, rounded to the nearest integer. 228 // Equivalent to VRHADD in the ARM NEON instruction set. 229 template <typename IntegerType> 230 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 231 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 232 (void)b; 233 return a; 234 } 235 236 template <> 237 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 238 std::int64_t a64 = a; 239 std::int64_t b64 = b; 240 std::int64_t sum = a64 + b64; 241 std::int64_t sign = sum >= 0 ? 1 : -1; 242 return static_cast<std::int32_t>((sum + sign) / 2); 243 } 244 245 template <> 246 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { 247 std::int32_t a32 = a; 248 std::int32_t b32 = b; 249 std::int32_t sum = a32 + b32; 250 std::int32_t sign = sum >= 0 ? 1 : -1; 251 return static_cast<std::int16_t>((sum + sign) / 2); 252 } 253 254 template <typename IntegerType> 255 IntegerType SaturatingAdd(IntegerType a, IntegerType b) { 256 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 257 (void)b; 258 return a; 259 } 260 261 // So far this is only needed for int16. 262 template <> 263 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { 264 std::int32_t a32 = a; 265 std::int32_t b32 = b; 266 std::int32_t sum = a32 + b32; 267 return static_cast<std::int16_t>( 268 std::min(static_cast<std::int32_t>(32767), 269 std::max(static_cast<std::int32_t>(-32768), sum))); 270 } 271 272 template <> 273 inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { 274 std::int16_t a16 = a; 275 std::int16_t b16 = b; 276 std::int16_t sum = a16 + b16; 277 return static_cast<std::int8_t>(std::min( 278 static_cast<int16_t>(std::numeric_limits<int8_t>::max()), 279 std::max(static_cast<int16_t>(std::numeric_limits<int8_t>::min()), sum))); 280 } 281 282 // Returns a+b, saturating if the integers are 16bit or narrower, 283 // otherwise just a plain addition. 284 template <typename IntegerType, bool Is16Bit> 285 struct AddSaturatingIf16BitImpl { 286 static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } 287 }; 288 template <typename IntegerType> 289 struct AddSaturatingIf16BitImpl<IntegerType, true> { 290 static IntegerType Run(IntegerType a, IntegerType b) { 291 return SaturatingAdd(a, b); 292 } 293 }; 294 template <typename IntegerType> 295 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { 296 using ScalarType = 297 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 298 return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, 299 b); 300 } 301 302 // Returns the integer that represents the product of two fixed-point 303 // numbers, interpreting all integers as fixed-point values in the 304 // interval [-1, 1), rounding to the nearest value, and saturating 305 // -1 * -1 to the maximum value (since 1 is not in the half-open 306 // interval [-1, 1)). 307 // 308 // [The explanation below specializes to std::int32_t for example purpose.] 309 // 310 // The mapping between IntegerType and the interval [-1, 1) is unique and 311 // implied by IntegerType, which is assumed to be signed. For example, 312 // for IntegerType==std::int32_t, the mapping is 313 // real_value = integer_value / 2^31. 314 // So in this case, and leaving aside rounding and saturating, this 315 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 316 // (a * b) / 2^31. 317 // 318 // The 'doubling' part in the name of this function comes from the fact that 319 // this operation is very close to a "multiply-high" operation, keeping only 320 // the top half bits, except that that would be effectively computing 321 // (a * b) / 2^32, 322 // so here we are computing 2x that, since 323 // 1/2^31 = 2 * 1/2^32. 324 // The idea is to use all of the available 32 bits in the destination int32 325 // value. 326 // 327 // [End of the explanation specializing to int32.] 328 // 329 // This is equivalent to the VQRDMULH instruction in ARM NEON. 330 template <typename IntegerType> 331 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 332 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 333 (void)b; 334 return a; 335 } 336 337 // This function implements the same computation as the ARMv7 NEON VQRDMULH 338 // instruction. 339 template <> 340 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 341 std::int32_t b) { 342 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 343 std::int64_t a_64(a); 344 std::int64_t b_64(b); 345 std::int64_t ab_64 = a_64 * b_64; 346 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 347 std::int32_t ab_x2_high32 = 348 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 349 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 350 } 351 352 template <> 353 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, 354 std::int16_t b) { 355 bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); 356 std::int32_t a_32(a); 357 std::int32_t b_32(b); 358 std::int32_t ab_32 = a_32 * b_32; 359 std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); 360 std::int16_t ab_x2_high16 = 361 static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15)); 362 return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; 363 } 364 365 // Correctly-rounded-to-nearest division by a power-of-two. 366 // Also known as a rounding arithmetic right shift. 367 template <typename IntegerType, typename ExponentType> 368 inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { 369 assert(exponent >= 0); 370 assert(exponent <= 31); 371 const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 372 const IntegerType zero = Dup<IntegerType>(0); 373 const IntegerType one = Dup<IntegerType>(1); 374 const IntegerType remainder = BitAnd(x, mask); 375 const IntegerType threshold = 376 Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 377 return Add(ShiftRight(x, exponent), 378 BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 379 } 380 381 // Returns the product of a run-time integer value by a compile-time power 382 // of two, with either a positive exponent (equivalent to an arithmetic 383 // left shift, saturating) or a negative exponent (equivalent to an arithmetic 384 // right shift, rounding to nearest). 385 template <int Exponent, typename IntegerType, 386 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 387 struct ImplSaturatingRoundingMultiplyByPOT {}; 388 389 template <int Exponent, typename IntegerType> 390 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 391 static IntegerType eval(IntegerType x) { return x; } 392 }; 393 394 template <int Exponent, typename IntegerType> 395 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 396 static IntegerType eval(IntegerType x) { 397 using ScalarIntegerType = 398 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 399 const IntegerType min = 400 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); 401 const IntegerType max = 402 Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); 403 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 404 405 const std::int32_t threshold = 406 ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); 407 const IntegerType positive_mask = 408 MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 409 const IntegerType negative_mask = 410 MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 411 412 IntegerType result = ShiftLeft(x, Exponent); 413 result = SelectUsingMask(positive_mask, max, result); 414 result = SelectUsingMask(negative_mask, min, result); 415 return result; 416 } 417 }; 418 419 template <int Exponent, typename IntegerType> 420 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 421 static IntegerType eval(IntegerType x) { 422 return RoundingDivideByPOT<IntegerType>(x, -Exponent); 423 } 424 }; 425 426 template <int Exponent, typename IntegerType> 427 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 428 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 429 } 430 431 // Part 2: the FixedPoint class. 432 433 // A FixedPoint object represents a fixed-point value stored in the underlying 434 // integer type tRawType, if tRawType is a plain scalar integer type. 435 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 436 // case a FixedPoint object represents a corresponding SIMD vector of fixed 437 // point values. 438 // 439 // tIntegerBits describes the range of the fixed-point format: if 440 // tIntegerBits == m then the range of representable values is the half-open 441 // interval [-2^m; 2^m) where the open boundary on the right side means that 442 // 2^m is not representable (how close the maximum representable value is to 443 // it, depends on bit-depth of tRawType). 444 // 445 // In "Q format notation", 446 // https://en.wikipedia.org/wiki/Q_(number_format) 447 // we are describing the format 448 // Qm.n 449 // where 450 // m = tIntegerBits 451 // and 452 // n = NumberOfBits(tRawType) - (m + 1) 453 // Note that the (m + 1) in the above line is because we adopt the convention 454 // that we count the integer bits exclusively of the sign bit; so (m + 1) is 455 // the total number of integer bits inclusive of the sign bit. 456 // 457 // Accordingly, the number of integral representable values in our range 458 // [-2^m ; 2^m) 459 // is equal to 2^(m+1). 460 template <typename tRawType, int tIntegerBits> 461 class FixedPoint { 462 public: 463 typedef tRawType RawType; 464 465 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 466 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 467 468 static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); 469 static constexpr int kIntegerBits = tIntegerBits; 470 static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; 471 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 472 "bad IntegerBits"); 473 474 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 475 476 static const ScalarRawType ScalarRawMin() { 477 return std::numeric_limits<ScalarRawType>::min(); 478 } 479 480 static const ScalarRawType ScalarRawMax() { 481 return std::numeric_limits<ScalarRawType>::max(); 482 } 483 484 static const ScalarRawType RawMin() { 485 return VectorFromScalar(ScalarRawMin()); 486 } 487 488 static const ScalarRawType RawMax() { 489 return VectorFromScalar(ScalarRawMax()); 490 } 491 492 static FixedPoint FromRaw(RawType x) { 493 FixedPoint retval; 494 retval.raw() = x; 495 return retval; 496 } 497 498 static FixedPoint FromScalarRaw(ScalarRawType x) { 499 FixedPoint retval; 500 retval.raw() = Dup<RawType>(x); 501 return retval; 502 } 503 504 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 505 return FromScalarRaw(x.raw()); 506 } 507 508 template <int Exponent> 509 static FixedPoint ConstantPOT() { 510 static constexpr int kOffset = kFractionalBits + Exponent; 511 static_assert( 512 kOffset < 31, 513 "Constant not exactly representable in this fixed-point format"); 514 return FromScalarRaw(ScalarRawType(1) << kOffset); 515 } 516 517 static FixedPoint Zero() { return FromScalarRaw(0); } 518 519 static FixedPoint One() { 520 return FromScalarRaw( 521 kIntegerBits == 0 522 ? ScalarRawMax() 523 : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); 524 } 525 526 static FixedPoint FromDouble(double x) { 527 const double min_bound = static_cast<double>(ScalarRawMin()); 528 const double max_bound = static_cast<double>(ScalarRawMax()); 529 return FromScalarRaw(static_cast<ScalarRawType>(std::min( 530 std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 531 min_bound), 532 max_bound))); 533 } 534 535 RawType raw() const { return i_; } 536 RawType& raw() { return i_; } 537 538 private: 539 RawType i_; 540 }; 541 542 // Part 3: implementation of arithmetic operators for the 543 // FixedPoint class, and a few related functions. 544 545 // A FixedPoint multiplication is just a 546 // SaturatingRoundingDoublingHighMul operation on the underlying 547 // raw integer values. The IntegerBits simply add up, as is obvious 548 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 549 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 550 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 551 FixedPoint<tRawType, tIntegerBits_a> a, 552 FixedPoint<tRawType, tIntegerBits_b> b) { 553 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 554 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 555 return c; 556 } 557 558 // Tweaking IntegerBits gives exact multiplication by a power of two. 559 template <int tExponent, typename tRawType, int tIntegerBits> 560 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 561 FixedPoint<tRawType, tIntegerBits> a) { 562 FixedPoint<tRawType, tExponent + tIntegerBits> c; 563 c.raw() = a.raw(); 564 return c; 565 } 566 567 // If we want to leave IntegerBits fixed, then multiplication 568 // by a power of two has to be saturating/rounding, not exact anymore. 569 template <int tExponent, typename tRawType, int tIntegerBits> 570 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 571 FixedPoint<tRawType, tIntegerBits> a) { 572 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 573 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 574 } 575 576 // Generic arithmetic operators. 577 578 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 579 template <typename tRawType, int tIntegerBits> \ 580 FixedPoint<tRawType, tIntegerBits> FuncName( \ 581 FixedPoint<tRawType, tIntegerBits> a) { \ 582 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 583 } 584 585 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 586 template <typename tRawType, int tIntegerBits> \ 587 FixedPoint<tRawType, tIntegerBits> FuncName( \ 588 FixedPoint<tRawType, tIntegerBits> a, \ 589 FixedPoint<tRawType, tIntegerBits> b) { \ 590 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 591 ImplFuncName(a.raw(), b.raw())); \ 592 } 593 594 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 595 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 596 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 597 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 598 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 599 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 600 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 601 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 602 603 #undef MAKE_FIXEDPOINT_UNARY_FUNC 604 #undef MAKE_FIXEDPOINT_BINARY_FUNC 605 606 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 607 template <typename tRawType, int tIntegerBits> \ 608 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 609 return FuncName(a.raw()); \ 610 } 611 612 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 613 template <typename tRawType, int tIntegerBits> \ 614 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 615 FixedPoint<tRawType, tIntegerBits> b) { \ 616 return FuncName(a.raw(), b.raw()); \ 617 } 618 619 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 620 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 621 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 622 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 623 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 624 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 625 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 626 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 627 628 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 629 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 630 631 template <typename tRawType, int tIntegerBits> 632 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 633 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 634 FixedPoint<tRawType, tIntegerBits> else_val) { 635 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 636 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 637 } 638 639 template <typename tRawType, int tIntegerBits> 640 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 641 FixedPoint<tRawType, tIntegerBits> b) { 642 return All(MaskIfEqual(a.raw(), b.raw())); 643 } 644 645 template <typename tRawType, int tIntegerBits> 646 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 647 FixedPoint<tRawType, tIntegerBits> b) { 648 return !(a == b); 649 } 650 651 template <typename tRawType, int tIntegerBits> 652 FixedPoint<tRawType, tIntegerBits> SaturatingAdd( 653 FixedPoint<tRawType, tIntegerBits> a, 654 FixedPoint<tRawType, tIntegerBits> b) { 655 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 656 SaturatingAdd(a.raw(), b.raw())); 657 } 658 659 template <typename tRawType, int tIntegerBits> 660 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit( 661 FixedPoint<tRawType, tIntegerBits> a, 662 FixedPoint<tRawType, tIntegerBits> b) { 663 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 664 AddSaturatingIf16Bit(a.raw(), b.raw())); 665 } 666 667 // Conversion to floating-point. 668 template <typename tRawType, int tIntegerBits> 669 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 670 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 671 "not applicable to SIMD types"); 672 typedef FixedPoint<tRawType, tIntegerBits> F; 673 return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 674 } 675 676 // Rescale changes the number of IntegerBits and updates the underlying 677 // raw integer value accordingly. 678 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 679 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 680 FixedPoint<tRawType, tIntegerBitsSrc> x) { 681 static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 682 FixedPoint<tRawType, tIntegerBitsDst> result; 683 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 684 return result; 685 } 686 687 // CheckedFixedPointConstant allows to specify fixed-point constants 688 // initialized as real numbers, in a way that does not compile floating-point 689 // arithmetic in production code, yet still checks agreement with the 690 // floating-point expressions when asserts are enabled. 691 // 692 // The raw integer value provided is always a int32, encoding a 32-bit 693 // fixed-point value, regardless of the actual Scalar type. This allows 694 // writing generic code that applies just as well to the 32-bit and 16-bit 695 // cases. In the 16-bit case, the raw integer value is internally 696 // rounding-shifted by 16 bits to the right. 697 template <typename FixedPointType> 698 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( 699 std::int32_t int32_value) { 700 typedef typename FixedPointType::ScalarRawType ScalarRawType; 701 static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); 702 return static_cast<ScalarRawType>( 703 RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); 704 } 705 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 706 template <typename FixedPointType> 707 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, 708 double double_value) { 709 const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 710 assert(result == FixedPointType::FromDouble(double_value)); 711 return result; 712 } 713 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 714 ScalarRawInt32Value, DoubleValue) \ 715 (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \ 716 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 717 ScalarRawInt32Value), \ 718 DoubleValue)) 719 720 #else 721 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ 722 ScalarRawInt32Value, DoubleValue) \ 723 (FixedPointType::FromScalarRaw( \ 724 gemmlowp::RescaleConstantInitializer<FixedPointType>( \ 725 ScalarRawInt32Value))) 726 #endif 727 728 // Implementation of exponential function. 729 730 // Returns exp(x) for x in [-1/4, 0). 731 template <typename tRawType> 732 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 733 FixedPoint<tRawType, 0> a) { 734 typedef FixedPoint<tRawType, 0> F; 735 const F constant_term = 736 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 737 const F constant_1_over_3 = 738 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 739 // We're evaluating a Taylor expansion around -1/8, so we do the change of 740 // variable: x = a + 1/8. 741 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 742 F x = a + F::template ConstantPOT<-3>(); 743 F x2 = x * x; 744 F x3 = x2 * x; 745 F x4 = x2 * x2; 746 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 747 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 748 SaturatingRoundingMultiplyByPOT<-1>( 749 ((x4_over_4 + x3) * constant_1_over_3) + x2); 750 return AddSaturatingIf16Bit( 751 constant_term, 752 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); 753 } 754 755 // Returns exp(x) for x < 0. 756 template <typename tRawType, int tIntegerBits> 757 FixedPoint<tRawType, 0> exp_on_negative_values( 758 FixedPoint<tRawType, tIntegerBits> a) { 759 typedef FixedPoint<tRawType, tIntegerBits> InputF; 760 typedef FixedPoint<tRawType, 0> ResultF; 761 static constexpr int kFractionalBits = InputF::kFractionalBits; 762 static constexpr int kIntegerBits = InputF::kIntegerBits; 763 const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 764 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 765 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 766 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 767 Rescale<0>(a_mod_quarter_minus_one_quarter)); 768 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 769 770 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 771 if (kIntegerBits > Exponent) { \ 772 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 773 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 774 static constexpr int kShiftAmount = \ 775 kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 776 result = SelectUsingMask( \ 777 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 778 result * kMultiplier, result); \ 779 } 780 781 // Constants below are Q0 representations of negative exp fractionals: 782 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4) 783 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2) 784 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1) 785 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2) 786 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4) 787 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8) 788 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16) 789 790 #undef GEMMLOWP_EXP_BARREL_SHIFTER 791 792 static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; 793 if (kIntegerBits > 5) { 794 const InputF clamp = 795 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); 796 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 797 } 798 799 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 800 return result; 801 } 802 803 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 804 805 // Returns (1 - x) / (1 + x) for x in (0, 1). 806 template <typename tRawType> 807 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 808 FixedPoint<tRawType, 0> a) { 809 typedef FixedPoint<tRawType, 0> F0; 810 typedef FixedPoint<tRawType, 2> F2; 811 F0 half_denominator = RoundingHalfSum(a, F0::One()); 812 // Newton-Raphson division 813 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 814 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 815 const F2 constant_48_over_17 = 816 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 817 const F2 constant_neg_32_over_17 = 818 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 819 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 820 for (int i = 0; i < 3; i++) { 821 F2 half_denominator_times_x = half_denominator * x; 822 F2 one_minus_half_denominator_times_x = 823 F2::One() - half_denominator_times_x; 824 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 825 } 826 return Rescale<0>(x - F2::One()); 827 } 828 829 // Returns -tanh(x) for x < 0. 830 template <typename tRawType, int tIntegerBits> 831 FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 832 FixedPoint<tRawType, tIntegerBits> a) { 833 return one_minus_x_over_one_plus_x_for_x_in_0_1( 834 exp_on_negative_values(ExactMulByPot<1>(a))); 835 } 836 837 // Returns tanh(x) for any x. 838 template <typename tRawType, int tIntegerBits> 839 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 840 typedef FixedPoint<tRawType, tIntegerBits> InputF; 841 typedef FixedPoint<tRawType, 0> ResultF; 842 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 843 tRawType mask_if_zero = MaskIfZero(a); 844 InputF n = SelectUsingMask(mask_if_negative, a, -a); 845 ResultF t = neg_tanh_on_negative_values(n); 846 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 847 SelectUsingMask(mask_if_negative, -t, t)); 848 } 849 850 // Implementation of logistic function. 851 852 // Returns 1 / (1 + x) for x in (0, 1). 853 template <typename tRawType> 854 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 855 FixedPoint<tRawType, 0> a) { 856 typedef FixedPoint<tRawType, 0> F0; 857 typedef FixedPoint<tRawType, 2> F2; 858 F0 half_denominator = RoundingHalfSum(a, F0::One()); 859 // Newton-Raphson division 860 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 861 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 862 const F2 constant_48_over_17 = 863 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 864 const F2 constant_neg_32_over_17 = 865 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 866 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 867 for (int i = 0; i < 3; i++) { 868 F2 half_denominator_times_x = half_denominator * x; 869 F2 one_minus_half_denominator_times_x = 870 F2::One() - half_denominator_times_x; 871 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 872 } 873 return Rescale<0>(ExactMulByPot<-1>(x)); 874 } 875 876 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 877 template <typename tRawType, int tIntegerBits> 878 FixedPoint<tRawType, 0> logistic_on_positive_values( 879 FixedPoint<tRawType, tIntegerBits> a) { 880 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 881 } 882 883 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 884 template <typename tRawType, int tIntegerBits> 885 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 886 typedef FixedPoint<tRawType, tIntegerBits> InputF; 887 typedef FixedPoint<tRawType, 0> ResultF; 888 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 889 tRawType mask_if_zero = MaskIfZero(a); 890 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 891 ResultF result_if_positive = logistic_on_positive_values(abs_input); 892 ResultF result_if_negative = ResultF::One() - result_if_positive; 893 const ResultF one_half = 894 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 895 return SelectUsingMask(mask_if_zero, one_half, 896 SelectUsingMask(mask_if_positive, result_if_positive, 897 result_if_negative)); 898 } 899 900 } // end namespace gemmlowp 901 902 #ifdef GEMMLOWP_NEON 903 #include "./fixedpoint_neon.h" 904 #elif defined(GEMMLOWP_AVX2) 905 #include "./fixedpoint_avx.h" 906 #elif defined(GEMMLOWP_SSE4) 907 #include "./fixedpoint_sse.h" 908 #elif defined(GEMMLOWP_MSA) 909 #include "./fixedpoint_msa.h" 910 #elif defined(GEMMLOWP_WASMSIMD) 911 #include "./fixedpoint_wasmsimd.h" 912 #endif 913 914 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 915