1 #pragma once 2 3 #include <ATen/cpu/vec/intrinsics.h> 4 #include <ATen/cpu/vec/vec_base.h> 5 #include <ATen/cpu/vec/vec_mask.h> 6 7 namespace at::vec { 8 inline namespace CPU_CAPABILITY { 9 10 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) 11 12 template <typename T, int dst_n, typename mask_t, int mask_n> 13 struct VecMaskLoad< 14 T, 15 dst_n, 16 mask_t, 17 mask_n, 18 typename std::enable_if_t< 19 (mask_n == dst_n * 2 && dst_n >= 1) && 20 (std::is_same_v<T, float> || std::is_same_v<T, int32_t>), 21 void>> { 22 static inline VectorizedN<T, dst_n> apply( 23 const T* ptr, 24 const VecMask<mask_t, mask_n>& vec_mask) { 25 at::vec::Vectorized<T> zero_vec(0); 26 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 27 VectorizedN<mask_t, 2> tmp_vec; 28 VectorizedN<T, dst_n> result; 29 for (int i = 0; i < dst_n; i++) { 30 tmp_vec[0] = vec_mask[2 * i]; 31 tmp_vec[1] = vec_mask[2 * i + 1]; 32 auto int64_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int64_t, 2>(); 33 auto int_mask = int64_mask.template cast<int, 1>()[0]; 34 auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); 35 if constexpr (std::is_same_v<T, float>) { 36 result[i] = Vectorized<T>(_mm512_mask_loadu_ps( 37 zero_vec, mmask, ptr + i * Vectorized<T>::size())); 38 } else { 39 result[i] = Vectorized<T>(_mm512_mask_loadu_epi32( 40 zero_vec, mmask, ptr + i * Vectorized<T>::size())); 41 } 42 } 43 return result; 44 } 45 }; 46 47 template <typename T, int dst_n, typename mask_t> 48 struct VecMaskLoad< 49 T, 50 dst_n, 51 mask_t, 52 dst_n, 53 typename std::enable_if_t< 54 std::is_same_v<T, float> || std::is_same_v<T, int32_t>, 55 void>> { 56 static inline VectorizedN<T, dst_n> apply( 57 const T* ptr, 58 const VecMask<mask_t, dst_n>& vec_mask) { 59 at::vec::Vectorized<T> zero_vec(0); 60 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 61 VectorizedN<T, dst_n> result; 62 #ifndef _MSC_VER 63 #pragma unroll 64 #endif 65 for (int i = 0; i < dst_n; i++) { 66 auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]); 67 auto int_mask = tmp_mask.template cast<int, 1>()[0]; 68 auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); 69 if constexpr (std::is_same_v<T, float>) { 70 result[i] = Vectorized<T>(_mm512_mask_loadu_ps( 71 zero_vec, mmask, ptr + i * Vectorized<T>::size())); 72 } else { 73 result[i] = Vectorized<T>(_mm512_mask_loadu_epi32( 74 zero_vec, mmask, ptr + i * Vectorized<T>::size())); 75 } 76 } 77 return result; 78 } 79 }; 80 81 template <typename data_t, int dst_n, typename mask_t> 82 struct VecMaskLoad< 83 data_t, 84 dst_n, 85 mask_t, 86 dst_n, 87 typename std::enable_if< 88 std::is_same_v<data_t, BFloat16> || 89 std::is_same_v<data_t, Half>>::type> { 90 static inline VectorizedN<data_t, dst_n> apply( 91 const data_t* ptr, 92 const VecMask<mask_t, dst_n>& vec_mask) { 93 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 94 VectorizedN<data_t, dst_n> result; 95 #ifndef _MSC_VER 96 #pragma unroll 97 #endif 98 for (int i = 0; i < dst_n; i++) { 99 auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]); 100 auto int_mask = tmp_mask.template cast<int, 2>(); 101 auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); 102 auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); 103 auto zero = _mm256_set1_epi16(0); 104 auto temp0 = _mm256_mask_loadu_epi16( 105 zero, mmask0, ptr + (2 * i) * Vectorized<int>::size()); 106 auto temp1 = _mm256_mask_loadu_epi16( 107 zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size()); 108 result[i] = Vectorized<data_t>( 109 _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); 110 } 111 return result; 112 } 113 }; 114 115 template <typename data_t, int dst_n, typename mask_t, int mask_n> 116 struct VecMaskLoad< 117 data_t, 118 dst_n, 119 mask_t, 120 mask_n, 121 typename std::enable_if_t< 122 (mask_n == 2 * dst_n && dst_n >= 1) && 123 (std::is_same_v<data_t, BFloat16> || std::is_same_v<data_t, Half>)>> { 124 static inline VectorizedN<data_t, dst_n> apply( 125 const data_t* ptr, 126 const VecMask<mask_t, mask_n>& vec_mask) { 127 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 128 VectorizedN<data_t, dst_n> result; 129 VectorizedN<mask_t, 2> tmp_vec; 130 for (int i = 0; i < dst_n; i++) { 131 tmp_vec[0] = vec_mask[2 * i]; 132 tmp_vec[1] = vec_mask[2 * i + 1]; 133 auto int_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int, 2>(); 134 auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); 135 auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); 136 auto zero = _mm256_set1_epi16(0); 137 auto temp0 = _mm256_mask_loadu_epi16( 138 zero, mmask0, ptr + (2 * i) * Vectorized<int>::size()); 139 auto temp1 = _mm256_mask_loadu_epi16( 140 zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size()); 141 result[i] = Vectorized<data_t>( 142 _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); 143 } 144 return result; 145 } 146 }; 147 148 template <typename data_t, typename mask_t> 149 struct VecMaskLoad< 150 data_t, 151 1, 152 mask_t, 153 1, 154 typename std::enable_if< 155 std::is_same_v<data_t, int8_t> || 156 std::is_same_v<data_t, uint8_t>>::type> { 157 static inline VectorizedN<data_t, 1> apply( 158 const data_t* ptr, 159 const VecMask<mask_t, 1>& vec_mask) { 160 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 161 auto int_mask = vec_mask.template cast<int, 1>()[0]; 162 auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); 163 auto zero = _mm_set1_epi8(0); 164 auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr); 165 return Vectorized<data_t>( 166 _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0)); 167 } 168 }; 169 170 template <typename data_t, typename mask_t> 171 struct VecMaskLoad< 172 data_t, 173 2, 174 mask_t, 175 1, 176 typename std::enable_if< 177 std::is_same_v<data_t, int64_t> || 178 std::is_same_v<data_t, double>>::type> { 179 static inline VectorizedN<data_t, 2> apply( 180 const data_t* ptr, 181 const VecMask<mask_t, 1>& vec_mask) { 182 auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); 183 at::vec::Vectorized<data_t> zero_vec(0); 184 auto int_mask = vec_mask.template cast<int, 1>()[0]; 185 auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); 186 at::vec::VectorizedN<data_t, 2> result; 187 if constexpr (std::is_same_v<data_t, double>) { 188 result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); 189 result[1] = 190 _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); 191 } else { 192 result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); 193 result[1] = 194 _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); 195 } 196 return result; 197 } 198 }; 199 200 template <int N> 201 struct VecMaskCast<float, N, int, N> { 202 static inline VecMask<float, N> apply(const VecMask<int, N>& vec_mask) { 203 VectorizedN<float, N> result; 204 #ifndef _MSC_VER 205 #pragma unroll 206 #endif 207 for (int i = 0; i < N; ++i) { 208 result[i] = _mm512_castsi512_ps(vec_mask[i]); 209 } 210 return result; 211 } 212 }; 213 214 template <int N> 215 struct VecMaskCast<int, N, float, N> { 216 static inline VecMask<int, N> apply(const VecMask<float, N>& vec_mask) { 217 VectorizedN<int, N> result; 218 #ifndef _MSC_VER 219 #pragma unroll 220 #endif 221 for (int i = 0; i < N; ++i) { 222 result[i] = _mm512_castps_si512(vec_mask[i]); 223 } 224 return result; 225 } 226 }; 227 228 template <int N> 229 struct VecMaskCast<int64_t, N, double, N> { 230 static inline VecMask<int64_t, N> apply(const VecMask<double, N>& vec_mask) { 231 VectorizedN<int64_t, N> result; 232 #ifndef _MSC_VER 233 #pragma unroll 234 #endif 235 for (int i = 0; i < N; ++i) { 236 result[i] = _mm512_castpd_si512(vec_mask[i]); 237 } 238 return result; 239 } 240 }; 241 242 template <int N> 243 struct VecMaskCast<double, N, int64_t, N> { 244 static inline VecMask<double, N> apply(const VecMask<int64_t, N>& vec_mask) { 245 VectorizedN<double, N> result; 246 #ifndef _MSC_VER 247 #pragma unroll 248 #endif 249 for (int i = 0; i < N; ++i) { 250 result[i] = _mm512_castsi512_pd(vec_mask[i]); 251 } 252 return result; 253 } 254 }; 255 256 template <int dst_n, typename mask_t, int mask_n> 257 struct VecMaskCast< 258 int64_t, 259 dst_n, 260 mask_t, 261 mask_n, 262 typename std::enable_if_t< 263 (dst_n == 2 * mask_n) && 264 (std::is_same_v<mask_t, float> || std::is_same_v<mask_t, int>), 265 void>> { 266 static inline VecMask<int64_t, dst_n> apply( 267 const VecMask<mask_t, mask_n>& vec_mask) { 268 VectorizedN<int64_t, dst_n> result; 269 auto int_mask = vec_mask.template cast<int, mask_n>(); 270 #ifndef _MSC_VER 271 #pragma unroll 272 #endif 273 for (int i = 0; i < mask_n; ++i) { 274 auto int64_vec = 275 convert<int64_t, 2, int, 1>(VectorizedN<int, 1>(int_mask[i])); 276 result[2 * i] = int64_vec[0]; 277 result[2 * i + 1] = int64_vec[1]; 278 } 279 return VecMask<int64_t, dst_n>(result); 280 } 281 }; 282 283 template <typename dst_t, int dst_n, int mask_n> 284 struct VecMaskCast< 285 dst_t, 286 dst_n, 287 int64_t, 288 mask_n, 289 typename std::enable_if_t< 290 (mask_n == 2 * dst_n) && 291 (std::is_same_v<dst_t, float> || std::is_same_v<dst_t, int>), 292 void>> { 293 static inline VecMask<dst_t, dst_n> apply( 294 const VecMask<int64_t, mask_n>& vec_mask) { 295 VectorizedN<int, dst_n> result; 296 VectorizedN<int64_t, 2> int64_vec; 297 for (int i = 0; i < dst_n; ++i) { 298 int64_vec[0] = vec_mask[2 * i]; 299 int64_vec[1] = vec_mask[2 * i + 1]; 300 result[i] = convert<int, 1, int64_t, 2>(int64_vec); 301 } 302 return VecMask<int, dst_n>(result).template cast<dst_t, dst_n>(); 303 } 304 }; 305 306 template <> 307 struct VecMaskCast<double, 2, float, 1> { 308 static inline VecMask<double, 2> apply(const VecMask<float, 1>& vec_mask) { 309 auto int64_mask = VecMaskCast<int64_t, 2, float, 1>::apply(vec_mask); 310 return VecMaskCast<double, 2, int64_t, 2>::apply(int64_mask); 311 } 312 }; 313 314 template <> 315 struct VecMaskCast<float, 1, double, 2> { 316 static inline VecMask<float, 1> apply(const VecMask<double, 2>& vec_mask) { 317 auto int64_mask = VecMaskCast<int64_t, 2, double, 2>::apply(vec_mask); 318 return VecMaskCast<float, 1, int64_t, 2>::apply(int64_mask); 319 } 320 }; 321 322 template <> 323 inline bool VecMask<int, 1>::all_zero() const { 324 __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]); 325 return mask == 0; 326 } 327 328 template <> 329 inline bool VecMask<int, 1>::is_masked(int i) const { 330 return _mm512_movepi32_mask(mask_[0]) & (1 << i); 331 } 332 333 template <> 334 inline bool VecMask<int, 1>::all_masked() const { 335 __mmask16 mask = _mm512_movepi32_mask(mask_[0]); 336 return mask == 0xffff; 337 } 338 339 template <int N> 340 struct VecMaskCheck<int64_t, N> { 341 static inline bool all_zero(const VectorizedN<int64_t, N>& vec_mask) { 342 bool all_zero = true; 343 for (int i = 0; i < N; ++i) { 344 all_zero = 345 all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); 346 if (!all_zero) { 347 return all_zero; 348 } 349 } 350 return all_zero; 351 } 352 353 static inline bool is_masked(const VectorizedN<int64_t, N>& vec_mask, int i) { 354 for (int j = 0; j < N; ++j) { 355 if (i < (j + 1) * 8) { 356 return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); 357 } 358 } 359 return false; 360 } 361 362 static inline bool all_masked(const VectorizedN<int64_t, N>& vec_mask) { 363 bool all_masked = true; 364 for (int i = 0; i < N; ++i) { 365 all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); 366 if (!all_masked) { 367 return all_masked; 368 } 369 } 370 return all_masked; 371 } 372 }; 373 374 #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ 375 T, N, return_type, method, args_def, args) \ 376 template <> \ 377 inline return_type VecMask<T, N>::method args_def const { \ 378 return cast<int, 1>().method args; \ 379 } 380 381 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ()) 382 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ()) 383 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i)) 384 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i)) 385 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ()) 386 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ()) 387 388 #undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT 389 390 #endif 391 392 } // namespace CPU_CAPABILITY 393 } // namespace at::vec 394