1 #pragma once
2
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5
6 #include <ATen/cpu/vec/intrinsics.h>
7 #include <ATen/cpu/vec/vec_base.h>
8 #include <c10/util/irange.h>
9
10 #if defined(CPU_CAPABILITY_AVX512)
11 #define SLEEF_STATIC_LIBS
12 #include <sleef.h>
13 #endif
14
15 namespace at {
16 namespace vec {
17 // See Note [CPU_CAPABILITY namespace]
18 inline namespace CPU_CAPABILITY {
19
20 #if defined(CPU_CAPABILITY_AVX512)
21
22 #ifndef SLEEF_CONST
23 #if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
24 #define SLEEF_CONST const
25 #else
26 #define SLEEF_CONST
27 #endif
28 #define SLEEF_CONST_OLD SLEEF_CONST
29 #else
30 #define SLEEF_CONST_OLD
31 #endif
32
33 // bfloat16 conversion
cvtbf16_fp32(const __m256i & a,__m512 & o)34 static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
35 o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
36 }
37
cvtbf16_fp32(const __m512i & a,__m512 & o1,__m512 & o2)38 static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
39 __m256i lo = _mm512_extracti32x8_epi32(a, 0);
40 __m256i hi = _mm512_extracti32x8_epi32(a, 1);
41 cvtbf16_fp32(lo, o1);
42 cvtbf16_fp32(hi, o2);
43 }
44
cvtfp32_bf16(const __m512 & src)45 static inline __m256i cvtfp32_bf16(const __m512& src) {
46 __m512i value = _mm512_castps_si512(src);
47 __m512i nan = _mm512_set1_epi32(0xffff);
48 auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
49 __m512i ones = _mm512_set1_epi32(0x1);
50 __m512i vec_bias = _mm512_set1_epi32(0x7fff);
51 // uint32_t lsb = (input >> 16) & 1;
52 auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
53 // uint32_t rounding_bias = 0x7fff + lsb;
54 t_value = _mm512_add_epi32(t_value, vec_bias);
55 // input += rounding_bias;
56 t_value = _mm512_add_epi32(t_value, value);
57 // input = input >> 16;
58 t_value = _mm512_srli_epi32(t_value, 16);
59 // Check NaN before converting back to bf16
60 t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
61 return _mm512_cvtusepi32_epi16(t_value);
62 }
63
cvtfp32_bf16(const __m512 & a,const __m512 & b)64 static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
65 __m512i lo = _mm512_castps_si512(a);
66 __m512i hi = _mm512_castps_si512(b);
67 __m512i nan = _mm512_set1_epi32(0xffff);
68 auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
69 auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
70 __m512i ones = _mm512_set1_epi32(0x1);
71 __m512i vec_bias = _mm512_set1_epi32(0x7fff);
72 // uint32_t lsb = (input >> 16) & 1;
73 auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
74 auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
75 // uint32_t rounding_bias = 0x7fff + lsb;
76 t_lo = _mm512_add_epi32(t_lo, vec_bias);
77 t_hi = _mm512_add_epi32(t_hi, vec_bias);
78 // input += rounding_bias;
79 t_lo = _mm512_add_epi32(t_lo, lo);
80 t_hi = _mm512_add_epi32(t_hi, hi);
81 // input = input >> 16;
82 t_lo = _mm512_srli_epi32(t_lo, 16);
83 t_hi = _mm512_srli_epi32(t_hi, 16);
84 // Check NaN before converting back to bf16
85 t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
86 t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
87
88 t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
89 __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
90 return _mm512_permutexvar_epi64(idx, t_lo);
91 }
92
merge_compare_result(const __m512 & a,const __m512 & b)93 static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
94 __m512i lo = _mm512_castps_si512(a);
95 __m512i hi = _mm512_castps_si512(b);
96 lo = _mm512_srli_epi32(lo, 16);
97 hi = _mm512_srli_epi32(hi, 16);
98 auto out = _mm512_packus_epi32(lo, hi);
99 __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
100 return _mm512_permutexvar_epi64(idx, out);
101 }
102
103 // float16 conversion
cvtfp16_fp32(const __m256i & a,__m512 & o)104 static inline void cvtfp16_fp32(const __m256i& a, __m512& o) {
105 o = _mm512_cvtph_ps(a);
106 }
107
cvtfp16_fp32(const __m512i & a,__m512 & o1,__m512 & o2)108 static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
109 __m256i lo = _mm512_extracti32x8_epi32(a, 0);
110 __m256i hi = _mm512_extracti32x8_epi32(a, 1);
111 cvtfp16_fp32(lo, o1);
112 cvtfp16_fp32(hi, o2);
113 }
114
cvtfp32_fp16(const __m512 & src)115 static inline __m256i cvtfp32_fp16(const __m512& src) {
116 return _mm512_cvtps_ph(
117 src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
118 }
119
cvtfp32_fp16(const __m512 & a,const __m512 & b)120 static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) {
121 __m256i lo = _mm512_cvtps_ph(
122 a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
123 __m256i hi = _mm512_cvtps_ph(
124 b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
125 __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo));
126 __m256 t_hi = _mm256_castsi256_ps(hi);
127 return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1));
128 }
129
130 // dtype conversion between float16/bfloat16 and float32
131 template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
132 inline void cvt_to_fp32(const __m256i& a, __m512& o);
133 template <> inline void cvt_to_fp32<BFloat16>(const __m256i& a, __m512& o) {
134 cvtbf16_fp32(a, o);
135 }
136 template <> inline void cvt_to_fp32<Half>(const __m256i& a, __m512& o) {
137 cvtfp16_fp32(a, o);
138 }
139
140 template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
141 inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2);
142 template <> inline void cvt_to_fp32<BFloat16>(const __m512i& a, __m512& o1, __m512& o2) {
143 cvtbf16_fp32(a, o1, o2);
144 }
145 template <> inline void cvt_to_fp32<Half>(const __m512i& a, __m512& o1, __m512& o2) {
146 cvtfp16_fp32(a, o1, o2);
147 }
148
149 template <typename T, bool is_compare_op = false,
150 typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
151 inline __m512i cvt_from_fp32(const __m512& a, const __m512& b);
152 template <> inline __m512i cvt_from_fp32<BFloat16, false>(const __m512& a, const __m512& b) {
153 return cvtfp32_bf16(a, b);
154 }
155 template <> inline __m512i cvt_from_fp32<BFloat16, true>(const __m512& a, const __m512& b) {
156 return merge_compare_result(a, b);
157 }
158 template <> inline __m512i cvt_from_fp32<Half, false>(const __m512& a, const __m512& b) {
159 return cvtfp32_fp16(a, b);
160 }
161 template <> inline __m512i cvt_from_fp32<Half, true>(const __m512& a, const __m512& b) {
162 return cvtfp32_fp16(a, b);
163 }
164
165 template <typename T>
166 class Vectorized16 {
167 static_assert(
168 is_reduced_floating_point_v<T>,
169 "Support only float16 and bfloat16.");
170 private:
171 __m512i values;
172 public:
173 using value_type = uint16_t;
174 using size_type = int;
size()175 static constexpr size_type size() {
176 return 32;
177 }
Vectorized16()178 Vectorized16() {}
Vectorized16(__m512i v)179 Vectorized16(__m512i v) : values(v) {}
Vectorized16(T val)180 Vectorized16(T val) {
181 value_type uw = val.x;
182 values = _mm512_set1_epi16(uw);
183 }
Vectorized16(T val1,T val2,T val3,T val4,T val5,T val6,T val7,T val8,T val9,T val10,T val11,T val12,T val13,T val14,T val15,T val16,T val17,T val18,T val19,T val20,T val21,T val22,T val23,T val24,T val25,T val26,T val27,T val28,T val29,T val30,T val31,T val32)184 Vectorized16(T val1, T val2, T val3, T val4,
185 T val5, T val6, T val7, T val8,
186 T val9, T val10, T val11, T val12,
187 T val13, T val14, T val15, T val16,
188 T val17, T val18, T val19, T val20,
189 T val21, T val22, T val23, T val24,
190 T val25, T val26, T val27, T val28,
191 T val29, T val30, T val31, T val32) {
192 values = _mm512_set_epi16(
193 val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
194 val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
195 val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
196 val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
197 }
__m512i()198 operator __m512i() const {
199 return values;
200 }
201 T& operator[](int idx) = delete;
202 const T& operator[](int idx) const = delete;
zero_mask()203 int zero_mask() const {
204 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
205 return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
206 }
207 static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
208 if (count == size())
209 return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
210
211 __mmask32 mask = (1ULL << count) - 1;
212 return _mm512_maskz_loadu_epi16(mask, ptr);
213 }
214 void store(void* ptr, int count = size()) const {
215 if (count == size()) {
216 _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
217 } else if (count > 0) {
218 __mmask32 mask = (1ULL << count) - 1;
219 _mm512_mask_storeu_epi16(ptr, mask, values);
220 }
221 }
222 template <int64_t mask>
blend(const Vectorized<T> & a,const Vectorized<T> & b)223 static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
224 __at_align__ int16_t tmp_values[size()];
225 a.store(tmp_values);
226 if (mask & 0x01)
227 tmp_values[0] = b.values[31];
228 if (mask & 0x02)
229 tmp_values[1] = b.values[30];
230 if (mask & 0x04)
231 tmp_values[2] = b.values[29];
232 if (mask & 0x08)
233 tmp_values[3] = b.values[28];
234 if (mask & 0x10)
235 tmp_values[4] = b.values[27];
236 if (mask & 0x20)
237 tmp_values[5] = b.values[26];
238 if (mask & 0x40)
239 tmp_values[6] = b.values[25];
240 if (mask & 0x80)
241 tmp_values[7] = b.values[24];
242 if (mask & 0x100)
243 tmp_values[8] = b.values[23];
244 if (mask & 0x200)
245 tmp_values[9] = b.values[22];
246 if (mask & 0x400)
247 tmp_values[10] = b.values[21];
248 if (mask & 0x800)
249 tmp_values[11] = b.values[20];
250 if (mask & 0x1000)
251 tmp_values[12] = b.values[19];
252 if (mask & 0x2000)
253 tmp_values[13] = b.values[18];
254 if (mask & 0x4000)
255 tmp_values[14] = b.values[17];
256 if (mask & 0x8000)
257 tmp_values[15] = b.values[16];
258 if (mask & 0x10000)
259 tmp_values[16] = b.values[15];
260 if (mask & 0x20000)
261 tmp_values[17] = b.values[14];
262 if (mask & 0x40000)
263 tmp_values[18] = b.values[13];
264 if (mask & 0x80000)
265 tmp_values[19] = b.values[12];
266 if (mask & 0x100000)
267 tmp_values[20] = b.values[11];
268 if (mask & 0x200000)
269 tmp_values[21] = b.values[10];
270 if (mask & 0x400000)
271 tmp_values[22] = b.values[9];
272 if (mask & 0x800000)
273 tmp_values[23] = b.values[8];
274 if (mask & 0x1000000)
275 tmp_values[24] = b.values[7];
276 if (mask & 0x2000000)
277 tmp_values[25] = b.values[6];
278 if (mask & 0x4000000)
279 tmp_values[26] = b.values[5];
280 if (mask & 0x8000000)
281 tmp_values[27] = b.values[4];
282 if (mask & 0x10000000)
283 tmp_values[28] = b.values[3];
284 if (mask & 0x20000000)
285 tmp_values[29] = b.values[2];
286 if (mask & 0x40000000)
287 tmp_values[30] = b.values[1];
288 if (mask & 0x80000000)
289 tmp_values[31] = b.values[0];
290 return loadu(tmp_values);
291 }
blendv(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & mask)292 static Vectorized<T> blendv(const Vectorized<T>& a,
293 const Vectorized<T>& b, const Vectorized<T>& mask) {
294 auto all_ones = _mm512_set1_epi16(0xFFFF);
295 auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
296 return _mm512_mask_blend_epi16(mask_, a.values, b.values);
297 }
298 template<typename step_t>
299 static Vectorized<T> arange(T base = 0.f, step_t step = static_cast<step_t>(1)) {
300 return Vectorized<T>(
301 base, base + step, base + 2 * step, base + 3 * step,
302 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
303 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
304 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
305 base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
306 base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
307 base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
308 base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
309 }
310 static Vectorized<T> set(const Vectorized<T>& a,
311 const Vectorized<T>& b, int64_t count = size()) {
312 switch (count) {
313 case 0:
314 return a;
315 case 1:
316 return blend<1>(a, b);
317 case 2:
318 return blend<3>(a, b);
319 case 3:
320 return blend<7>(a, b);
321 case 4:
322 return blend<15>(a, b);
323 case 5:
324 return blend<31>(a, b);
325 case 6:
326 return blend<63>(a, b);
327 case 7:
328 return blend<127>(a, b);
329 case 8:
330 return blend<255>(a, b);
331 case 9:
332 return blend<511>(a, b);
333 case 10:
334 return blend<1023>(a, b);
335 case 11:
336 return blend<2047>(a, b);
337 case 12:
338 return blend<4095>(a, b);
339 case 13:
340 return blend<8191>(a, b);
341 case 14:
342 return blend<16383>(a, b);
343 case 15:
344 return blend<32767>(a, b);
345 case 16:
346 return blend<65535>(a, b);
347 case 17:
348 return blend<131071>(a, b);
349 case 18:
350 return blend<262143>(a, b);
351 case 19:
352 return blend<524287>(a, b);
353 case 20:
354 return blend<1048575>(a, b);
355 case 21:
356 return blend<2097151>(a, b);
357 case 22:
358 return blend<4194303>(a, b);
359 case 23:
360 return blend<8388607>(a, b);
361 case 24:
362 return blend<16777215>(a, b);
363 case 25:
364 return blend<33554431>(a, b);
365 case 26:
366 return blend<67108863>(a, b);
367 case 27:
368 return blend<134217727>(a, b);
369 case 28:
370 return blend<268435455>(a, b);
371 case 29:
372 return blend<536870911>(a, b);
373 case 30:
374 return blend<1073741823>(a, b);
375 case 31:
376 return blend<2147483647>(a, b);
377 }
378 return b;
379 }
380 #pragma clang diagnostic push
381 #pragma clang diagnostic ignored "-Wignored-qualifiers"
382
map(SLEEF_CONST __m512 (* SLEEF_CONST_OLD vop)(__m512))383 Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
384 __m512 lo, hi;
385 cvt_to_fp32<T>(values, lo, hi);
386 const auto o1 = vop(lo);
387 const auto o2 = vop(hi);
388 return cvt_from_fp32<T>(o1, o2);
389 }
isnan()390 Vectorized<T> isnan() const {
391 __m512 lo, hi;
392 cvt_to_fp32<T>(values, lo, hi);
393 __mmask16 lo_mask, hi_mask;
394 __m512 zero = _mm512_set1_ps(0.0);
395 __m512i zeroi = _mm512_castps_si512(zero);
396 lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
397 lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
398 hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
399 hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
400 return merge_compare_result(lo, hi);
401 }
402 #pragma clang diagnostic pop
abs()403 Vectorized<T> abs() const {
404 return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
405 }
angle()406 Vectorized<T> angle() const {
407 __m512 lo, hi;
408 cvt_to_fp32<T>(values, lo, hi);
409 auto angle_lambda = [](__m512 values) {
410 const auto zero_vec = _mm512_set1_ps(0.f);
411 const auto nan_vec = _mm512_set1_ps(NAN);
412 const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
413 const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
414 not_nan_mask, 0xFFFFFFFF);
415 const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
416 zero_vec, _CMP_EQ_OQ);
417 const auto pi = _mm512_set1_ps(c10::pi<float>);
418
419 const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
420 auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
421 angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
422 return angle;
423 };
424 auto o1 = angle_lambda(lo);
425 auto o2 = angle_lambda(hi);
426 return cvt_from_fp32<T>(o1, o2);
427 }
real()428 Vectorized<T> real() const {
429 return *this;
430 }
imag()431 Vectorized<T> imag() const {
432 return _mm512_set1_epi16(0);
433 }
conj()434 Vectorized<T> conj() const {
435 return *this;
436 }
acos()437 Vectorized<T> acos() const {
438 return map(Sleef_acosf16_u10);
439 }
acosh()440 Vectorized<T> acosh() const {
441 return map(Sleef_acoshf16_u10);
442 }
asin()443 Vectorized<T> asin() const {
444 return map(Sleef_asinf16_u10);
445 }
atan()446 Vectorized<T> atan() const {
447 return map(Sleef_atanf16_u10);
448 }
atanh()449 Vectorized<T> atanh() const {
450 return map(Sleef_atanhf16_u10);
451 }
atan2(const Vectorized<T> & b)452 Vectorized<T> atan2(const Vectorized<T> &b) const {
453 __m512 lo, hi;
454 __m512 b1, b2;
455 cvt_to_fp32<T>(values, lo, hi);
456 cvt_to_fp32<T>(b.values, b1, b2);
457 auto o1 = Sleef_atan2f16_u10(lo, b1);
458 auto o2 = Sleef_atan2f16_u10(hi, b2);
459 return cvt_from_fp32<T>(o1, o2);
460 }
copysign(const Vectorized<T> & sign)461 Vectorized<T> copysign(const Vectorized<T> &sign) const {
462 // copy sign bit (0x8000) from sign and remaining bits from values
463 __m512i mask_value = _mm512_set1_epi32(~0x80008000);
464 __m512i mask_signbit = _mm512_set1_epi32(0x80008000);
465 return Vectorized<T>(
466 _mm512_or_si512(
467 _mm512_and_si512(values, mask_value),
468 _mm512_and_si512(sign, mask_signbit)));
469 }
erf()470 Vectorized<T> erf() const {
471 return map(Sleef_erff16_u10);
472 }
erfc()473 Vectorized<T> erfc() const {
474 return map(Sleef_erfcf16_u15);
475 }
erfinv()476 Vectorized<T> erfinv() const {
477 __m512 lo, hi;
478 cvt_to_fp32<T>(values, lo, hi);
479 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
480 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
481 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
482 for (int64_t i = 0; i < size() / 2; i++) {
483 tmp1[i] = calc_erfinv(tmp1[i]);
484 tmp2[i] = calc_erfinv(tmp2[i]);
485 }
486 auto o1 = _mm512_loadu_ps(tmp1);
487 auto o2 = _mm512_loadu_ps(tmp2);
488 return cvt_from_fp32<T>(o1, o2);
489 }
exp()490 Vectorized<T> exp() const {
491 return map(Sleef_expf16_u10);
492 }
exp2()493 Vectorized<T> exp2() const {
494 return map(Sleef_exp2f16_u10);
495 }
expm1()496 Vectorized<T> expm1() const {
497 return map(Sleef_expm1f16_u10);
498 }
exp_u20()499 Vectorized<T> exp_u20() const {
500 return exp();
501 }
fmod(const Vectorized<T> & q)502 Vectorized<T> fmod(const Vectorized<T> & q) const {
503 __m512 x_lo, x_hi;
504 cvt_to_fp32<T>(values, x_lo, x_hi);
505 __m512 q_lo, q_hi;
506 cvtbf16_fp32(q.values, q_lo, q_hi);
507 auto o1 = Sleef_fmodf16(x_lo, q_lo);
508 auto o2 = Sleef_fmodf16(x_hi, q_hi);
509 return cvt_from_fp32<T>(o1, o2);
510 }
hypot(const Vectorized<T> & b)511 Vectorized<T> hypot(const Vectorized<T> &b) const {
512 __m512 lo, hi;
513 __m512 b1, b2;
514 cvt_to_fp32<T>(values, lo, hi);
515 cvt_to_fp32<T>(b.values, b1, b2);
516 auto o1 = Sleef_hypotf16_u05(lo, b1);
517 auto o2 = Sleef_hypotf16_u05(hi, b2);
518 return cvt_from_fp32<T>(o1, o2);
519 }
i0()520 Vectorized<T> i0() const {
521 __m512 lo, hi;
522 cvt_to_fp32<T>(values, lo, hi);
523 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
524 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
525 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
526 for (int64_t i = 0; i < size() / 2; i++) {
527 tmp1[i] = calc_i0(tmp1[i]);
528 tmp2[i] = calc_i0(tmp2[i]);
529 }
530 auto o1 = _mm512_loadu_ps(tmp1);
531 auto o2 = _mm512_loadu_ps(tmp2);
532 return cvt_from_fp32<T>(o1, o2);
533 }
i0e()534 Vectorized<T> i0e() const {
535 __m512 lo, hi;
536 cvt_to_fp32<T>(values, lo, hi);
537 constexpr auto sz = size();
538 __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
539 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
540 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
541
542 for (auto i = decltype(sz){0}; i < sz / 2; i++) {
543 tmp1[i] = calc_i0e(tmp1[i]);
544 tmp2[i] = calc_i0e(tmp2[i]);
545 }
546 const auto o1 = _mm512_loadu_ps(tmp1);
547 const auto o2 = _mm512_loadu_ps(tmp2);
548 return cvt_from_fp32<T>(o1, o2);
549 }
digamma()550 Vectorized<T> digamma() const {
551 __m512 lo, hi;
552 cvt_to_fp32<T>(values, lo, hi);
553 constexpr auto sz = size();
554 __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
555 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
556 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
557
558 for (auto i = decltype(sz){0}; i < sz / 2; i++) {
559 tmp1[i] = calc_digamma(tmp1[i]);
560 tmp2[i] = calc_digamma(tmp2[i]);
561 }
562 const auto o1 = _mm512_loadu_ps(tmp1);
563 const auto o2 = _mm512_loadu_ps(tmp2);
564 return cvt_from_fp32<T>(o1, o2);
565 }
igamma(const Vectorized<T> & x)566 Vectorized<T> igamma(const Vectorized<T> &x) const {
567 __m512 lo, hi;
568 __m512 xlo, xhi;
569 cvt_to_fp32<T>(values, lo, hi);
570 cvt_to_fp32<T>(x.values, xlo, xhi);
571 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
572 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
573 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
574 __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
575 _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
576 _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
577 for (int64_t i = 0; i < size() / 2; ++i) {
578 tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
579 tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
580 }
581 auto o1 = _mm512_loadu_ps(tmp1);
582 auto o2 = _mm512_loadu_ps(tmp2);
583 return cvt_from_fp32<T>(o1, o2);
584 }
585
igammac(const Vectorized<T> & x)586 Vectorized<T> igammac(const Vectorized<T> &x) const {
587 __m512 lo, hi;
588 __m512 xlo, xhi;
589 cvt_to_fp32<T>(values, lo, hi);
590 cvt_to_fp32<T>(x.values, xlo, xhi);
591 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
592 _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
593 _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
594 __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
595 _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
596 _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
597 for (int64_t i = 0; i < size() / 2; ++i) {
598 tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
599 tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
600 }
601 auto o1 = _mm512_loadu_ps(tmp1);
602 auto o2 = _mm512_loadu_ps(tmp2);
603 return cvt_from_fp32<T>(o1, o2);
604 }
log()605 Vectorized<T> log() const {
606 return map(Sleef_logf16_u10);
607 }
log2()608 Vectorized<T> log2() const {
609 return map(Sleef_log2f16_u10);
610 }
log10()611 Vectorized<T> log10() const {
612 return map(Sleef_log10f16_u10);
613 }
log1p()614 Vectorized<T> log1p() const {
615 return map(Sleef_log1pf16_u10);
616 }
sin()617 Vectorized<T> sin() const {
618 return map(Sleef_sinf16_u10);
619 }
sinh()620 Vectorized<T> sinh() const {
621 return map(Sleef_sinhf16_u10);
622 }
cos()623 Vectorized<T> cos() const {
624 return map(Sleef_cosf16_u10);
625 }
cosh()626 Vectorized<T> cosh() const {
627 return map(Sleef_coshf16_u10);
628 }
ceil()629 Vectorized<T> ceil() const {
630 __m512 lo, hi;
631 cvt_to_fp32<T>(values, lo, hi);
632 auto o1 = _mm512_ceil_ps(lo);
633 auto o2 = _mm512_ceil_ps(hi);
634 return cvt_from_fp32<T>(o1, o2);
635 }
floor()636 Vectorized<T> floor() const {
637 __m512 lo, hi;
638 cvt_to_fp32<T>(values, lo, hi);
639 auto o1 = _mm512_floor_ps(lo);
640 auto o2 = _mm512_floor_ps(hi);
641 return cvt_from_fp32<T>(o1, o2);
642 }
neg()643 Vectorized<T> neg() const {
644 return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
645 }
round()646 Vectorized<T> round() const {
647 __m512 lo, hi;
648 cvt_to_fp32<T>(values, lo, hi);
649 auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
650 auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
651 return cvt_from_fp32<T>(o1, o2);
652 }
tan()653 Vectorized<T> tan() const {
654 return map(Sleef_tanf16_u10);
655 }
tanh()656 Vectorized<T> tanh() const {
657 return map(Sleef_tanhf16_u10);
658 }
trunc()659 Vectorized<T> trunc() const {
660 __m512 lo, hi;
661 cvt_to_fp32<T>(values, lo, hi);
662 auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
663 auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
664 return cvt_from_fp32<T>(o1, o2);
665 }
lgamma()666 Vectorized<T> lgamma() const {
667 return map(Sleef_lgammaf16_u10);
668 }
sqrt()669 Vectorized<T> sqrt() const {
670 __m512 lo, hi;
671 cvt_to_fp32<T>(values, lo, hi);
672 auto o1 = _mm512_sqrt_ps(lo);
673 auto o2 = _mm512_sqrt_ps(hi);
674 return cvt_from_fp32<T>(o1, o2);
675 }
reciprocal()676 Vectorized<T> reciprocal() const {
677 __m512 lo, hi;
678 cvt_to_fp32<T>(values, lo, hi);
679 auto ones = _mm512_set1_ps(1);
680 auto o1 = _mm512_div_ps(ones, lo);
681 auto o2 = _mm512_div_ps(ones, hi);
682 return cvt_from_fp32<T>(o1, o2);
683 }
rsqrt()684 Vectorized<T> rsqrt() const {
685 __m512 lo, hi;
686 cvt_to_fp32<T>(values, lo, hi);
687 auto ones = _mm512_set1_ps(1);
688 auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
689 auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
690 return cvt_from_fp32<T>(o1, o2);
691 }
pow(const Vectorized<T> & b)692 Vectorized<T> pow(const Vectorized<T> &b) const {
693 __m512 lo, hi;
694 __m512 b1, b2;
695 cvt_to_fp32<T>(values, lo, hi);
696 cvt_to_fp32<T>(b.values, b1, b2);
697 auto o1 = Sleef_powf16_u10(lo, b1);
698 auto o2 = Sleef_powf16_u10(hi, b2);
699 return cvt_from_fp32<T>(o1, o2);
700 }
701 private:
702 template<typename Op>
binary_compare(const Vectorized<T> & b,Op op)703 Vectorized<T> inline binary_compare(const Vectorized<T>& b, Op op) const {
704 __m512 a_lo, a_hi;
705 __m512 b_lo, b_hi;
706 cvt_to_fp32<T>(values, a_lo, a_hi);
707 cvt_to_fp32<T>(b.values, b_lo, b_hi);
708 auto o1 = op(a_lo, b_lo);
709 auto o2 = op(a_hi, b_hi);
710 return cvt_from_fp32<T, /*is_compare_op*/true>(o1, o2);
711 }
712
713 public:
714 Vectorized<T> inline operator>(const Vectorized<T>& other) const {
715 return binary_compare(other, [](__m512 x, __m512 y) {
716 auto zero_vec = _mm512_set1_epi32(0);
717 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
718 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
719 });
720 }
721 Vectorized<T> inline operator<(const Vectorized<T>& other) const {
722 return binary_compare(other, [](__m512 x, __m512 y) {
723 auto zero_vec = _mm512_set1_epi32(0);
724 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
725 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
726 });
727 }
728 Vectorized<T> inline operator>=(const Vectorized<T>& other) const {
729 return binary_compare(other, [](__m512 x, __m512 y) {
730 auto zero_vec = _mm512_set1_epi32(0);
731 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
732 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
733 });
734 }
735 Vectorized<T> inline operator<=(const Vectorized<T>& other) const {
736 return binary_compare(other, [](__m512 x, __m512 y) {
737 auto zero_vec = _mm512_set1_epi32(0);
738 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
739 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
740 });
741 }
742 Vectorized<T> inline operator==(const Vectorized<T>& other) const {
743 return binary_compare(other, [](__m512 x, __m512 y) {
744 auto zero_vec = _mm512_set1_epi32(0);
745 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
746 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
747 });
748 }
749 Vectorized<T> inline operator!=(const Vectorized<T>& other) const {
750 return binary_compare(other, [](__m512 x, __m512 y) {
751 auto zero_vec = _mm512_set1_epi32(0);
752 auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
753 return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
754 });
755 }
756 };
757
758 template<typename T, typename Op>
binary_op_as_fp32(const Vectorized<T> & a,const Vectorized<T> & b,Op op)759 static inline Vectorized<T> binary_op_as_fp32(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
760 __m512 a_lo, a_hi;
761 __m512 b_lo, b_hi;
762 cvt_to_fp32<T>(__m512i(a), a_lo, a_hi);
763 cvt_to_fp32<T>(__m512i(b), b_lo, b_hi);
764 auto o1 = op(a_lo, b_lo);
765 auto o2 = op(a_hi, b_hi);
766 return cvt_from_fp32<T>(o1, o2);
767 }
768
769 template <>
770 class Vectorized<BFloat16>: public Vectorized16<BFloat16> {
771 public:
772 using Vectorized16::Vectorized16;
773
774 Vectorized<BFloat16> frac() const;
775
776 Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
777 Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
778 Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
779 Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
780 Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
781 Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
782 };
783
784 Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
785 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
786 }
787 Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
788 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
789 }
790 Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
791 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
792 }
793 Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
794 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
795 }
796 Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
797 return _mm512_and_si512(a, b);
798 }
799 Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
800 return _mm512_or_si512(a, b);
801 }
802 Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
803 return _mm512_xor_si512(a, b);
804 }
805
eq(const Vectorized<BFloat16> & other)806 inline Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
807 return (*this == other) & Vectorized<BFloat16>(1.0f);
808 }
809
ne(const Vectorized<BFloat16> & other)810 inline Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
811 return (*this != other) & Vectorized<BFloat16>(1.0f);
812 }
813
gt(const Vectorized<BFloat16> & other)814 inline Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
815 return (*this > other) & Vectorized<BFloat16>(1.0f);
816 }
817
ge(const Vectorized<BFloat16> & other)818 inline Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
819 return (*this >= other) & Vectorized<BFloat16>(1.0f);
820 }
821
lt(const Vectorized<BFloat16> & other)822 inline Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
823 return (*this < other) & Vectorized<BFloat16>(1.0f);
824 }
825
le(const Vectorized<BFloat16> & other)826 inline Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
827 return (*this <= other) & Vectorized<BFloat16>(1.0f);
828 }
829
830 // frac. Implement this here so we can use subtraction
frac()831 inline Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
832 return *this - this->trunc();
833 }
834
835 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
836 // either input is a NaN.
837 template <>
maximum(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b)838 Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
839 __m512 a_lo, a_hi;
840 __m512 b_lo, b_hi;
841 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
842 cvtbf16_fp32(__m512i(b), b_lo, b_hi);
843 auto max_lo = _mm512_max_ps(a_lo, b_lo);
844 auto max_hi = _mm512_max_ps(a_hi, b_hi);
845 auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
846 auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
847 auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
848 auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
849 // Exploit the fact that all-ones is a NaN.
850 auto o1 = _mm512_or_ps(max_lo, nan_lo);
851 auto o2 = _mm512_or_ps(max_hi, nan_hi);
852 return cvtfp32_bf16(o1, o2);
853 }
854
855 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
856 // either input is a NaN.
857 template <>
minimum(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b)858 Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
859 __m512 a_lo, a_hi;
860 __m512 b_lo, b_hi;
861 __m512i zero_vec = _mm512_set1_epi32(0);
862 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
863 cvtbf16_fp32(__m512i(b), b_lo, b_hi);
864 auto min_lo = _mm512_min_ps(a_lo, b_lo);
865 auto min_hi = _mm512_min_ps(a_hi, b_hi);
866 auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
867 auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
868 auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
869 0xFFFFFFFF));
870 auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
871 0xFFFFFFFF));
872 // Exploit the fact that all-ones is a NaN.
873 auto o1 = _mm512_or_ps(min_lo, nan_lo);
874 auto o2 = _mm512_or_ps(min_hi, nan_hi);
875 return cvtfp32_bf16(o1, o2);
876 }
877
878 template <>
clamp(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & min,const Vectorized<BFloat16> & max)879 Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
880 const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
881 __m512 a_lo, a_hi;
882 __m512 min_lo, min_hi;
883 __m512 max_lo, max_hi;
884 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
885 cvtbf16_fp32(__m512i(min), min_lo, min_hi);
886 cvtbf16_fp32(__m512i(max), max_lo, max_hi);
887 auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
888 auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
889 return cvtfp32_bf16(o1, o2);
890 }
891
892 template <>
clamp_max(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & max)893 Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
894 __m512 a_lo, a_hi;
895 __m512 max_lo, max_hi;
896 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
897 cvtbf16_fp32(__m512i(max), max_lo, max_hi);
898 auto o1 = _mm512_min_ps(max_lo, a_lo);
899 auto o2 = _mm512_min_ps(max_hi, a_hi);
900 return cvtfp32_bf16(o1, o2);
901 }
902
903 template <>
clamp_min(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & min)904 Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
905 __m512 a_lo, a_hi;
906 __m512 min_lo, min_hi;
907 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
908 cvtbf16_fp32(__m512i(min), min_lo, min_hi);
909 auto o1 = _mm512_max_ps(min_lo, a_lo);
910 auto o2 = _mm512_max_ps(min_hi, a_hi);
911 return cvtfp32_bf16(o1, o2);
912 }
913
914 template <>
convert(const BFloat16 * src,BFloat16 * dst,int64_t n)915 inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
916 int64_t i;
917 #ifndef __msvc_cl__
918 #pragma unroll
919 #endif
920 for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
921 auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
922 _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
923 }
924 #ifndef __msvc_cl__
925 #pragma unroll
926 #endif
927 for (; i < n; i++) {
928 dst[i] = src[i];
929 }
930 }
931
932 template <>
convert(const float * src,BFloat16 * dst,int64_t n)933 inline void convert(const float* src, BFloat16* dst, int64_t n) {
934 int64_t i;
935 for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
936 __m512 a = _mm512_loadu_ps(&src[i]);
937 __m512 b = _mm512_loadu_ps(&src[i + 16]);
938
939 __m512i bf = cvtfp32_bf16(a, b);
940 _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
941 }
942 for (; i < n; i++) {
943 dst[i] = c10::convert<BFloat16>(src[i]);
944 }
945 }
946
947 template <>
convert(const double * src,BFloat16 * dst,int64_t n)948 inline void convert(const double* src, BFloat16* dst, int64_t n) {
949 auto load_float = [](const double *src) -> __m512 {
950 // Load one float vector from an array of doubles
951 __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
952 __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
953 return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
954 };
955
956 int64_t i;
957 for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
958 __m512 a = load_float(&src[i]);
959 __m512 b = load_float(&src[i + 16]);
960
961 __m512i bf = cvtfp32_bf16(a, b);
962 _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
963 }
964 for (; i < n; i++) {
965 dst[i] = c10::convert<BFloat16>(src[i]);
966 }
967 }
968
969 template <>
fmadd(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b,const Vectorized<BFloat16> & c)970 Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
971 const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
972 __m512 a_lo, a_hi;
973 __m512 b_lo, b_hi;
974 __m512 c_lo, c_hi;
975 cvtbf16_fp32(__m512i(a), a_lo, a_hi);
976 cvtbf16_fp32(__m512i(b), b_lo, b_hi);
977 cvtbf16_fp32(__m512i(c), c_lo, c_hi);
978 auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
979 auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
980 return cvtfp32_bf16(o1, o2);
981 }
982
_transpose_mxn_half_16_16(__m256i t[],__m512i u[])983 static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) {
984 __m512i r[8];
985 // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15 e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
986 // b0-b15 f0-f15
987 // c0-c15 g0-g15
988 // d0-d15 h0-h15
989 // i0-i15 m0-m15
990 // j0-j15 n0-n15
991 // k0-k15 o0-o15
992 // l0-l15 p0-p15
993 #ifndef __msvc_cl__
994 #pragma unroll(4)
995 #endif
996 for (int i = 0; i < 4; i++) {
997 r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01);
998 r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01);
999 }
1000
1001 // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11 e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 e10e11 f10f11
1002 // u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15 e4e5 f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15
1003 // u2: c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11 g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11
1004 // u3: c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 h12h13 g14g15 h14h15
1005 // i j m n
1006 // k l o p
1007 #ifndef __msvc_cl__
1008 #pragma unroll(4)
1009 #endif
1010 for (int i = 0; i < 8; i += 2) {
1011 u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]);
1012 u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]);
1013 }
1014
1015 // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9 e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 h8h9
1016 // r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11 e2e3 f2f3 g2g3 h2h3 e10e11 f10f11 g10g11 h10h11
1017 // r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
1018 // r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
1019 // r4: i j k l m n o p
1020 r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
1021 r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
1022 r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
1023 r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
1024 r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
1025 r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
1026 r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
1027 r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
1028
1029 __m512i const1 = _mm512_set_epi32(
1030 0x00370035,
1031 0x00330031,
1032 0x00270025,
1033 0x00230021,
1034 0x00170015,
1035 0x00130011,
1036 0x00070005,
1037 0x00030001,
1038 0x00360034,
1039 0x00320030,
1040 0x00260024,
1041 0x00220020,
1042 0x00160014,
1043 0x00120010,
1044 0x00060004,
1045 0x00020000);
1046 __m512i const2 = _mm512_set_epi32(
1047 0x003f003d,
1048 0x003b0039,
1049 0x002f002d,
1050 0x002b0029,
1051 0x001f001d,
1052 0x001b0019,
1053 0x000f000d,
1054 0x000b0009,
1055 0x003e003c,
1056 0x003a0038,
1057 0x002e002c,
1058 0x002a0028,
1059 0x001e001c,
1060 0x001a0018,
1061 0x000e000c,
1062 0x000a0008);
1063 // merge values from two regs
1064 // 0-- 1--
1065 // 8-- 9--
1066 // 2-- 3--
1067 // 10-- 11--
1068 // 4-- 5--
1069 // 12-- 13--
1070 // 6-- 7--
1071 // 14-- 15--
1072 #ifndef __msvc_cl__
1073 #pragma unroll(4)
1074 #endif
1075 for (int i = 0; i < 4; i++) {
1076 u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]);
1077 u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]);
1078 }
1079 }
1080
1081 // TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16
1082 // Code referred to FBGEMM:
1083 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
1084 template<>
1085 inline void transpose_mxn<BFloat16, 16, 16>(
1086 const BFloat16* src,
1087 int64_t ld_src,
1088 BFloat16* dst,
1089 int64_t ld_dst) {
1090 __m256i t[16];
1091 // load from src to registers
1092 // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
1093 // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
1094 // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15
1095 // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15
1096 // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15
1097 // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15
1098 // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15
1099 // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15
1100 // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15
1101 // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15
1102 // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15
1103 // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15
1104 // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15
1105 // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15
1106 // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15
1107 // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15
1108 #ifndef __msvc_cl__
1109 #pragma unroll(16)
1110 #endif
1111 for (int i = 0; i < 16; i++) {
1112 t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
1113 }
1114
1115 __m512i u[8];
1116 _transpose_mxn_half_16_16(t, u);
1117
1118 #ifndef __msvc_cl__
1119 #pragma unroll(8)
1120 #endif
1121 for (int i = 0; i < 8; i++) {
1122 _mm256_storeu_si256(
1123 reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
1124 _mm512_extracti32x8_epi32(u[i], 0x0));
1125 _mm256_storeu_si256(
1126 reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
1127 _mm512_extracti32x8_epi32(u[i], 0x01));
1128 }
1129 }
1130
1131 // Code referred to FBGEMM:
1132 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
1133 template<>
1134 inline void transpose_mxn<Half, 16, 16>(
1135 const Half* src,
1136 int64_t ld_src,
1137 Half* dst,
1138 int64_t ld_dst) {
1139 __m256i t[16];
1140 // load from src to registers
1141 // Same matrix indices as above transpose_mxn<BFloat16, 16, 16>
1142 #ifndef __msvc_cl__
1143 #pragma unroll(16)
1144 #endif
1145 for (int i = 0; i < 16; i++) {
1146 t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
1147 }
1148
1149 __m512i u[8];
1150 _transpose_mxn_half_16_16(t, u);
1151
1152 #ifndef __msvc_cl__
1153 #pragma unroll(8)
1154 #endif
1155 for (int i = 0; i < 8; i++) {
1156 _mm256_storeu_si256(
1157 reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
1158 _mm512_extracti32x8_epi32(u[i], 0x0));
1159 _mm256_storeu_si256(
1160 reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
1161 _mm512_extracti32x8_epi32(u[i], 0x01));
1162 }
1163 }
1164
_transpose_mxn_half_32_32(__m512i r[],__m512i d[])1165 static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) {
1166 // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59
1167 // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63
1168 // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123
1169 // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127
1170 // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 ... 187
1171 // t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 175 148 ... 191
1172 // t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 234 203 235 208 ... 251
1173 // t[7]: 196 228 197 229 198 230 199 231 204 236 205 237 206 238 207 239 212 ... 255
1174 // t[8]: 256 288 257 289 258 290 259 291 264 296 265 297 266 298 267 299 272 ... 315
1175 // t[9]: 260 292 261 293 262 294 263 295 268 300 269 301 270 302 271 303 276 ... 319
1176 // t[10]: 320 352 321 353 322 354 323 355 328 360 329 361 330 362 331 363 336 ... 379
1177 // t[11]: 324 356 325 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383
1178 // t[12]: 384 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443
1179 // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 ... 447
1180 // t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 491 464 ... 507
1181 // t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 494 463 495 468 ... 511
1182 // t[16]: 512 544 513 545 514 546 515 547 520 552 521 553 522 554 523 555 528 ... 571
1183 // ...
1184 // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 980 ... 1023
1185 #ifndef __msvc_cl__
1186 #pragma unroll(16)
1187 #endif
1188 for (int i = 0; i < 16; ++i) {
1189 d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]);
1190 d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]);
1191 }
1192
1193 // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121
1194 // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123
1195 // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125
1196 // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127
1197 // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 ... 249
1198 // t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 235 146 ... 251
1199 // t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 173 205 237 148 ... 253
1200 // t[7]: 134 166 198 230 135 167 199 231 142 174 206 238 143 175 207 239 150 ... 255
1201 // t[8]: 256 288 320 352 257 289 321 353 264 296 328 360 265 297 329 361 272 ... 377
1202 // t[9]: 258 290 322 354 259 291 323 355 266 298 330 362 267 299 331 363 274 ... 379
1203 // t[10]: 260 292 324 356 261 293 325 357 268 300 332 364 269 301 333 365 276 ... 381
1204 // t[11]: 262 294 326 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383
1205 // t[12]: 384 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505
1206 // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 ... 507
1207 // t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 493 404 ... 509
1208 // t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 431 463 495 406 ... 511
1209 // t[16]: 512 544 576 608 513 545 577 609 520 552 584 616 521 553 585 617 528 ... 633
1210 // ...
1211 // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 918 ... 1023
1212 #ifndef __msvc_cl__
1213 #pragma unroll(8)
1214 #endif
1215 for (int i = 0; i < 8; ++i) {
1216 r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]);
1217 r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]);
1218 r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]);
1219 r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]);
1220 }
1221
1222 // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248
1223 // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249
1224 // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250
1225 // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251
1226 // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252
1227 // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253
1228 // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254
1229 // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255
1230 // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 ... 504
1231 // t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 489 273 ... 505
1232 // t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 426 458 490 274 ... 506
1233 // t[11]: 259 291 323 355 387 419 451 483 267 299 331 363 395 427 459 491 275 ... 507
1234 // t[12]: 260 292 324 356 388 420 452 484 268 300 332 364 396 428 460 492 276 ... 508
1235 // t[13]: 261 293 325 357 389 421 453 485 269 301 333 365 397 429 461 493 277 ... 509
1236 // t[14]: 262 294 326 358 390 422 454 486 270 302 334 366 398 430 462 494 278 ... 510
1237 // t[15]: 263 295 327 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511
1238 // t[16]: 512 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760
1239 // ...
1240 // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 ... 1023
1241 #ifndef __msvc_cl__
1242 #pragma unroll(4)
1243 #endif
1244 for (int i = 0; i < 4; ++i) {
1245 d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]);
1246 d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]);
1247 d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]);
1248 d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]);
1249 d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]);
1250 d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]);
1251 d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]);
1252 d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]);
1253 }
1254
1255 // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496
1256 // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497
1257 // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498
1258 // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499
1259 // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... 500
1260 // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 ... 501
1261 // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 22 ... 502
1262 // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 23 ... 503
1263 // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 24 ... 504
1264 // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 25 ... 505
1265 // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 26 ... 506
1266 // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 27 ... 507
1267 // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 28 ... 508
1268 // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 29 ... 509
1269 // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 30 ... 510
1270 // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 31 ... 511
1271 // t[16]: 512 544 576 608 640 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008
1272 // ...
1273 // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 ... 1023
1274 __m512i const1 = _mm512_set_epi64(
1275 0x000000000000000d,
1276 0x000000000000000c,
1277 0x0000000000000005,
1278 0x0000000000000004,
1279 0x0000000000000009,
1280 0x0000000000000008,
1281 0x0000000000000001,
1282 0x0000000000000000);
1283 __m512i const2 = _mm512_set_epi64(
1284 0x000000000000000f,
1285 0x000000000000000e,
1286 0x0000000000000007,
1287 0x0000000000000006,
1288 0x000000000000000b,
1289 0x000000000000000a,
1290 0x0000000000000003,
1291 0x0000000000000002);
1292 #ifndef __msvc_cl__
1293 #pragma unroll(8)
1294 #endif
1295 for (int i = 0; i < 8; ++i) {
1296 r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/const1, d[i + 8]);
1297 r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/const2, d[i + 8]);
1298 r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const1, d[i + 24]);
1299 r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const2, d[i + 24]);
1300 }
1301
1302 // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 ... 992
1303 // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 513 545 ... 993
1304 // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 514 546 ... 994
1305 // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 515 547 ... 995
1306 // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 516 548 ... 996
1307 // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 517 549 ... 997
1308 // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 518 550 ... 998
1309 // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999
1310 // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000
1311 // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 ... 1001
1312 // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 522 554 ... 1002
1313 // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 523 555 ... 1003
1314 // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 524 556 ... 1004
1315 // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005
1316 // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006
1317 // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... 1007
1318 // t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 528 560 ... 1008
1319 // ...
1320 // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 ... 1023
1321 __m512i const3 = _mm512_set_epi64(
1322 0x000000000000000b,
1323 0x000000000000000a,
1324 0x0000000000000009,
1325 0x0000000000000008,
1326 0x0000000000000003,
1327 0x0000000000000002,
1328 0x0000000000000001,
1329 0x0000000000000000);
1330 __m512i const4 = _mm512_set_epi64(
1331 0x000000000000000f,
1332 0x000000000000000e,
1333 0x000000000000000d,
1334 0x000000000000000c,
1335 0x0000000000000007,
1336 0x0000000000000006,
1337 0x0000000000000005,
1338 0x0000000000000004);
1339 #ifndef __msvc_cl__
1340 #pragma unroll(16)
1341 #endif
1342 for (int i = 0; i < 16; ++i) {
1343 d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/const3, r[i + 16]);
1344 d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/const4, r[i + 16]);
1345 }
1346 }
1347
1348 // Code referred to FBGEMM:
1349 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
1350 template<>
1351 inline void transpose_mxn<BFloat16>(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) {
1352 // load from src
1353 TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<BFloat16> expects M, N <= 32.");
1354 __m512i r[32];
1355 int i;
1356 if (N == 32) {
1357 for (i = 0; i < M; ++i) {
1358 r[i] = _mm512_loadu_si512(&src[i * ld_src]);
1359 }
1360 } else {
1361 __mmask32 src_mask = (1 << N) - 1;
1362 for (i = 0; i < M; ++i) {
1363 r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
1364 }
1365 }
1366 for (; i < 32; ++i) {
1367 r[i] = _mm512_setzero_si512();
1368 }
1369
1370 __m512i d[32];
1371 _transpose_mxn_half_32_32(r, d);
1372
1373 // store to dst
1374 if (M == 32) {
1375 for (i = 0; i < N; ++i) {
1376 _mm512_storeu_si512(&dst[i * ld_dst], d[i]);
1377 }
1378 } else {
1379 __mmask32 dst_mask = (1 << M) - 1;
1380 for (i = 0; i < N; ++i) {
1381 _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
1382 }
1383 }
1384 }
1385
1386 template <typename T, int M, int N,
1387 typename std::enable_if_t<std::is_same<T, BFloat16>::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
transpose_mxn(const BFloat16 * src,int64_t ld_src,BFloat16 * dst,int64_t ld_dst)1388 inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) {
1389 transpose_mxn<BFloat16>(src, ld_src, dst, ld_dst, M, N);
1390 }
1391
1392 template<>
1393 inline void transpose_mxn<Half>(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) {
1394 TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<Half> expects M, N <= 32.");
1395 // load from src
1396 __m512i r[32];
1397 int i;
1398 if (N == 32) {
1399 for (i = 0; i < M; ++i) {
1400 r[i] = _mm512_loadu_si512(&src[i * ld_src]);
1401 }
1402 } else {
1403 __mmask32 src_mask = (1 << N) - 1;
1404 for (i = 0; i < M; ++i) {
1405 r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
1406 }
1407 }
1408 for (; i < 32; ++i) {
1409 r[i] = _mm512_setzero_si512();
1410 }
1411
1412 __m512i d[32];
1413 _transpose_mxn_half_32_32(r, d);
1414
1415 // store to dst
1416 if (M == 32) {
1417 for (i = 0; i < N; ++i) {
1418 _mm512_storeu_si512(&dst[i * ld_dst], d[i]);
1419 }
1420 } else {
1421 __mmask32 dst_mask = (1 << M) - 1;
1422 for (i = 0; i < N; ++i) {
1423 _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
1424 }
1425 }
1426 }
1427
1428 template <typename T, int M, int N,
1429 typename std::enable_if_t<std::is_same<T, Half>::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
transpose_mxn(const Half * src,int64_t ld_src,Half * dst,int64_t ld_dst)1430 inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) {
1431 transpose_mxn<Half>(src, ld_src, dst, ld_dst, M, N);
1432 }
1433
1434 template <>
1435 class Vectorized<Half>: public Vectorized16<Half> {
1436 public:
1437 using Vectorized16::Vectorized16;
1438
1439 Vectorized<Half> frac() const;
1440
1441 Vectorized<Half> eq(const Vectorized<Half>& other) const;
1442 Vectorized<Half> ne(const Vectorized<Half>& other) const;
1443 Vectorized<Half> gt(const Vectorized<Half>& other) const;
1444 Vectorized<Half> ge(const Vectorized<Half>& other) const;
1445 Vectorized<Half> lt(const Vectorized<Half>& other) const;
1446 Vectorized<Half> le(const Vectorized<Half>& other) const;
1447 };
1448
1449 Vectorized<Half> inline operator+(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1450 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
1451 }
1452 Vectorized<Half> inline operator-(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1453 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
1454 }
1455 Vectorized<Half> inline operator*(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1456 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
1457 }
1458 Vectorized<Half> inline operator/(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1459 return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
1460 }
1461
1462 Vectorized<Half> inline operator&(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1463 return _mm512_and_si512(a, b);
1464 }
1465 Vectorized<Half> inline operator|(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1466 return _mm512_or_si512(a, b);
1467 }
1468 Vectorized<Half> inline operator^(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1469 return _mm512_xor_si512(a, b);
1470 }
1471
eq(const Vectorized<Half> & other)1472 inline Vectorized<Half> Vectorized<Half>::eq(const Vectorized<Half>& other) const {
1473 return (*this == other) & Vectorized<Half>(1.0f);
1474 }
1475
ne(const Vectorized<Half> & other)1476 inline Vectorized<Half> Vectorized<Half>::ne(const Vectorized<Half>& other) const {
1477 return (*this != other) & Vectorized<Half>(1.0f);
1478 }
1479
gt(const Vectorized<Half> & other)1480 inline Vectorized<Half> Vectorized<Half>::gt(const Vectorized<Half>& other) const {
1481 return (*this > other) & Vectorized<Half>(1.0f);
1482 }
1483
ge(const Vectorized<Half> & other)1484 inline Vectorized<Half> Vectorized<Half>::ge(const Vectorized<Half>& other) const {
1485 return (*this >= other) & Vectorized<Half>(1.0f);
1486 }
1487
lt(const Vectorized<Half> & other)1488 inline Vectorized<Half> Vectorized<Half>::lt(const Vectorized<Half>& other) const {
1489 return (*this < other) & Vectorized<Half>(1.0f);
1490 }
1491
le(const Vectorized<Half> & other)1492 inline Vectorized<Half> Vectorized<Half>::le(const Vectorized<Half>& other) const {
1493 return (*this <= other) & Vectorized<Half>(1.0f);
1494 }
1495
1496 // frac. Implement this here so we can use subtraction
frac()1497 inline Vectorized<Half> Vectorized<Half>::frac() const {
1498 return *this - this->trunc();
1499 }
1500
1501 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
1502 // either input is a NaN.
1503 template <>
maximum(const Vectorized<Half> & a,const Vectorized<Half> & b)1504 Vectorized<Half> inline maximum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1505 __m512 a_lo, a_hi;
1506 __m512 b_lo, b_hi;
1507 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1508 cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1509 auto max_lo = _mm512_max_ps(a_lo, b_lo);
1510 auto max_hi = _mm512_max_ps(a_hi, b_hi);
1511 auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
1512 auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
1513 auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
1514 auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
1515 // Exploit the fact that all-ones is a NaN.
1516 auto o1 = _mm512_or_ps(max_lo, nan_lo);
1517 auto o2 = _mm512_or_ps(max_hi, nan_hi);
1518 return cvtfp32_fp16(o1, o2);
1519 }
1520
1521 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
1522 // either input is a NaN.
1523 template <>
minimum(const Vectorized<Half> & a,const Vectorized<Half> & b)1524 Vectorized<Half> inline minimum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1525 __m512 a_lo, a_hi;
1526 __m512 b_lo, b_hi;
1527 __m512i zero_vec = _mm512_set1_epi32(0);
1528 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1529 cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1530 auto min_lo = _mm512_min_ps(a_lo, b_lo);
1531 auto min_hi = _mm512_min_ps(a_hi, b_hi);
1532 auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
1533 auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
1534 auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
1535 0xFFFFFFFF));
1536 auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
1537 0xFFFFFFFF));
1538 // Exploit the fact that all-ones is a NaN.
1539 auto o1 = _mm512_or_ps(min_lo, nan_lo);
1540 auto o2 = _mm512_or_ps(min_hi, nan_hi);
1541 return cvtfp32_fp16(o1, o2);
1542 }
1543
1544 template <>
clamp(const Vectorized<Half> & a,const Vectorized<Half> & min,const Vectorized<Half> & max)1545 Vectorized<Half> inline clamp(const Vectorized<Half>& a,
1546 const Vectorized<Half>& min, const Vectorized<Half>& max) {
1547 __m512 a_lo, a_hi;
1548 __m512 min_lo, min_hi;
1549 __m512 max_lo, max_hi;
1550 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1551 cvtfp16_fp32(__m512i(min), min_lo, min_hi);
1552 cvtfp16_fp32(__m512i(max), max_lo, max_hi);
1553 auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
1554 auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
1555 return cvtfp32_fp16(o1, o2);
1556 }
1557
1558 template <>
clamp_max(const Vectorized<Half> & a,const Vectorized<Half> & max)1559 Vectorized<Half> inline clamp_max(const Vectorized<Half>& a, const Vectorized<Half>& max) {
1560 __m512 a_lo, a_hi;
1561 __m512 max_lo, max_hi;
1562 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1563 cvtfp16_fp32(__m512i(max), max_lo, max_hi);
1564 auto o1 = _mm512_min_ps(max_lo, a_lo);
1565 auto o2 = _mm512_min_ps(max_hi, a_hi);
1566 return cvtfp32_fp16(o1, o2);
1567 }
1568
1569 template <>
clamp_min(const Vectorized<Half> & a,const Vectorized<Half> & min)1570 Vectorized<Half> inline clamp_min(const Vectorized<Half>& a, const Vectorized<Half>& min) {
1571 __m512 a_lo, a_hi;
1572 __m512 min_lo, min_hi;
1573 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1574 cvtfp16_fp32(__m512i(min), min_lo, min_hi);
1575 auto o1 = _mm512_max_ps(min_lo, a_lo);
1576 auto o2 = _mm512_max_ps(min_hi, a_hi);
1577 return cvtfp32_fp16(o1, o2);
1578 }
1579
1580 template <>
convert(const Half * src,Half * dst,int64_t n)1581 inline void convert(const Half* src, Half* dst, int64_t n) {
1582 int64_t i;
1583 #ifndef __msvc_cl__
1584 #pragma unroll
1585 #endif
1586 for (i = 0; i <= (n - Vectorized<Half>::size()); i += Vectorized<Half>::size()) {
1587 auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
1588 _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
1589 }
1590 #ifndef __msvc_cl__
1591 #pragma unroll
1592 #endif
1593 for (; i < n; i++) {
1594 dst[i] = src[i];
1595 }
1596 }
1597
1598 template <>
convert(const float * src,Half * dst,int64_t n)1599 inline void convert(const float* src, Half* dst, int64_t n) {
1600 int64_t i;
1601 for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
1602 __m512 a = _mm512_loadu_ps(&src[i]);
1603 __m512 b = _mm512_loadu_ps(&src[i + 16]);
1604
1605 __m512i bf = cvtfp32_fp16(a, b);
1606 _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
1607 }
1608 for (; i < n; i++) {
1609 dst[i] = c10::convert<Half>(src[i]);
1610 }
1611 }
1612
1613 template <>
convert(const double * src,Half * dst,int64_t n)1614 inline void convert(const double* src, Half* dst, int64_t n) {
1615 auto load_float = [](const double *src) -> __m512 {
1616 // Load one float vector from an array of doubles
1617 __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
1618 __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
1619 return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
1620 };
1621
1622 int64_t i;
1623 for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
1624 __m512 a = load_float(&src[i]);
1625 __m512 b = load_float(&src[i + 16]);
1626
1627 __m512i bf = cvtfp32_fp16(a, b);
1628 _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
1629 }
1630 for (; i < n; i++) {
1631 dst[i] = c10::convert<Half>(src[i]);
1632 }
1633 }
1634
1635 template <>
fmadd(const Vectorized<Half> & a,const Vectorized<Half> & b,const Vectorized<Half> & c)1636 Vectorized<Half> inline fmadd(const Vectorized<Half>& a,
1637 const Vectorized<Half>& b, const Vectorized<Half>& c) {
1638 __m512 a_lo, a_hi;
1639 __m512 b_lo, b_hi;
1640 __m512 c_lo, c_hi;
1641 cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1642 cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1643 cvtfp16_fp32(__m512i(c), c_lo, c_hi);
1644 auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
1645 auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
1646 return cvtfp32_fp16(o1, o2);
1647 }
1648
1649 #define CONVERT_VECTORIZED_INIT(type, name) \
1650 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
1651 __m512 o1, o2; \
1652 cvt_to_fp32<type>(__m512i(a), o1, o2); \
1653 return std::make_tuple(o1, o2); \
1654 } \
1655 \
1656 inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
1657 return cvt_from_fp32<type>(__m512(a), __m512(b)); \
1658 }
1659 CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
1660 CONVERT_VECTORIZED_INIT(Half, half);
1661
1662 #else //defined(CPU_CAPABILITY_AVX512)
1663
1664 #define CONVERT_NON_VECTORIZED_INIT(type, name) \
1665 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
1666 constexpr int64_t K = Vectorized<type>::size(); \
1667 __at_align__ float arr[K]; \
1668 __at_align__ type arr2[K]; \
1669 a.store(arr2); \
1670 for (const auto k : c10::irange(K)) { \
1671 arr[k] = c10::convert<float>(arr2[k]); \
1672 } \
1673 return std::make_tuple( \
1674 Vectorized<float>::loadu(arr), \
1675 Vectorized<float>::loadu(arr + Vectorized<float>::size())); \
1676 } \
1677 \
1678 inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
1679 constexpr int64_t K = Vectorized<type>::size(); \
1680 __at_align__ float arr[K]; \
1681 __at_align__ type arr2[K]; \
1682 a.store(arr); \
1683 b.store(arr + Vectorized<float>::size()); \
1684 for (const auto k : c10::irange(K)) { \
1685 arr2[k] = c10::convert<type>(arr[k]); \
1686 } \
1687 return Vectorized<type>::loadu(arr2); \
1688 }
1689 CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
1690 CONVERT_NON_VECTORIZED_INIT(Half, half);
1691
1692 #endif // defined(CPU_CAPABILITY_AVX512)
1693
1694 #if defined(CPU_CAPABILITY_AVX512)
1695 #define LOAD_FP32_VECTORIZED_INIT(type, name) \
1696 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
1697 auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
1698 __m512 out_values; \
1699 cvt_to_fp32<type>(values, out_values); \
1700 out = out_values; \
1701 } \
1702 \
1703 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
1704 auto vec = Vectorized<type>::loadu(data); \
1705 __m512 out1_values, out2_values; \
1706 cvt_to_fp32<type>(vec, out1_values, out2_values); \
1707 out1 = out1_values; \
1708 out2 = out2_values; \
1709 }
1710 LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
1711 LOAD_FP32_VECTORIZED_INIT(Half, fp16);
1712
1713 #else // defined(CPU_CAPABILITY_AVX512)
1714 #define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
1715 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
1716 __at_align__ float values[Vectorized<float>::size()]; \
1717 for (const auto k : c10::irange(Vectorized<float>::size())) { \
1718 values[k] = data[k]; \
1719 } \
1720 out = Vectorized<float>::loadu(values); \
1721 } \
1722 \
1723 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
1724 load_fp32_from_##name(data, out1); \
1725 data += Vectorized<float>::size(); \
1726 load_fp32_from_##name(data, out2); \
1727 }
1728 LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
1729 LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
1730
1731 #endif
1732 }}}
1733