• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec_mask.h>
6 
7 namespace at::vec {
8 inline namespace CPU_CAPABILITY {
9 
10 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
11 
12 template <typename T, int dst_n, typename mask_t, int mask_n>
13 struct VecMaskLoad<
14     T,
15     dst_n,
16     mask_t,
17     mask_n,
18     typename std::enable_if_t<
19         (mask_n == dst_n * 2 && dst_n >= 1) &&
20             (std::is_same_v<T, float> || std::is_same_v<T, int32_t>),
21         void>> {
22   static inline VectorizedN<T, dst_n> apply(
23       const T* ptr,
24       const VecMask<mask_t, mask_n>& vec_mask) {
25     at::vec::Vectorized<T> zero_vec(0);
26     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
27     VectorizedN<mask_t, 2> tmp_vec;
28     VectorizedN<T, dst_n> result;
29     for (int i = 0; i < dst_n; i++) {
30       tmp_vec[0] = vec_mask[2 * i];
31       tmp_vec[1] = vec_mask[2 * i + 1];
32       auto int64_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int64_t, 2>();
33       auto int_mask = int64_mask.template cast<int, 1>()[0];
34       auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
35       if constexpr (std::is_same_v<T, float>) {
36         result[i] = Vectorized<T>(_mm512_mask_loadu_ps(
37             zero_vec, mmask, ptr + i * Vectorized<T>::size()));
38       } else {
39         result[i] = Vectorized<T>(_mm512_mask_loadu_epi32(
40             zero_vec, mmask, ptr + i * Vectorized<T>::size()));
41       }
42     }
43     return result;
44   }
45 };
46 
47 template <typename T, int dst_n, typename mask_t>
48 struct VecMaskLoad<
49     T,
50     dst_n,
51     mask_t,
52     dst_n,
53     typename std::enable_if_t<
54         std::is_same_v<T, float> || std::is_same_v<T, int32_t>,
55         void>> {
56   static inline VectorizedN<T, dst_n> apply(
57       const T* ptr,
58       const VecMask<mask_t, dst_n>& vec_mask) {
59     at::vec::Vectorized<T> zero_vec(0);
60     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
61     VectorizedN<T, dst_n> result;
62 #ifndef _MSC_VER
63 #pragma unroll
64 #endif
65     for (int i = 0; i < dst_n; i++) {
66       auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]);
67       auto int_mask = tmp_mask.template cast<int, 1>()[0];
68       auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
69       if constexpr (std::is_same_v<T, float>) {
70         result[i] = Vectorized<T>(_mm512_mask_loadu_ps(
71             zero_vec, mmask, ptr + i * Vectorized<T>::size()));
72       } else {
73         result[i] = Vectorized<T>(_mm512_mask_loadu_epi32(
74             zero_vec, mmask, ptr + i * Vectorized<T>::size()));
75       }
76     }
77     return result;
78   }
79 };
80 
81 template <typename data_t, int dst_n, typename mask_t>
82 struct VecMaskLoad<
83     data_t,
84     dst_n,
85     mask_t,
86     dst_n,
87     typename std::enable_if<
88         std::is_same_v<data_t, BFloat16> ||
89         std::is_same_v<data_t, Half>>::type> {
90   static inline VectorizedN<data_t, dst_n> apply(
91       const data_t* ptr,
92       const VecMask<mask_t, dst_n>& vec_mask) {
93     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
94     VectorizedN<data_t, dst_n> result;
95 #ifndef _MSC_VER
96 #pragma unroll
97 #endif
98     for (int i = 0; i < dst_n; i++) {
99       auto tmp_mask = VecMask<mask_t, 1>(vec_mask[i]);
100       auto int_mask = tmp_mask.template cast<int, 2>();
101       auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ);
102       auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ);
103       auto zero = _mm256_set1_epi16(0);
104       auto temp0 = _mm256_mask_loadu_epi16(
105           zero, mmask0, ptr + (2 * i) * Vectorized<int>::size());
106       auto temp1 = _mm256_mask_loadu_epi16(
107           zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size());
108       result[i] = Vectorized<data_t>(
109           _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1));
110     }
111     return result;
112   }
113 };
114 
115 template <typename data_t, int dst_n, typename mask_t, int mask_n>
116 struct VecMaskLoad<
117     data_t,
118     dst_n,
119     mask_t,
120     mask_n,
121     typename std::enable_if_t<
122         (mask_n == 2 * dst_n && dst_n >= 1) &&
123         (std::is_same_v<data_t, BFloat16> || std::is_same_v<data_t, Half>)>> {
124   static inline VectorizedN<data_t, dst_n> apply(
125       const data_t* ptr,
126       const VecMask<mask_t, mask_n>& vec_mask) {
127     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
128     VectorizedN<data_t, dst_n> result;
129     VectorizedN<mask_t, 2> tmp_vec;
130     for (int i = 0; i < dst_n; i++) {
131       tmp_vec[0] = vec_mask[2 * i];
132       tmp_vec[1] = vec_mask[2 * i + 1];
133       auto int_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int, 2>();
134       auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ);
135       auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ);
136       auto zero = _mm256_set1_epi16(0);
137       auto temp0 = _mm256_mask_loadu_epi16(
138           zero, mmask0, ptr + (2 * i) * Vectorized<int>::size());
139       auto temp1 = _mm256_mask_loadu_epi16(
140           zero, mmask1, ptr + (2 * i + 1) * Vectorized<int>::size());
141       result[i] = Vectorized<data_t>(
142           _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1));
143     }
144     return result;
145   }
146 };
147 
148 template <typename data_t, typename mask_t>
149 struct VecMaskLoad<
150     data_t,
151     1,
152     mask_t,
153     1,
154     typename std::enable_if<
155         std::is_same_v<data_t, int8_t> ||
156         std::is_same_v<data_t, uint8_t>>::type> {
157   static inline VectorizedN<data_t, 1> apply(
158       const data_t* ptr,
159       const VecMask<mask_t, 1>& vec_mask) {
160     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
161     auto int_mask = vec_mask.template cast<int, 1>()[0];
162     auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
163     auto zero = _mm_set1_epi8(0);
164     auto temp = _mm_mask_loadu_epi8(zero, mmask, ptr);
165     return Vectorized<data_t>(
166         _mm512_inserti64x2(_mm512_set1_epi32(0), temp, 0));
167   }
168 };
169 
170 template <typename data_t, typename mask_t>
171 struct VecMaskLoad<
172     data_t,
173     2,
174     mask_t,
175     1,
176     typename std::enable_if<
177         std::is_same_v<data_t, int64_t> ||
178         std::is_same_v<data_t, double>>::type> {
179   static inline VectorizedN<data_t, 2> apply(
180       const data_t* ptr,
181       const VecMask<mask_t, 1>& vec_mask) {
182     auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
183     at::vec::Vectorized<data_t> zero_vec(0);
184     auto int_mask = vec_mask.template cast<int, 1>()[0];
185     auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
186     at::vec::VectorizedN<data_t, 2> result;
187     if constexpr (std::is_same_v<data_t, double>) {
188       result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr);
189       result[1] =
190           _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8);
191     } else {
192       result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr);
193       result[1] =
194           _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8);
195     }
196     return result;
197   }
198 };
199 
200 template <int N>
201 struct VecMaskCast<float, N, int, N> {
202   static inline VecMask<float, N> apply(const VecMask<int, N>& vec_mask) {
203     VectorizedN<float, N> result;
204 #ifndef _MSC_VER
205 #pragma unroll
206 #endif
207     for (int i = 0; i < N; ++i) {
208       result[i] = _mm512_castsi512_ps(vec_mask[i]);
209     }
210     return result;
211   }
212 };
213 
214 template <int N>
215 struct VecMaskCast<int, N, float, N> {
216   static inline VecMask<int, N> apply(const VecMask<float, N>& vec_mask) {
217     VectorizedN<int, N> result;
218 #ifndef _MSC_VER
219 #pragma unroll
220 #endif
221     for (int i = 0; i < N; ++i) {
222       result[i] = _mm512_castps_si512(vec_mask[i]);
223     }
224     return result;
225   }
226 };
227 
228 template <int N>
229 struct VecMaskCast<int64_t, N, double, N> {
230   static inline VecMask<int64_t, N> apply(const VecMask<double, N>& vec_mask) {
231     VectorizedN<int64_t, N> result;
232 #ifndef _MSC_VER
233 #pragma unroll
234 #endif
235     for (int i = 0; i < N; ++i) {
236       result[i] = _mm512_castpd_si512(vec_mask[i]);
237     }
238     return result;
239   }
240 };
241 
242 template <int N>
243 struct VecMaskCast<double, N, int64_t, N> {
244   static inline VecMask<double, N> apply(const VecMask<int64_t, N>& vec_mask) {
245     VectorizedN<double, N> result;
246 #ifndef _MSC_VER
247 #pragma unroll
248 #endif
249     for (int i = 0; i < N; ++i) {
250       result[i] = _mm512_castsi512_pd(vec_mask[i]);
251     }
252     return result;
253   }
254 };
255 
256 template <int dst_n, typename mask_t, int mask_n>
257 struct VecMaskCast<
258     int64_t,
259     dst_n,
260     mask_t,
261     mask_n,
262     typename std::enable_if_t<
263         (dst_n == 2 * mask_n) &&
264             (std::is_same_v<mask_t, float> || std::is_same_v<mask_t, int>),
265         void>> {
266   static inline VecMask<int64_t, dst_n> apply(
267       const VecMask<mask_t, mask_n>& vec_mask) {
268     VectorizedN<int64_t, dst_n> result;
269     auto int_mask = vec_mask.template cast<int, mask_n>();
270 #ifndef _MSC_VER
271 #pragma unroll
272 #endif
273     for (int i = 0; i < mask_n; ++i) {
274       auto int64_vec =
275           convert<int64_t, 2, int, 1>(VectorizedN<int, 1>(int_mask[i]));
276       result[2 * i] = int64_vec[0];
277       result[2 * i + 1] = int64_vec[1];
278     }
279     return VecMask<int64_t, dst_n>(result);
280   }
281 };
282 
283 template <typename dst_t, int dst_n, int mask_n>
284 struct VecMaskCast<
285     dst_t,
286     dst_n,
287     int64_t,
288     mask_n,
289     typename std::enable_if_t<
290         (mask_n == 2 * dst_n) &&
291             (std::is_same_v<dst_t, float> || std::is_same_v<dst_t, int>),
292         void>> {
293   static inline VecMask<dst_t, dst_n> apply(
294       const VecMask<int64_t, mask_n>& vec_mask) {
295     VectorizedN<int, dst_n> result;
296     VectorizedN<int64_t, 2> int64_vec;
297     for (int i = 0; i < dst_n; ++i) {
298       int64_vec[0] = vec_mask[2 * i];
299       int64_vec[1] = vec_mask[2 * i + 1];
300       result[i] = convert<int, 1, int64_t, 2>(int64_vec);
301     }
302     return VecMask<int, dst_n>(result).template cast<dst_t, dst_n>();
303   }
304 };
305 
306 template <>
307 struct VecMaskCast<double, 2, float, 1> {
308   static inline VecMask<double, 2> apply(const VecMask<float, 1>& vec_mask) {
309     auto int64_mask = VecMaskCast<int64_t, 2, float, 1>::apply(vec_mask);
310     return VecMaskCast<double, 2, int64_t, 2>::apply(int64_mask);
311   }
312 };
313 
314 template <>
315 struct VecMaskCast<float, 1, double, 2> {
316   static inline VecMask<float, 1> apply(const VecMask<double, 2>& vec_mask) {
317     auto int64_mask = VecMaskCast<int64_t, 2, double, 2>::apply(vec_mask);
318     return VecMaskCast<float, 1, int64_t, 2>::apply(int64_mask);
319   }
320 };
321 
322 template <>
323 inline bool VecMask<int, 1>::all_zero() const {
324   __mmask16 mask = _mm512_test_epi32_mask(mask_[0], mask_[0]);
325   return mask == 0;
326 }
327 
328 template <>
329 inline bool VecMask<int, 1>::is_masked(int i) const {
330   return _mm512_movepi32_mask(mask_[0]) & (1 << i);
331 }
332 
333 template <>
334 inline bool VecMask<int, 1>::all_masked() const {
335   __mmask16 mask = _mm512_movepi32_mask(mask_[0]);
336   return mask == 0xffff;
337 }
338 
339 template <int N>
340 struct VecMaskCheck<int64_t, N> {
341   static inline bool all_zero(const VectorizedN<int64_t, N>& vec_mask) {
342     bool all_zero = true;
343     for (int i = 0; i < N; ++i) {
344       all_zero =
345           all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0);
346       if (!all_zero) {
347         return all_zero;
348       }
349     }
350     return all_zero;
351   }
352 
353   static inline bool is_masked(const VectorizedN<int64_t, N>& vec_mask, int i) {
354     for (int j = 0; j < N; ++j) {
355       if (i < (j + 1) * 8) {
356         return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8));
357       }
358     }
359     return false;
360   }
361 
362   static inline bool all_masked(const VectorizedN<int64_t, N>& vec_mask) {
363     bool all_masked = true;
364     for (int i = 0; i < N; ++i) {
365       all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff);
366       if (!all_masked) {
367         return all_masked;
368       }
369     }
370     return all_masked;
371   }
372 };
373 
374 #define VEC_MASK_METHOD_WITH_CAST_TO_INT(                   \
375     T, N, return_type, method, args_def, args)              \
376   template <>                                               \
377   inline return_type VecMask<T, N>::method args_def const { \
378     return cast<int, 1>().method args;                      \
379   }
380 
381 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_zero, (), ())
382 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_zero, (), ())
383 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, is_masked, (int i), (i))
384 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, is_masked, (int i), (i))
385 VEC_MASK_METHOD_WITH_CAST_TO_INT(float, 1, bool, all_masked, (), ())
386 VEC_MASK_METHOD_WITH_CAST_TO_INT(int64_t, 2, bool, all_masked, (), ())
387 
388 #undef VEC_MASK_DEFINE_METHOD_WITH_CAST_TO_INT
389 
390 #endif
391 
392 } // namespace CPU_CAPABILITY
393 } // namespace at::vec
394