1 /**************************************************************************** 2 * Copyright (C) 2017 Intel Corporation. All Rights Reserved. 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 ****************************************************************************/ 23 #pragma once 24 25 #if !defined(__cplusplus) 26 #error C++ compilation required 27 #endif 28 29 #include <immintrin.h> 30 #include <inttypes.h> 31 #include <stdint.h> 32 33 #define SIMD_ARCH_AVX 0 34 #define SIMD_ARCH_AVX2 1 35 #define SIMD_ARCH_AVX512 2 36 37 #if !defined(SIMD_ARCH) 38 #define SIMD_ARCH SIMD_ARCH_AVX 39 #endif 40 41 #if defined(_MSC_VER) 42 #define SIMDCALL __vectorcall 43 #define SIMDINLINE __forceinline 44 #define SIMDALIGN(type_, align_) __declspec(align(align_)) type_ 45 #else 46 #define SIMDCALL 47 #define SIMDINLINE inline 48 #define SIMDALIGN(type_, align_) type_ __attribute__((aligned(align_))) 49 #endif 50 51 // For documentation, please see the following include... 52 // #include "simdlib_interface.hpp" 53 54 namespace SIMDImpl 55 { 56 enum class CompareType 57 { 58 EQ_OQ = 0x00, // Equal (ordered, nonsignaling) 59 LT_OS = 0x01, // Less-than (ordered, signaling) 60 LE_OS = 0x02, // Less-than-or-equal (ordered, signaling) 61 UNORD_Q = 0x03, // Unordered (nonsignaling) 62 NEQ_UQ = 0x04, // Not-equal (unordered, nonsignaling) 63 NLT_US = 0x05, // Not-less-than (unordered, signaling) 64 NLE_US = 0x06, // Not-less-than-or-equal (unordered, signaling) 65 ORD_Q = 0x07, // Ordered (nonsignaling) 66 EQ_UQ = 0x08, // Equal (unordered, non-signaling) 67 NGE_US = 0x09, // Not-greater-than-or-equal (unordered, signaling) 68 NGT_US = 0x0A, // Not-greater-than (unordered, signaling) 69 FALSE_OQ = 0x0B, // False (ordered, nonsignaling) 70 NEQ_OQ = 0x0C, // Not-equal (ordered, non-signaling) 71 GE_OS = 0x0D, // Greater-than-or-equal (ordered, signaling) 72 GT_OS = 0x0E, // Greater-than (ordered, signaling) 73 TRUE_UQ = 0x0F, // True (unordered, non-signaling) 74 EQ_OS = 0x10, // Equal (ordered, signaling) 75 LT_OQ = 0x11, // Less-than (ordered, nonsignaling) 76 LE_OQ = 0x12, // Less-than-or-equal (ordered, nonsignaling) 77 UNORD_S = 0x13, // Unordered (signaling) 78 NEQ_US = 0x14, // Not-equal (unordered, signaling) 79 NLT_UQ = 0x15, // Not-less-than (unordered, nonsignaling) 80 NLE_UQ = 0x16, // Not-less-than-or-equal (unordered, nonsignaling) 81 ORD_S = 0x17, // Ordered (signaling) 82 EQ_US = 0x18, // Equal (unordered, signaling) 83 NGE_UQ = 0x19, // Not-greater-than-or-equal (unordered, nonsignaling) 84 NGT_UQ = 0x1A, // Not-greater-than (unordered, nonsignaling) 85 FALSE_OS = 0x1B, // False (ordered, signaling) 86 NEQ_OS = 0x1C, // Not-equal (ordered, signaling) 87 GE_OQ = 0x1D, // Greater-than-or-equal (ordered, nonsignaling) 88 GT_OQ = 0x1E, // Greater-than (ordered, nonsignaling) 89 TRUE_US = 0x1F, // True (unordered, signaling) 90 }; 91 92 #if SIMD_ARCH >= SIMD_ARCH_AVX512 93 enum class CompareTypeInt 94 { 95 EQ = _MM_CMPINT_EQ, // Equal 96 LT = _MM_CMPINT_LT, // Less than 97 LE = _MM_CMPINT_LE, // Less than or Equal 98 NE = _MM_CMPINT_NE, // Not Equal 99 GE = _MM_CMPINT_GE, // Greater than or Equal 100 GT = _MM_CMPINT_GT, // Greater than 101 }; 102 #endif // SIMD_ARCH >= SIMD_ARCH_AVX512 103 104 enum class ScaleFactor 105 { 106 SF_1 = 1, // No scaling 107 SF_2 = 2, // Scale offset by 2 108 SF_4 = 4, // Scale offset by 4 109 SF_8 = 8, // Scale offset by 8 110 }; 111 112 enum class RoundMode 113 { 114 TO_NEAREST_INT = 0x00, // Round to nearest integer == TRUNCATE(value + 0.5) 115 TO_NEG_INF = 0x01, // Round to negative infinity 116 TO_POS_INF = 0x02, // Round to positive infinity 117 TO_ZERO = 0x03, // Round to 0 a.k.a. truncate 118 CUR_DIRECTION = 0x04, // Round in direction set in MXCSR register 119 120 RAISE_EXC = 0x00, // Raise exception on overflow 121 NO_EXC = 0x08, // Suppress exceptions 122 123 NINT = static_cast<int>(TO_NEAREST_INT) | static_cast<int>(RAISE_EXC), 124 NINT_NOEXC = static_cast<int>(TO_NEAREST_INT) | static_cast<int>(NO_EXC), 125 FLOOR = static_cast<int>(TO_NEG_INF) | static_cast<int>(RAISE_EXC), 126 FLOOR_NOEXC = static_cast<int>(TO_NEG_INF) | static_cast<int>(NO_EXC), 127 CEIL = static_cast<int>(TO_POS_INF) | static_cast<int>(RAISE_EXC), 128 CEIL_NOEXC = static_cast<int>(TO_POS_INF) | static_cast<int>(NO_EXC), 129 TRUNC = static_cast<int>(TO_ZERO) | static_cast<int>(RAISE_EXC), 130 TRUNC_NOEXC = static_cast<int>(TO_ZERO) | static_cast<int>(NO_EXC), 131 RINT = static_cast<int>(CUR_DIRECTION) | static_cast<int>(RAISE_EXC), 132 NEARBYINT = static_cast<int>(CUR_DIRECTION) | static_cast<int>(NO_EXC), 133 }; 134 135 struct Traits 136 { 137 using CompareType = SIMDImpl::CompareType; 138 using ScaleFactor = SIMDImpl::ScaleFactor; 139 using RoundMode = SIMDImpl::RoundMode; 140 }; 141 142 // Attribute, 4-dimensional attribute in SIMD SOA layout 143 template <typename Float, typename Integer, typename Double> 144 union Vec4 145 { 146 Float v[4]; 147 Integer vi[4]; 148 Double vd[4]; 149 struct 150 { 151 Float x; 152 Float y; 153 Float z; 154 Float w; 155 }; operator [](const int i)156 SIMDINLINE Float& SIMDCALL operator[](const int i) { return v[i]; } operator [](const int i) const157 SIMDINLINE Float const& SIMDCALL operator[](const int i) const { return v[i]; } operator =(Vec4 const & in)158 SIMDINLINE Vec4& SIMDCALL operator=(Vec4 const& in) 159 { 160 v[0] = in.v[0]; 161 v[1] = in.v[1]; 162 v[2] = in.v[2]; 163 v[3] = in.v[3]; 164 return *this; 165 } 166 }; 167 168 namespace SIMD128Impl 169 { 170 union Float 171 { 172 SIMDINLINE Float() = default; Float(__m128 in)173 SIMDINLINE Float(__m128 in) : v(in) {} operator =(__m128 in)174 SIMDINLINE Float& SIMDCALL operator=(__m128 in) 175 { 176 v = in; 177 return *this; 178 } operator =(Float const & in)179 SIMDINLINE Float& SIMDCALL operator=(Float const& in) 180 { 181 v = in.v; 182 return *this; 183 } operator __m128() const184 SIMDINLINE SIMDCALL operator __m128() const { return v; } 185 186 SIMDALIGN(__m128, 16) v; 187 }; 188 189 union Integer 190 { 191 SIMDINLINE Integer() = default; Integer(__m128i in)192 SIMDINLINE Integer(__m128i in) : v(in) {} operator =(__m128i in)193 SIMDINLINE Integer& SIMDCALL operator=(__m128i in) 194 { 195 v = in; 196 return *this; 197 } operator =(Integer const & in)198 SIMDINLINE Integer& SIMDCALL operator=(Integer const& in) 199 { 200 v = in.v; 201 return *this; 202 } operator __m128i() const203 SIMDINLINE SIMDCALL operator __m128i() const { return v; } 204 205 SIMDALIGN(__m128i, 16) v; 206 }; 207 208 union Double 209 { 210 SIMDINLINE Double() = default; Double(__m128d in)211 SIMDINLINE Double(__m128d in) : v(in) {} operator =(__m128d in)212 SIMDINLINE Double& SIMDCALL operator=(__m128d in) 213 { 214 v = in; 215 return *this; 216 } operator =(Double const & in)217 SIMDINLINE Double& SIMDCALL operator=(Double const& in) 218 { 219 v = in.v; 220 return *this; 221 } operator __m128d() const222 SIMDINLINE SIMDCALL operator __m128d() const { return v; } 223 224 SIMDALIGN(__m128d, 16) v; 225 }; 226 227 using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>; 228 using Mask = uint8_t; 229 230 static const uint32_t SIMD_WIDTH = 4; 231 } // namespace SIMD128Impl 232 233 namespace SIMD256Impl 234 { 235 union Float 236 { 237 SIMDINLINE Float() = default; Float(__m256 in)238 SIMDINLINE Float(__m256 in) : v(in) {} Float(SIMD128Impl::Float const & in_lo,SIMD128Impl::Float const & in_hi=_mm_setzero_ps ())239 SIMDINLINE Float(SIMD128Impl::Float const& in_lo, 240 SIMD128Impl::Float const& in_hi = _mm_setzero_ps()) 241 { 242 v = _mm256_insertf128_ps(_mm256_castps128_ps256(in_lo), in_hi, 0x1); 243 } operator =(__m256 in)244 SIMDINLINE Float& SIMDCALL operator=(__m256 in) 245 { 246 v = in; 247 return *this; 248 } operator =(Float const & in)249 SIMDINLINE Float& SIMDCALL operator=(Float const& in) 250 { 251 v = in.v; 252 return *this; 253 } operator __m256() const254 SIMDINLINE SIMDCALL operator __m256() const { return v; } 255 256 SIMDALIGN(__m256, 32) v; 257 SIMD128Impl::Float v4[2]; 258 }; 259 260 union Integer 261 { 262 SIMDINLINE Integer() = default; Integer(__m256i in)263 SIMDINLINE Integer(__m256i in) : v(in) {} Integer(SIMD128Impl::Integer const & in_lo,SIMD128Impl::Integer const & in_hi=_mm_setzero_si128 ())264 SIMDINLINE Integer(SIMD128Impl::Integer const& in_lo, 265 SIMD128Impl::Integer const& in_hi = _mm_setzero_si128()) 266 { 267 v = _mm256_insertf128_si256(_mm256_castsi128_si256(in_lo), in_hi, 0x1); 268 } operator =(__m256i in)269 SIMDINLINE Integer& SIMDCALL operator=(__m256i in) 270 { 271 v = in; 272 return *this; 273 } operator =(Integer const & in)274 SIMDINLINE Integer& SIMDCALL operator=(Integer const& in) 275 { 276 v = in.v; 277 return *this; 278 } operator __m256i() const279 SIMDINLINE SIMDCALL operator __m256i() const { return v; } 280 281 SIMDALIGN(__m256i, 32) v; 282 SIMD128Impl::Integer v4[2]; 283 }; 284 285 union Double 286 { 287 SIMDINLINE Double() = default; Double(__m256d const & in)288 SIMDINLINE Double(__m256d const& in) : v(in) {} Double(SIMD128Impl::Double const & in_lo,SIMD128Impl::Double const & in_hi=_mm_setzero_pd ())289 SIMDINLINE Double(SIMD128Impl::Double const& in_lo, 290 SIMD128Impl::Double const& in_hi = _mm_setzero_pd()) 291 { 292 v = _mm256_insertf128_pd(_mm256_castpd128_pd256(in_lo), in_hi, 0x1); 293 } operator =(__m256d in)294 SIMDINLINE Double& SIMDCALL operator=(__m256d in) 295 { 296 v = in; 297 return *this; 298 } operator =(Double const & in)299 SIMDINLINE Double& SIMDCALL operator=(Double const& in) 300 { 301 v = in.v; 302 return *this; 303 } operator __m256d() const304 SIMDINLINE SIMDCALL operator __m256d() const { return v; } 305 306 SIMDALIGN(__m256d, 32) v; 307 SIMD128Impl::Double v4[2]; 308 }; 309 310 using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>; 311 using Mask = uint8_t; 312 313 static const uint32_t SIMD_WIDTH = 8; 314 } // namespace SIMD256Impl 315 316 namespace SIMD512Impl 317 { 318 #if !(defined(__AVX512F__) || defined(_ZMMINTRIN_H_INCLUDED)) 319 // Define AVX512 types if not included via immintrin.h. 320 // All data members of these types are ONLY to viewed 321 // in a debugger. Do NOT access them via code! 322 union __m512 323 { 324 private: 325 float m512_f32[16]; 326 }; 327 struct __m512d 328 { 329 private: 330 double m512d_f64[8]; 331 }; 332 333 union __m512i 334 { 335 private: 336 int8_t m512i_i8[64]; 337 int16_t m512i_i16[32]; 338 int32_t m512i_i32[16]; 339 int64_t m512i_i64[8]; 340 uint8_t m512i_u8[64]; 341 uint16_t m512i_u16[32]; 342 uint32_t m512i_u32[16]; 343 uint64_t m512i_u64[8]; 344 }; 345 346 using __mmask16 = uint16_t; 347 #endif 348 349 #if defined(__INTEL_COMPILER) || (SIMD_ARCH >= SIMD_ARCH_AVX512) 350 #define SIMD_ALIGNMENT_BYTES 64 351 #else 352 #define SIMD_ALIGNMENT_BYTES 32 353 #endif 354 355 union Float 356 { 357 SIMDINLINE Float() = default; Float(__m512 in)358 SIMDINLINE Float(__m512 in) : v(in) {} Float(SIMD256Impl::Float const & in_lo,SIMD256Impl::Float const & in_hi=_mm256_setzero_ps ())359 SIMDINLINE Float(SIMD256Impl::Float const& in_lo, 360 SIMD256Impl::Float const& in_hi = _mm256_setzero_ps()) 361 { 362 v8[0] = in_lo; 363 v8[1] = in_hi; 364 } operator =(__m512 in)365 SIMDINLINE Float& SIMDCALL operator=(__m512 in) 366 { 367 v = in; 368 return *this; 369 } operator =(Float const & in)370 SIMDINLINE Float& SIMDCALL operator=(Float const& in) 371 { 372 #if SIMD_ARCH >= SIMD_ARCH_AVX512 373 v = in.v; 374 #else 375 v8[0] = in.v8[0]; 376 v8[1] = in.v8[1]; 377 #endif 378 return *this; 379 } operator __m512() const380 SIMDINLINE SIMDCALL operator __m512() const { return v; } 381 382 SIMDALIGN(__m512, SIMD_ALIGNMENT_BYTES) v; 383 SIMD256Impl::Float v8[2]; 384 }; 385 386 union Integer 387 { 388 SIMDINLINE Integer() = default; Integer(__m512i in)389 SIMDINLINE Integer(__m512i in) : v(in) {} Integer(SIMD256Impl::Integer const & in_lo,SIMD256Impl::Integer const & in_hi=_mm256_setzero_si256 ())390 SIMDINLINE Integer(SIMD256Impl::Integer const& in_lo, 391 SIMD256Impl::Integer const& in_hi = _mm256_setzero_si256()) 392 { 393 v8[0] = in_lo; 394 v8[1] = in_hi; 395 } operator =(__m512i in)396 SIMDINLINE Integer& SIMDCALL operator=(__m512i in) 397 { 398 v = in; 399 return *this; 400 } operator =(Integer const & in)401 SIMDINLINE Integer& SIMDCALL operator=(Integer const& in) 402 { 403 #if SIMD_ARCH >= SIMD_ARCH_AVX512 404 v = in.v; 405 #else 406 v8[0] = in.v8[0]; 407 v8[1] = in.v8[1]; 408 #endif 409 return *this; 410 } 411 operator __m512i() const412 SIMDINLINE SIMDCALL operator __m512i() const { return v; } 413 414 SIMDALIGN(__m512i, SIMD_ALIGNMENT_BYTES) v; 415 SIMD256Impl::Integer v8[2]; 416 }; 417 418 union Double 419 { 420 SIMDINLINE Double() = default; Double(__m512d in)421 SIMDINLINE Double(__m512d in) : v(in) {} Double(SIMD256Impl::Double const & in_lo,SIMD256Impl::Double const & in_hi=_mm256_setzero_pd ())422 SIMDINLINE Double(SIMD256Impl::Double const& in_lo, 423 SIMD256Impl::Double const& in_hi = _mm256_setzero_pd()) 424 { 425 v8[0] = in_lo; 426 v8[1] = in_hi; 427 } operator =(__m512d in)428 SIMDINLINE Double& SIMDCALL operator=(__m512d in) 429 { 430 v = in; 431 return *this; 432 } operator =(Double const & in)433 SIMDINLINE Double& SIMDCALL operator=(Double const& in) 434 { 435 #if SIMD_ARCH >= SIMD_ARCH_AVX512 436 v = in.v; 437 #else 438 v8[0] = in.v8[0]; 439 v8[1] = in.v8[1]; 440 #endif 441 return *this; 442 } 443 operator __m512d() const444 SIMDINLINE SIMDCALL operator __m512d() const { return v; } 445 446 SIMDALIGN(__m512d, SIMD_ALIGNMENT_BYTES) v; 447 SIMD256Impl::Double v8[2]; 448 }; 449 450 typedef SIMDImpl::Vec4<Float, Integer, Double> SIMDALIGN(Vec4, 64); 451 using Mask = __mmask16; 452 453 static const uint32_t SIMD_WIDTH = 16; 454 455 #undef SIMD_ALIGNMENT_BYTES 456 } // namespace SIMD512Impl 457 } // namespace SIMDImpl 458