• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 
6 #include <ATen/cpu/vec/intrinsics.h>
7 #include <ATen/cpu/vec/vec_base.h>
8 #include <c10/util/irange.h>
9 
10 #if defined(CPU_CAPABILITY_AVX512)
11 #define SLEEF_STATIC_LIBS
12 #include <sleef.h>
13 #endif
14 
15 namespace at {
16 namespace vec {
17 // See Note [CPU_CAPABILITY namespace]
18 inline namespace CPU_CAPABILITY {
19 
20 #if defined(CPU_CAPABILITY_AVX512)
21 
22 #ifndef SLEEF_CONST
23 #if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
24 #define SLEEF_CONST const
25 #else
26 #define SLEEF_CONST
27 #endif
28 #define SLEEF_CONST_OLD SLEEF_CONST
29 #else
30 #define SLEEF_CONST_OLD
31 #endif
32 
33 // bfloat16 conversion
cvtbf16_fp32(const __m256i & a,__m512 & o)34 static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
35   o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
36 }
37 
cvtbf16_fp32(const __m512i & a,__m512 & o1,__m512 & o2)38 static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
39   __m256i lo = _mm512_extracti32x8_epi32(a, 0);
40   __m256i hi = _mm512_extracti32x8_epi32(a, 1);
41   cvtbf16_fp32(lo, o1);
42   cvtbf16_fp32(hi, o2);
43 }
44 
cvtfp32_bf16(const __m512 & src)45 static inline __m256i cvtfp32_bf16(const __m512& src) {
46   __m512i value = _mm512_castps_si512(src);
47   __m512i nan = _mm512_set1_epi32(0xffff);
48   auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
49   __m512i ones = _mm512_set1_epi32(0x1);
50   __m512i vec_bias = _mm512_set1_epi32(0x7fff);
51   // uint32_t lsb = (input >> 16) & 1;
52   auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
53   // uint32_t rounding_bias = 0x7fff + lsb;
54   t_value = _mm512_add_epi32(t_value, vec_bias);
55   // input += rounding_bias;
56   t_value = _mm512_add_epi32(t_value, value);
57   // input = input >> 16;
58   t_value = _mm512_srli_epi32(t_value, 16);
59   // Check NaN before converting back to bf16
60   t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
61   return _mm512_cvtusepi32_epi16(t_value);
62 }
63 
cvtfp32_bf16(const __m512 & a,const __m512 & b)64 static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) {
65   __m512i lo = _mm512_castps_si512(a);
66   __m512i hi = _mm512_castps_si512(b);
67   __m512i nan = _mm512_set1_epi32(0xffff);
68   auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
69   auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q);
70   __m512i ones = _mm512_set1_epi32(0x1);
71   __m512i vec_bias = _mm512_set1_epi32(0x7fff);
72   // uint32_t lsb = (input >> 16) & 1;
73   auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones);
74   auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones);
75   // uint32_t rounding_bias = 0x7fff + lsb;
76   t_lo = _mm512_add_epi32(t_lo, vec_bias);
77   t_hi = _mm512_add_epi32(t_hi, vec_bias);
78   // input += rounding_bias;
79   t_lo = _mm512_add_epi32(t_lo, lo);
80   t_hi = _mm512_add_epi32(t_hi, hi);
81   // input = input >> 16;
82   t_lo = _mm512_srli_epi32(t_lo, 16);
83   t_hi = _mm512_srli_epi32(t_hi, 16);
84   // Check NaN before converting back to bf16
85   t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo);
86   t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi);
87 
88   t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
89   __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
90   return _mm512_permutexvar_epi64(idx, t_lo);
91 }
92 
merge_compare_result(const __m512 & a,const __m512 & b)93 static inline __m512i merge_compare_result(const __m512& a, const __m512& b) {
94   __m512i lo = _mm512_castps_si512(a);
95   __m512i hi = _mm512_castps_si512(b);
96   lo = _mm512_srli_epi32(lo, 16);
97   hi = _mm512_srli_epi32(hi, 16);
98   auto out = _mm512_packus_epi32(lo, hi);
99   __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
100   return _mm512_permutexvar_epi64(idx, out);
101 }
102 
103 // float16 conversion
cvtfp16_fp32(const __m256i & a,__m512 & o)104 static inline void cvtfp16_fp32(const __m256i& a, __m512& o) {
105   o = _mm512_cvtph_ps(a);
106 }
107 
cvtfp16_fp32(const __m512i & a,__m512 & o1,__m512 & o2)108 static inline void cvtfp16_fp32(const __m512i& a, __m512& o1, __m512& o2) {
109   __m256i lo = _mm512_extracti32x8_epi32(a, 0);
110   __m256i hi = _mm512_extracti32x8_epi32(a, 1);
111   cvtfp16_fp32(lo, o1);
112   cvtfp16_fp32(hi, o2);
113 }
114 
cvtfp32_fp16(const __m512 & src)115 static inline __m256i cvtfp32_fp16(const __m512& src) {
116   return _mm512_cvtps_ph(
117       src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
118 }
119 
cvtfp32_fp16(const __m512 & a,const __m512 & b)120 static inline __m512i cvtfp32_fp16(const __m512& a, const __m512& b) {
121   __m256i lo = _mm512_cvtps_ph(
122       a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
123   __m256i hi = _mm512_cvtps_ph(
124       b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
125   __m512 t_lo = _mm512_castsi512_ps(_mm512_castsi256_si512(lo));
126   __m256 t_hi = _mm256_castsi256_ps(hi);
127   return _mm512_castps_si512(_mm512_insertf32x8(t_lo, t_hi, 1));
128 }
129 
130 // dtype conversion between float16/bfloat16 and float32
131 template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
132 inline void cvt_to_fp32(const __m256i& a, __m512& o);
133 template <> inline void cvt_to_fp32<BFloat16>(const __m256i& a, __m512& o) {
134   cvtbf16_fp32(a, o);
135 }
136 template <> inline void cvt_to_fp32<Half>(const __m256i& a, __m512& o) {
137   cvtfp16_fp32(a, o);
138 }
139 
140 template <typename T, typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
141 inline void cvt_to_fp32(const __m512i& a, __m512& o1, __m512& o2);
142 template <> inline void cvt_to_fp32<BFloat16>(const __m512i& a, __m512& o1, __m512& o2) {
143   cvtbf16_fp32(a, o1, o2);
144 }
145 template <> inline void cvt_to_fp32<Half>(const __m512i& a, __m512& o1, __m512& o2) {
146   cvtfp16_fp32(a, o1, o2);
147 }
148 
149 template <typename T, bool is_compare_op = false,
150           typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
151 inline __m512i cvt_from_fp32(const __m512& a, const __m512& b);
152 template <> inline __m512i cvt_from_fp32<BFloat16, false>(const __m512& a, const __m512& b) {
153   return cvtfp32_bf16(a, b);
154 }
155 template <> inline __m512i cvt_from_fp32<BFloat16, true>(const __m512& a, const __m512& b) {
156   return merge_compare_result(a, b);
157 }
158 template <> inline __m512i cvt_from_fp32<Half, false>(const __m512& a, const __m512& b) {
159   return cvtfp32_fp16(a, b);
160 }
161 template <> inline __m512i cvt_from_fp32<Half, true>(const __m512& a, const __m512& b) {
162   return cvtfp32_fp16(a, b);
163 }
164 
165 template <typename T>
166 class Vectorized16 {
167 static_assert(
168   is_reduced_floating_point_v<T>,
169   "Support only float16 and bfloat16.");
170 private:
171   __m512i values;
172 public:
173   using value_type = uint16_t;
174   using size_type = int;
size()175   static constexpr size_type size() {
176     return 32;
177   }
Vectorized16()178   Vectorized16() {}
Vectorized16(__m512i v)179   Vectorized16(__m512i v) : values(v) {}
Vectorized16(T val)180   Vectorized16(T val) {
181     value_type uw = val.x;
182     values = _mm512_set1_epi16(uw);
183   }
Vectorized16(T val1,T val2,T val3,T val4,T val5,T val6,T val7,T val8,T val9,T val10,T val11,T val12,T val13,T val14,T val15,T val16,T val17,T val18,T val19,T val20,T val21,T val22,T val23,T val24,T val25,T val26,T val27,T val28,T val29,T val30,T val31,T val32)184   Vectorized16(T val1, T val2, T val3, T val4,
185          T val5, T val6, T val7, T val8,
186          T val9, T val10, T val11, T val12,
187          T val13, T val14, T val15, T val16,
188          T val17, T val18, T val19, T val20,
189          T val21, T val22, T val23, T val24,
190          T val25, T val26, T val27, T val28,
191          T val29, T val30, T val31, T val32) {
192     values = _mm512_set_epi16(
193         val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x,
194         val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x,
195         val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x,
196         val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x);
197   }
__m512i()198   operator __m512i() const {
199     return values;
200   }
201   T& operator[](int idx) = delete;
202   const T& operator[](int idx) const  = delete;
zero_mask()203   int zero_mask() const {
204     // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
205     return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
206   }
207   static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
208     if (count == size())
209       return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
210 
211     __mmask32 mask = (1ULL << count) - 1;
212     return _mm512_maskz_loadu_epi16(mask, ptr);
213   }
214   void store(void* ptr, int count = size()) const {
215     if (count == size()) {
216       _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
217     } else if (count > 0) {
218       __mmask32 mask = (1ULL << count) - 1;
219       _mm512_mask_storeu_epi16(ptr, mask, values);
220     }
221   }
222   template <int64_t mask>
blend(const Vectorized<T> & a,const Vectorized<T> & b)223   static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
224     __at_align__ int16_t tmp_values[size()];
225     a.store(tmp_values);
226     if (mask & 0x01)
227       tmp_values[0] = b.values[31];
228     if (mask & 0x02)
229       tmp_values[1] = b.values[30];
230     if (mask & 0x04)
231       tmp_values[2] = b.values[29];
232     if (mask & 0x08)
233       tmp_values[3] = b.values[28];
234     if (mask & 0x10)
235       tmp_values[4] = b.values[27];
236     if (mask & 0x20)
237       tmp_values[5] = b.values[26];
238     if (mask & 0x40)
239       tmp_values[6] = b.values[25];
240     if (mask & 0x80)
241       tmp_values[7] = b.values[24];
242     if (mask & 0x100)
243       tmp_values[8] = b.values[23];
244     if (mask & 0x200)
245       tmp_values[9] = b.values[22];
246     if (mask & 0x400)
247       tmp_values[10] = b.values[21];
248     if (mask & 0x800)
249       tmp_values[11] = b.values[20];
250     if (mask & 0x1000)
251       tmp_values[12] = b.values[19];
252     if (mask & 0x2000)
253       tmp_values[13] = b.values[18];
254     if (mask & 0x4000)
255       tmp_values[14] = b.values[17];
256     if (mask & 0x8000)
257       tmp_values[15] = b.values[16];
258     if (mask & 0x10000)
259       tmp_values[16] = b.values[15];
260     if (mask & 0x20000)
261       tmp_values[17] = b.values[14];
262     if (mask & 0x40000)
263       tmp_values[18] = b.values[13];
264     if (mask & 0x80000)
265       tmp_values[19] = b.values[12];
266     if (mask & 0x100000)
267       tmp_values[20] = b.values[11];
268     if (mask & 0x200000)
269       tmp_values[21] = b.values[10];
270     if (mask & 0x400000)
271       tmp_values[22] = b.values[9];
272     if (mask & 0x800000)
273       tmp_values[23] = b.values[8];
274     if (mask & 0x1000000)
275       tmp_values[24] = b.values[7];
276     if (mask & 0x2000000)
277       tmp_values[25] = b.values[6];
278     if (mask & 0x4000000)
279       tmp_values[26] = b.values[5];
280     if (mask & 0x8000000)
281       tmp_values[27] = b.values[4];
282     if (mask & 0x10000000)
283       tmp_values[28] = b.values[3];
284     if (mask & 0x20000000)
285       tmp_values[29] = b.values[2];
286     if (mask & 0x40000000)
287       tmp_values[30] = b.values[1];
288     if (mask & 0x80000000)
289       tmp_values[31] = b.values[0];
290     return loadu(tmp_values);
291   }
blendv(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & mask)292   static Vectorized<T> blendv(const Vectorized<T>& a,
293       const Vectorized<T>& b, const Vectorized<T>& mask) {
294     auto all_ones = _mm512_set1_epi16(0xFFFF);
295     auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
296     return _mm512_mask_blend_epi16(mask_, a.values, b.values);
297   }
298   template<typename step_t>
299   static Vectorized<T> arange(T base = 0.f, step_t step = static_cast<step_t>(1)) {
300     return Vectorized<T>(
301       base,             base +      step, base +  2 * step, base +  3 * step,
302       base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
303       base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
304       base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
305       base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
306       base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
307       base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
308       base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
309   }
310   static Vectorized<T> set(const Vectorized<T>& a,
311       const Vectorized<T>& b, int64_t count = size()) {
312     switch (count) {
313       case 0:
314         return a;
315       case 1:
316         return blend<1>(a, b);
317       case 2:
318         return blend<3>(a, b);
319       case 3:
320         return blend<7>(a, b);
321       case 4:
322         return blend<15>(a, b);
323       case 5:
324         return blend<31>(a, b);
325       case 6:
326         return blend<63>(a, b);
327       case 7:
328         return blend<127>(a, b);
329       case 8:
330         return blend<255>(a, b);
331       case 9:
332         return blend<511>(a, b);
333       case 10:
334         return blend<1023>(a, b);
335       case 11:
336         return blend<2047>(a, b);
337       case 12:
338         return blend<4095>(a, b);
339       case 13:
340         return blend<8191>(a, b);
341       case 14:
342         return blend<16383>(a, b);
343       case 15:
344         return blend<32767>(a, b);
345       case 16:
346         return blend<65535>(a, b);
347       case 17:
348         return blend<131071>(a, b);
349       case 18:
350         return blend<262143>(a, b);
351       case 19:
352         return blend<524287>(a, b);
353       case 20:
354         return blend<1048575>(a, b);
355       case 21:
356         return blend<2097151>(a, b);
357       case 22:
358         return blend<4194303>(a, b);
359       case 23:
360         return blend<8388607>(a, b);
361       case 24:
362         return blend<16777215>(a, b);
363       case 25:
364         return blend<33554431>(a, b);
365       case 26:
366         return blend<67108863>(a, b);
367       case 27:
368         return blend<134217727>(a, b);
369       case 28:
370         return blend<268435455>(a, b);
371       case 29:
372         return blend<536870911>(a, b);
373       case 30:
374         return blend<1073741823>(a, b);
375       case 31:
376         return blend<2147483647>(a, b);
377     }
378     return b;
379   }
380   #pragma clang diagnostic push
381   #pragma clang diagnostic ignored "-Wignored-qualifiers"
382 
map(SLEEF_CONST __m512 (* SLEEF_CONST_OLD vop)(__m512))383   Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
384     __m512 lo, hi;
385     cvt_to_fp32<T>(values, lo, hi);
386     const auto o1 = vop(lo);
387     const auto o2 = vop(hi);
388     return cvt_from_fp32<T>(o1, o2);
389   }
isnan()390   Vectorized<T> isnan() const {
391     __m512 lo, hi;
392     cvt_to_fp32<T>(values, lo, hi);
393     __mmask16 lo_mask, hi_mask;
394     __m512 zero = _mm512_set1_ps(0.0);
395     __m512i zeroi = _mm512_castps_si512(zero);
396     lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
397     lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
398     hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
399     hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
400     return merge_compare_result(lo, hi);
401   }
402   #pragma clang diagnostic pop
abs()403   Vectorized<T> abs() const {
404     return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
405   }
angle()406   Vectorized<T> angle() const {
407     __m512 lo, hi;
408     cvt_to_fp32<T>(values, lo, hi);
409     auto angle_lambda = [](__m512 values) {
410       const auto zero_vec = _mm512_set1_ps(0.f);
411       const auto nan_vec = _mm512_set1_ps(NAN);
412       const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
413       const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec),
414                                                            not_nan_mask, 0xFFFFFFFF);
415       const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec),
416                                                zero_vec, _CMP_EQ_OQ);
417       const auto pi = _mm512_set1_ps(c10::pi<float>);
418 
419       const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
420       auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
421       angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
422       return angle;
423     };
424     auto o1 = angle_lambda(lo);
425     auto o2 = angle_lambda(hi);
426     return cvt_from_fp32<T>(o1, o2);
427   }
real()428   Vectorized<T> real() const {
429     return *this;
430   }
imag()431   Vectorized<T> imag() const {
432     return _mm512_set1_epi16(0);
433   }
conj()434   Vectorized<T> conj() const {
435     return *this;
436   }
acos()437   Vectorized<T> acos() const {
438     return map(Sleef_acosf16_u10);
439   }
acosh()440   Vectorized<T> acosh() const {
441     return map(Sleef_acoshf16_u10);
442   }
asin()443   Vectorized<T> asin() const {
444     return map(Sleef_asinf16_u10);
445   }
atan()446   Vectorized<T> atan() const {
447     return map(Sleef_atanf16_u10);
448   }
atanh()449   Vectorized<T> atanh() const {
450     return map(Sleef_atanhf16_u10);
451   }
atan2(const Vectorized<T> & b)452   Vectorized<T> atan2(const Vectorized<T> &b) const {
453     __m512 lo, hi;
454     __m512 b1, b2;
455     cvt_to_fp32<T>(values, lo, hi);
456     cvt_to_fp32<T>(b.values, b1, b2);
457     auto o1 = Sleef_atan2f16_u10(lo, b1);
458     auto o2 = Sleef_atan2f16_u10(hi, b2);
459     return cvt_from_fp32<T>(o1, o2);
460   }
copysign(const Vectorized<T> & sign)461   Vectorized<T> copysign(const Vectorized<T> &sign) const {
462     // copy sign bit (0x8000) from sign and remaining bits from values
463     __m512i mask_value = _mm512_set1_epi32(~0x80008000);
464     __m512i mask_signbit = _mm512_set1_epi32(0x80008000);
465     return Vectorized<T>(
466       _mm512_or_si512(
467         _mm512_and_si512(values, mask_value),
468         _mm512_and_si512(sign, mask_signbit)));
469   }
erf()470   Vectorized<T> erf() const {
471     return map(Sleef_erff16_u10);
472   }
erfc()473   Vectorized<T> erfc() const {
474     return map(Sleef_erfcf16_u15);
475   }
erfinv()476   Vectorized<T> erfinv() const {
477     __m512 lo, hi;
478     cvt_to_fp32<T>(values, lo, hi);
479     __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
480     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
481     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
482     for (int64_t i = 0; i < size() / 2; i++) {
483       tmp1[i] = calc_erfinv(tmp1[i]);
484       tmp2[i] = calc_erfinv(tmp2[i]);
485     }
486     auto o1 = _mm512_loadu_ps(tmp1);
487     auto o2 = _mm512_loadu_ps(tmp2);
488     return cvt_from_fp32<T>(o1, o2);
489   }
exp()490   Vectorized<T> exp() const {
491     return map(Sleef_expf16_u10);
492   }
exp2()493   Vectorized<T> exp2() const {
494     return map(Sleef_exp2f16_u10);
495   }
expm1()496   Vectorized<T> expm1() const {
497     return map(Sleef_expm1f16_u10);
498   }
exp_u20()499   Vectorized<T> exp_u20() const {
500     return exp();
501   }
fmod(const Vectorized<T> & q)502   Vectorized<T> fmod(const Vectorized<T> & q) const {
503     __m512 x_lo, x_hi;
504     cvt_to_fp32<T>(values, x_lo, x_hi);
505     __m512 q_lo, q_hi;
506     cvtbf16_fp32(q.values, q_lo, q_hi);
507     auto o1 = Sleef_fmodf16(x_lo, q_lo);
508     auto o2 = Sleef_fmodf16(x_hi, q_hi);
509     return cvt_from_fp32<T>(o1, o2);
510   }
hypot(const Vectorized<T> & b)511   Vectorized<T> hypot(const Vectorized<T> &b) const {
512     __m512 lo, hi;
513     __m512 b1, b2;
514     cvt_to_fp32<T>(values, lo, hi);
515     cvt_to_fp32<T>(b.values, b1, b2);
516     auto o1 = Sleef_hypotf16_u05(lo, b1);
517     auto o2 = Sleef_hypotf16_u05(hi, b2);
518     return cvt_from_fp32<T>(o1, o2);
519   }
i0()520   Vectorized<T> i0() const {
521     __m512 lo, hi;
522     cvt_to_fp32<T>(values, lo, hi);
523     __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
524     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
525     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
526     for (int64_t i = 0; i < size() / 2; i++) {
527       tmp1[i] = calc_i0(tmp1[i]);
528       tmp2[i] = calc_i0(tmp2[i]);
529     }
530     auto o1 = _mm512_loadu_ps(tmp1);
531     auto o2 = _mm512_loadu_ps(tmp2);
532     return cvt_from_fp32<T>(o1, o2);
533   }
i0e()534   Vectorized<T> i0e() const {
535     __m512 lo, hi;
536     cvt_to_fp32<T>(values, lo, hi);
537     constexpr auto sz = size();
538     __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
539     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
540     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
541 
542     for (auto i = decltype(sz){0}; i < sz / 2; i++) {
543       tmp1[i] = calc_i0e(tmp1[i]);
544       tmp2[i] = calc_i0e(tmp2[i]);
545     }
546     const auto o1 = _mm512_loadu_ps(tmp1);
547     const auto o2 = _mm512_loadu_ps(tmp2);
548     return cvt_from_fp32<T>(o1, o2);
549   }
digamma()550   Vectorized<T> digamma() const {
551     __m512 lo, hi;
552     cvt_to_fp32<T>(values, lo, hi);
553     constexpr auto sz = size();
554     __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
555     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
556     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
557 
558     for (auto i = decltype(sz){0}; i < sz / 2; i++) {
559       tmp1[i] = calc_digamma(tmp1[i]);
560       tmp2[i] = calc_digamma(tmp2[i]);
561     }
562     const auto o1 = _mm512_loadu_ps(tmp1);
563     const auto o2 = _mm512_loadu_ps(tmp2);
564     return cvt_from_fp32<T>(o1, o2);
565   }
igamma(const Vectorized<T> & x)566   Vectorized<T> igamma(const Vectorized<T> &x) const {
567     __m512 lo, hi;
568     __m512 xlo, xhi;
569     cvt_to_fp32<T>(values, lo, hi);
570     cvt_to_fp32<T>(x.values, xlo, xhi);
571     __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
572     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
573     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
574     __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
575     _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
576     _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
577     for (int64_t i = 0; i < size() / 2; ++i) {
578       tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
579       tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
580     }
581     auto o1 = _mm512_loadu_ps(tmp1);
582     auto o2 = _mm512_loadu_ps(tmp2);
583     return cvt_from_fp32<T>(o1, o2);
584   }
585 
igammac(const Vectorized<T> & x)586   Vectorized<T> igammac(const Vectorized<T> &x) const {
587     __m512 lo, hi;
588     __m512 xlo, xhi;
589     cvt_to_fp32<T>(values, lo, hi);
590     cvt_to_fp32<T>(x.values, xlo, xhi);
591     __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
592     _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
593     _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
594     __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
595     _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
596     _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
597     for (int64_t i = 0; i < size() / 2; ++i) {
598       tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
599       tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
600     }
601     auto o1 = _mm512_loadu_ps(tmp1);
602     auto o2 = _mm512_loadu_ps(tmp2);
603     return cvt_from_fp32<T>(o1, o2);
604   }
log()605   Vectorized<T> log() const {
606     return map(Sleef_logf16_u10);
607   }
log2()608   Vectorized<T> log2() const {
609     return map(Sleef_log2f16_u10);
610   }
log10()611   Vectorized<T> log10() const {
612     return map(Sleef_log10f16_u10);
613   }
log1p()614   Vectorized<T> log1p() const {
615     return map(Sleef_log1pf16_u10);
616   }
sin()617   Vectorized<T> sin() const {
618     return map(Sleef_sinf16_u10);
619   }
sinh()620   Vectorized<T> sinh() const {
621     return map(Sleef_sinhf16_u10);
622   }
cos()623   Vectorized<T> cos() const {
624     return map(Sleef_cosf16_u10);
625   }
cosh()626   Vectorized<T> cosh() const {
627     return map(Sleef_coshf16_u10);
628   }
ceil()629   Vectorized<T> ceil() const {
630     __m512 lo, hi;
631     cvt_to_fp32<T>(values, lo, hi);
632     auto o1 = _mm512_ceil_ps(lo);
633     auto o2 = _mm512_ceil_ps(hi);
634     return cvt_from_fp32<T>(o1, o2);
635   }
floor()636   Vectorized<T> floor() const {
637     __m512 lo, hi;
638     cvt_to_fp32<T>(values, lo, hi);
639     auto o1 = _mm512_floor_ps(lo);
640     auto o2 = _mm512_floor_ps(hi);
641     return cvt_from_fp32<T>(o1, o2);
642   }
neg()643   Vectorized<T> neg() const {
644     return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
645   }
round()646   Vectorized<T> round() const {
647     __m512 lo, hi;
648     cvt_to_fp32<T>(values, lo, hi);
649     auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
650     auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
651     return cvt_from_fp32<T>(o1, o2);
652   }
tan()653   Vectorized<T> tan() const {
654     return map(Sleef_tanf16_u10);
655   }
tanh()656   Vectorized<T> tanh() const {
657     return map(Sleef_tanhf16_u10);
658   }
trunc()659   Vectorized<T> trunc() const {
660     __m512 lo, hi;
661     cvt_to_fp32<T>(values, lo, hi);
662     auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
663     auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
664     return cvt_from_fp32<T>(o1, o2);
665   }
lgamma()666   Vectorized<T> lgamma() const {
667     return map(Sleef_lgammaf16_u10);
668   }
sqrt()669   Vectorized<T> sqrt() const {
670     __m512 lo, hi;
671     cvt_to_fp32<T>(values, lo, hi);
672     auto o1 = _mm512_sqrt_ps(lo);
673     auto o2 = _mm512_sqrt_ps(hi);
674     return cvt_from_fp32<T>(o1, o2);
675   }
reciprocal()676   Vectorized<T> reciprocal() const {
677     __m512 lo, hi;
678     cvt_to_fp32<T>(values, lo, hi);
679     auto ones = _mm512_set1_ps(1);
680     auto o1 = _mm512_div_ps(ones, lo);
681     auto o2 = _mm512_div_ps(ones, hi);
682     return cvt_from_fp32<T>(o1, o2);
683   }
rsqrt()684   Vectorized<T> rsqrt() const {
685     __m512 lo, hi;
686     cvt_to_fp32<T>(values, lo, hi);
687     auto ones = _mm512_set1_ps(1);
688     auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
689     auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
690     return cvt_from_fp32<T>(o1, o2);
691   }
pow(const Vectorized<T> & b)692   Vectorized<T> pow(const Vectorized<T> &b) const {
693     __m512 lo, hi;
694     __m512 b1, b2;
695     cvt_to_fp32<T>(values, lo, hi);
696     cvt_to_fp32<T>(b.values, b1, b2);
697     auto o1 = Sleef_powf16_u10(lo, b1);
698     auto o2 = Sleef_powf16_u10(hi, b2);
699     return cvt_from_fp32<T>(o1, o2);
700   }
701 private:
702   template<typename Op>
binary_compare(const Vectorized<T> & b,Op op)703   Vectorized<T> inline binary_compare(const Vectorized<T>& b, Op op) const {
704     __m512 a_lo, a_hi;
705     __m512 b_lo, b_hi;
706     cvt_to_fp32<T>(values, a_lo, a_hi);
707     cvt_to_fp32<T>(b.values, b_lo, b_hi);
708     auto o1 = op(a_lo, b_lo);
709     auto o2 = op(a_hi, b_hi);
710     return cvt_from_fp32<T, /*is_compare_op*/true>(o1, o2);
711   }
712 
713 public:
714   Vectorized<T> inline operator>(const Vectorized<T>& other) const {
715     return binary_compare(other, [](__m512 x, __m512 y) {
716       auto zero_vec = _mm512_set1_epi32(0);
717       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
718       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
719     });
720   }
721   Vectorized<T> inline operator<(const Vectorized<T>& other) const {
722     return binary_compare(other, [](__m512 x, __m512 y) {
723       auto zero_vec = _mm512_set1_epi32(0);
724       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
725       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
726     });
727   }
728   Vectorized<T> inline operator>=(const Vectorized<T>& other) const {
729     return binary_compare(other, [](__m512 x, __m512 y) {
730       auto zero_vec = _mm512_set1_epi32(0);
731       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
732       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
733     });
734   }
735   Vectorized<T> inline operator<=(const Vectorized<T>& other) const {
736     return binary_compare(other, [](__m512 x, __m512 y) {
737       auto zero_vec = _mm512_set1_epi32(0);
738       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
739       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
740     });
741   }
742   Vectorized<T> inline operator==(const Vectorized<T>& other) const {
743     return binary_compare(other, [](__m512 x, __m512 y) {
744       auto zero_vec = _mm512_set1_epi32(0);
745       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
746       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
747     });
748   }
749   Vectorized<T> inline operator!=(const Vectorized<T>& other) const {
750     return binary_compare(other, [](__m512 x, __m512 y) {
751       auto zero_vec = _mm512_set1_epi32(0);
752       auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
753       return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
754     });
755   }
756 };
757 
758 template<typename T, typename Op>
binary_op_as_fp32(const Vectorized<T> & a,const Vectorized<T> & b,Op op)759 static inline Vectorized<T> binary_op_as_fp32(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
760   __m512 a_lo, a_hi;
761   __m512 b_lo, b_hi;
762   cvt_to_fp32<T>(__m512i(a), a_lo, a_hi);
763   cvt_to_fp32<T>(__m512i(b), b_lo, b_hi);
764   auto o1 = op(a_lo, b_lo);
765   auto o2 = op(a_hi, b_hi);
766   return cvt_from_fp32<T>(o1, o2);
767 }
768 
769 template <>
770 class Vectorized<BFloat16>: public Vectorized16<BFloat16> {
771 public:
772   using Vectorized16::Vectorized16;
773 
774   Vectorized<BFloat16> frac() const;
775 
776   Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
777   Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
778   Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
779   Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
780   Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
781   Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
782 };
783 
784 Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
785   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
786 }
787 Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
788   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
789 }
790 Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
791   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
792 }
793 Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
794   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
795 }
796 Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
797   return _mm512_and_si512(a, b);
798 }
799 Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
800   return _mm512_or_si512(a, b);
801 }
802 Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
803   return _mm512_xor_si512(a, b);
804 }
805 
eq(const Vectorized<BFloat16> & other)806 inline Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
807   return (*this == other) & Vectorized<BFloat16>(1.0f);
808 }
809 
ne(const Vectorized<BFloat16> & other)810 inline Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
811   return (*this != other) & Vectorized<BFloat16>(1.0f);
812 }
813 
gt(const Vectorized<BFloat16> & other)814 inline Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
815   return (*this > other) & Vectorized<BFloat16>(1.0f);
816 }
817 
ge(const Vectorized<BFloat16> & other)818 inline Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
819   return (*this >= other) & Vectorized<BFloat16>(1.0f);
820 }
821 
lt(const Vectorized<BFloat16> & other)822 inline Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
823   return (*this < other) & Vectorized<BFloat16>(1.0f);
824 }
825 
le(const Vectorized<BFloat16> & other)826 inline Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
827   return (*this <= other) & Vectorized<BFloat16>(1.0f);
828 }
829 
830 // frac. Implement this here so we can use subtraction
frac()831 inline Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
832   return *this - this->trunc();
833 }
834 
835 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
836 // either input is a NaN.
837 template <>
maximum(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b)838 Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
839   __m512 a_lo, a_hi;
840   __m512 b_lo, b_hi;
841   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
842   cvtbf16_fp32(__m512i(b), b_lo, b_hi);
843   auto max_lo = _mm512_max_ps(a_lo, b_lo);
844   auto max_hi = _mm512_max_ps(a_hi, b_hi);
845   auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
846   auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
847   auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
848   auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
849   // Exploit the fact that all-ones is a NaN.
850   auto o1 = _mm512_or_ps(max_lo, nan_lo);
851   auto o2 = _mm512_or_ps(max_hi, nan_hi);
852   return cvtfp32_bf16(o1, o2);
853 }
854 
855 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
856 // either input is a NaN.
857 template <>
minimum(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b)858 Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
859   __m512 a_lo, a_hi;
860   __m512 b_lo, b_hi;
861   __m512i zero_vec = _mm512_set1_epi32(0);
862   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
863   cvtbf16_fp32(__m512i(b), b_lo, b_hi);
864   auto min_lo = _mm512_min_ps(a_lo, b_lo);
865   auto min_hi = _mm512_min_ps(a_hi, b_hi);
866   auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
867   auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
868   auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
869                                                            0xFFFFFFFF));
870   auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
871                                                            0xFFFFFFFF));
872   // Exploit the fact that all-ones is a NaN.
873   auto o1 = _mm512_or_ps(min_lo, nan_lo);
874   auto o2 = _mm512_or_ps(min_hi, nan_hi);
875   return cvtfp32_bf16(o1, o2);
876 }
877 
878 template <>
clamp(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & min,const Vectorized<BFloat16> & max)879 Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
880     const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
881   __m512 a_lo, a_hi;
882   __m512 min_lo, min_hi;
883   __m512 max_lo, max_hi;
884   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
885   cvtbf16_fp32(__m512i(min), min_lo, min_hi);
886   cvtbf16_fp32(__m512i(max), max_lo, max_hi);
887   auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
888   auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
889   return cvtfp32_bf16(o1, o2);
890 }
891 
892 template <>
clamp_max(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & max)893 Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
894   __m512 a_lo, a_hi;
895   __m512 max_lo, max_hi;
896   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
897   cvtbf16_fp32(__m512i(max), max_lo, max_hi);
898   auto o1 = _mm512_min_ps(max_lo, a_lo);
899   auto o2 = _mm512_min_ps(max_hi, a_hi);
900   return cvtfp32_bf16(o1, o2);
901 }
902 
903 template <>
clamp_min(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & min)904 Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
905   __m512 a_lo, a_hi;
906   __m512 min_lo, min_hi;
907   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
908   cvtbf16_fp32(__m512i(min), min_lo, min_hi);
909   auto o1 = _mm512_max_ps(min_lo, a_lo);
910   auto o2 = _mm512_max_ps(min_hi, a_hi);
911   return cvtfp32_bf16(o1, o2);
912 }
913 
914 template <>
convert(const BFloat16 * src,BFloat16 * dst,int64_t n)915 inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
916   int64_t i;
917 #ifndef __msvc_cl__
918 #pragma unroll
919 #endif
920   for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
921     auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
922     _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
923   }
924 #ifndef __msvc_cl__
925 #pragma unroll
926 #endif
927   for (; i < n; i++) {
928     dst[i] = src[i];
929   }
930 }
931 
932 template <>
convert(const float * src,BFloat16 * dst,int64_t n)933 inline void convert(const float* src, BFloat16* dst, int64_t n) {
934   int64_t i;
935   for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
936     __m512 a = _mm512_loadu_ps(&src[i]);
937     __m512 b = _mm512_loadu_ps(&src[i + 16]);
938 
939     __m512i bf = cvtfp32_bf16(a, b);
940     _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
941   }
942   for (; i < n; i++) {
943     dst[i] = c10::convert<BFloat16>(src[i]);
944   }
945 }
946 
947 template <>
convert(const double * src,BFloat16 * dst,int64_t n)948 inline void convert(const double* src, BFloat16* dst, int64_t n) {
949   auto load_float = [](const double *src) -> __m512 {
950     // Load one float vector from an array of doubles
951     __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
952     __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
953     return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
954   };
955 
956   int64_t i;
957   for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
958     __m512 a = load_float(&src[i]);
959     __m512 b = load_float(&src[i + 16]);
960 
961     __m512i bf = cvtfp32_bf16(a, b);
962     _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
963   }
964   for (; i < n; i++) {
965     dst[i] = c10::convert<BFloat16>(src[i]);
966   }
967 }
968 
969 template <>
fmadd(const Vectorized<BFloat16> & a,const Vectorized<BFloat16> & b,const Vectorized<BFloat16> & c)970 Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
971     const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
972   __m512 a_lo, a_hi;
973   __m512 b_lo, b_hi;
974   __m512 c_lo, c_hi;
975   cvtbf16_fp32(__m512i(a), a_lo, a_hi);
976   cvtbf16_fp32(__m512i(b), b_lo, b_hi);
977   cvtbf16_fp32(__m512i(c), c_lo, c_hi);
978   auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
979   auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
980   return cvtfp32_bf16(o1, o2);
981 }
982 
_transpose_mxn_half_16_16(__m256i t[],__m512i u[])983 static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) {
984   __m512i r[8];
985   // a0a1 a2a3 a4a5 a6a7 a8a9 a10a11 a12a13 a14a15   e0e1 e2e3 e4e5 e6e7 e8e9 e10e11 e12e13 e14e15
986   // b0-b15  f0-f15
987   // c0-c15  g0-g15
988   // d0-d15  h0-h15
989   // i0-i15  m0-m15
990   // j0-j15  n0-n15
991   // k0-k15  o0-o15
992   // l0-l15  p0-p15
993 #ifndef __msvc_cl__
994 #pragma unroll(4)
995 #endif
996   for (int i = 0; i < 4; i++) {
997     r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01);
998     r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01);
999   }
1000 
1001   // u0: a0a1 b0b1 a2a3 b2b3 a8a9 b8b9 a10a11 b10b11   e0e1 f0f1 e2e3 f2f3 e8e9 f8f9 e10e11 f10f11
1002   // u1: a4a5 b4b5 a6a7 b6b7 a12a13 b12b13 a14a15 b14b15   e4e5 f4f5 e6e7 f6f7 e12e13 f12f13 e14e15 f14f15
1003   // u2: c0c1 d0d1 c2c3 d2d3 c8c9 d8d9 c10c11 d10d11   g0g1 h0h1 g2g3 h2h3 g8g9 h8h9 g10g11 h10h11
1004   // u3: c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15   g4g5 h4h5 g6g7 h6h7 g12g13 h12h13 g14g15 h14h15
1005   // i j  m n
1006   // k l  o p
1007 #ifndef __msvc_cl__
1008 #pragma unroll(4)
1009 #endif
1010   for (int i = 0; i < 8; i += 2) {
1011     u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]);
1012     u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]);
1013   }
1014 
1015   // r0: a0a1 b0b1 c0c1 d0d1 a8a9 b8b9 c8c9 d8d9  e0e1 f0f1 g0g1 h0h1 e8e9 f8f9 g8g9 h8h9
1016   // r1: a2a3 b2b3 c2c3 d2d3 a10a11 b10b11 c10c11 d10d11  e2e3 f2f3 g2g3 h2h3 e10e11 f10f11 g10g11 h10h11
1017   // r2: a4a5 b4b5 c4c5 d4b5 a12a13 b12b13 c12c13 d12d13
1018   // r3: a6a7 b6b7 c6c7 d6b7 a14a15 b14b15 c14c15 d14d15
1019   // r4: i j k l m n o p
1020   r[0] = _mm512_unpacklo_epi64(u[0], u[2]);
1021   r[1] = _mm512_unpackhi_epi64(u[0], u[2]);
1022   r[2] = _mm512_unpacklo_epi64(u[1], u[3]);
1023   r[3] = _mm512_unpackhi_epi64(u[1], u[3]);
1024   r[4] = _mm512_unpacklo_epi64(u[4], u[6]);
1025   r[5] = _mm512_unpackhi_epi64(u[4], u[6]);
1026   r[6] = _mm512_unpacklo_epi64(u[5], u[7]);
1027   r[7] = _mm512_unpackhi_epi64(u[5], u[7]);
1028 
1029   __m512i const1 = _mm512_set_epi32(
1030       0x00370035,
1031       0x00330031,
1032       0x00270025,
1033       0x00230021,
1034       0x00170015,
1035       0x00130011,
1036       0x00070005,
1037       0x00030001,
1038       0x00360034,
1039       0x00320030,
1040       0x00260024,
1041       0x00220020,
1042       0x00160014,
1043       0x00120010,
1044       0x00060004,
1045       0x00020000);
1046   __m512i const2 = _mm512_set_epi32(
1047       0x003f003d,
1048       0x003b0039,
1049       0x002f002d,
1050       0x002b0029,
1051       0x001f001d,
1052       0x001b0019,
1053       0x000f000d,
1054       0x000b0009,
1055       0x003e003c,
1056       0x003a0038,
1057       0x002e002c,
1058       0x002a0028,
1059       0x001e001c,
1060       0x001a0018,
1061       0x000e000c,
1062       0x000a0008);
1063   // merge values from two regs
1064   // 0-- 1--
1065   // 8-- 9--
1066   // 2-- 3--
1067   // 10-- 11--
1068   // 4-- 5--
1069   // 12-- 13--
1070   // 6-- 7--
1071   // 14-- 15--
1072 #ifndef __msvc_cl__
1073 #pragma unroll(4)
1074 #endif
1075   for (int i = 0; i < 4; i++) {
1076     u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]);
1077     u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]);
1078   }
1079 }
1080 
1081 // TODO(Leslie): Add the AVX2 Version of transpose_mxn for BFloat16 and Float16
1082 // Code referred to FBGEMM:
1083 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
1084 template<>
1085 inline void transpose_mxn<BFloat16, 16, 16>(
1086     const BFloat16* src,
1087     int64_t ld_src,
1088     BFloat16* dst,
1089     int64_t ld_dst) {
1090   __m256i t[16];
1091   // load from src to registers
1092   // a: a0  a1  a2  a3  a4  a5  a6  a7  a8  a9  a10 a11 a12 a13 a14 a15
1093   // b: b0  b1  b2  b3  b4  b5  b6  b7  b8  b9  b10 b11 b12 b13 b14 b15
1094   // c: c0  c1  c2  c3  c4  c5  c6  c7  c8  c9  c10 c11 c12 c13 c14 c15
1095   // d: d0  d1  d2  d3  d4  d5  d6  d7  d8  d9  d10 d11 d12 d13 d14 d15
1096   // e: e0  e1  e2  e3  e4  e5  e6  e7  e8  e9  e10 e11 e12 e13 e14 e15
1097   // f: f0  f1  f2  f3  f4  f5  f6  f7  f8  f9  f10 f11 f12 f13 f14 f15
1098   // g: g0  g1  g2  g3  g4  g5  g6  g7  g8  g9  g10 g11 g12 g13 g14 g15
1099   // h: h0  h1  h2  h3  h4  h5  h6  h7  h8  h9  h10 h11 h12 h13 h14 h15
1100   // i: i0  i1  i2  i3  i4  i5  i6  i7  i8  i9  i10 i11 i12 i13 i14 i15
1101   // j: j0  j1  j2  j3  j4  j5  j6  j7  j8  j9  j10 j11 j12 j13 j14 j15
1102   // k: k0  k1  k2  k3  k4  k5  k6  k7  k8  k9  k10 k11 k12 k13 k14 k15
1103   // l: l0  l1  l2  l3  l4  l5  l6  l7  l8  l9  l10 l11 l12 l13 l14 l15
1104   // m: m0  m1  m2  m3  m4  m5  m6  m7  m8  m9  m10 m11 m12 m13 m14 m15
1105   // n: n0  n1  n2  n3  n4  n5  n6  n7  n8  n9  n10 n11 n12 n13 n14 n15
1106   // o: o0  o1  o2  o3  o4  o5  o6  o7  o8  o9  o10 o11 o12 o13 o14 o15
1107   // p: p0  p1  p2  p3  p4  p5  p6  p7  p8  p9  p10 p11 p12 p13 p14 p15
1108 #ifndef __msvc_cl__
1109 #pragma unroll(16)
1110 #endif
1111   for (int i = 0; i < 16; i++) {
1112     t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
1113   }
1114 
1115   __m512i u[8];
1116   _transpose_mxn_half_16_16(t, u);
1117 
1118 #ifndef __msvc_cl__
1119 #pragma unroll(8)
1120 #endif
1121   for (int i = 0; i < 8; i++) {
1122     _mm256_storeu_si256(
1123       reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
1124       _mm512_extracti32x8_epi32(u[i], 0x0));
1125     _mm256_storeu_si256(
1126         reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
1127         _mm512_extracti32x8_epi32(u[i], 0x01));
1128   }
1129 }
1130 
1131 // Code referred to FBGEMM:
1132 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#L1483-L1607
1133 template<>
1134 inline void transpose_mxn<Half, 16, 16>(
1135     const Half* src,
1136     int64_t ld_src,
1137     Half* dst,
1138     int64_t ld_dst) {
1139   __m256i t[16];
1140   // load from src to registers
1141   // Same matrix indices as above transpose_mxn<BFloat16, 16, 16>
1142 #ifndef __msvc_cl__
1143 #pragma unroll(16)
1144 #endif
1145   for (int i = 0; i < 16; i++) {
1146     t[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
1147   }
1148 
1149   __m512i u[8];
1150   _transpose_mxn_half_16_16(t, u);
1151 
1152 #ifndef __msvc_cl__
1153 #pragma unroll(8)
1154 #endif
1155   for (int i = 0; i < 8; i++) {
1156     _mm256_storeu_si256(
1157       reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
1158       _mm512_extracti32x8_epi32(u[i], 0x0));
1159     _mm256_storeu_si256(
1160         reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
1161         _mm512_extracti32x8_epi32(u[i], 0x01));
1162   }
1163 }
1164 
_transpose_mxn_half_32_32(__m512i r[],__m512i d[])1165 static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) {
1166   // t[0]: 0 32 1 33 2 34 3 35 8 40 9 41 10 42 11 43 16 ... 59
1167   // t[1]: 4 36 5 37 6 38 7 39 12 44 13 45 14 46 15 47 20 ... 63
1168   // t[2]: 64 96 65 97 66 98 67 99 72 104 73 105 74 106 75 ... 123
1169   // t[3]: 68 100 69 101 70 102 71 103 76 108 77 109 78 110 79 111 84 ... 127
1170   // t[4]: 128 160 129 161 130 162 131 163 136 168 137 169 138 170 139 171 144 ... 187
1171   // t[5]: 132 164 133 165 134 166 135 167 140 172 141 173 142 174 143 175 148 ... 191
1172   // t[6]: 192 224 193 225 194 226 195 227 200 232 201 233 202 234 203 235 208 ... 251
1173   // t[7]: 196 228 197 229 198 230 199 231 204 236 205 237 206 238 207 239 212 ... 255
1174   // t[8]: 256 288 257 289 258 290 259 291 264 296 265 297 266 298 267 299 272 ... 315
1175   // t[9]: 260 292 261 293 262 294 263 295 268 300 269 301 270 302 271 303 276 ... 319
1176   // t[10]: 320 352 321 353 322 354 323 355 328 360 329 361 330 362 331 363 336 ... 379
1177   // t[11]: 324 356 325 357 326 358 327 359 332 364 333 365 334 366 335 367 340 ... 383
1178   // t[12]: 384 416 385 417 386 418 387 419 392 424 393 425 394 426 395 427 400 ... 443
1179   // t[13]: 388 420 389 421 390 422 391 423 396 428 397 429 398 430 399 431 404 ... 447
1180   // t[14]: 448 480 449 481 450 482 451 483 456 488 457 489 458 490 459 491 464 ... 507
1181   // t[15]: 452 484 453 485 454 486 455 487 460 492 461 493 462 494 463 495 468 ... 511
1182   // t[16]: 512 544 513 545 514 546 515 547 520 552 521 553 522 554 523 555 528 ... 571
1183   // ...
1184   // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 980 ... 1023
1185 #ifndef __msvc_cl__
1186 #pragma unroll(16)
1187 #endif
1188   for (int i = 0; i < 16; ++i) {
1189     d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]);
1190     d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]);
1191   }
1192 
1193   // t[0]: 0 32 64 96 1 33 65 97 8 40 72 104 9 41 73 105 16 ... 121
1194   // t[1]: 2 34 66 98 3 35 67 99 10 42 74 106 11 43 75 107 18 ... 123
1195   // t[2]: 4 36 68 100 5 37 69 101 12 44 76 108 13 45 77 109 20 ... 125
1196   // t[3]: 6 38 70 102 7 39 71 103 14 46 78 110 15 47 79 111 22 ... 127
1197   // t[4]: 128 160 192 224 129 161 193 225 136 168 200 232 137 169 201 233 144 ... 249
1198   // t[5]: 130 162 194 226 131 163 195 227 138 170 202 234 139 171 203 235 146 ... 251
1199   // t[6]: 132 164 196 228 133 165 197 229 140 172 204 236 141 173 205 237 148 ... 253
1200   // t[7]: 134 166 198 230 135 167 199 231 142 174 206 238 143 175 207 239 150 ... 255
1201   // t[8]: 256 288 320 352 257 289 321 353 264 296 328 360 265 297 329 361 272 ... 377
1202   // t[9]: 258 290 322 354 259 291 323 355 266 298 330 362 267 299 331 363 274 ... 379
1203   // t[10]: 260 292 324 356 261 293 325 357 268 300 332 364 269 301 333 365 276 ... 381
1204   // t[11]: 262 294 326 358 263 295 327 359 270 302 334 366 271 303 335 367 278 ... 383
1205   // t[12]: 384 416 448 480 385 417 449 481 392 424 456 488 393 425 457 489 400 ... 505
1206   // t[13]: 386 418 450 482 387 419 451 483 394 426 458 490 395 427 459 491 402 ... 507
1207   // t[14]: 388 420 452 484 389 421 453 485 396 428 460 492 397 429 461 493 404 ... 509
1208   // t[15]: 390 422 454 486 391 423 455 487 398 430 462 494 399 431 463 495 406 ... 511
1209   // t[16]: 512 544 576 608 513 545 577 609 520 552 584 616 521 553 585 617 528 ... 633
1210   // ...
1211   // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 918 ... 1023
1212 #ifndef __msvc_cl__
1213 #pragma unroll(8)
1214 #endif
1215   for (int i = 0; i < 8; ++i) {
1216     r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]);
1217     r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]);
1218     r[i * 4 + 2] = _mm512_unpacklo_epi32(d[i * 4 + 1], d[i * 4 + 3]);
1219     r[i * 4 + 3] = _mm512_unpackhi_epi32(d[i * 4 + 1], d[i * 4 + 3]);
1220   }
1221 
1222   // t[0]: 0 32 64 96 128 160 192 224 8 40 72 104 136 168 200 232 16 ... 248
1223   // t[1]: 1 33 65 97 129 161 193 225 9 41 73 105 137 169 201 233 17 ... 249
1224   // t[2]: 2 34 66 98 130 162 194 226 10 42 74 106 138 170 202 234 18 ... 250
1225   // t[3]: 3 35 67 99 131 163 195 227 11 43 75 107 139 171 203 235 19 ... 251
1226   // t[4]: 4 36 68 100 132 164 196 228 12 44 76 108 140 172 204 236 20 ... 252
1227   // t[5]: 5 37 69 101 133 165 197 229 13 45 77 109 141 173 205 237 21 ... 253
1228   // t[6]: 6 38 70 102 134 166 198 230 14 46 78 110 142 174 206 238 22 ... 254
1229   // t[7]: 7 39 71 103 135 167 199 231 15 47 79 111 143 175 207 239 23 ... 255
1230   // t[8]: 256 288 320 352 384 416 448 480 264 296 328 360 392 424 456 488 272 ... 504
1231   // t[9]: 257 289 321 353 385 417 449 481 265 297 329 361 393 425 457 489 273 ... 505
1232   // t[10]: 258 290 322 354 386 418 450 482 266 298 330 362 394 426 458 490 274 ... 506
1233   // t[11]: 259 291 323 355 387 419 451 483 267 299 331 363 395 427 459 491 275 ... 507
1234   // t[12]: 260 292 324 356 388 420 452 484 268 300 332 364 396 428 460 492 276 ... 508
1235   // t[13]: 261 293 325 357 389 421 453 485 269 301 333 365 397 429 461 493 277 ... 509
1236   // t[14]: 262 294 326 358 390 422 454 486 270 302 334 366 398 430 462 494 278 ... 510
1237   // t[15]: 263 295 327 359 391 423 455 487 271 303 335 367 399 431 463 495 279 ... 511
1238   // t[16]: 512 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760
1239   // ...
1240   // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 ... 1023
1241 #ifndef __msvc_cl__
1242 #pragma unroll(4)
1243 #endif
1244   for (int i = 0; i < 4; ++i) {
1245     d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]);
1246     d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]);
1247     d[i * 8 + 2] = _mm512_unpacklo_epi64(r[i * 8 + 1], r[i * 8 + 5]);
1248     d[i * 8 + 3] = _mm512_unpackhi_epi64(r[i * 8 + 1], r[i * 8 + 5]);
1249     d[i * 8 + 4] = _mm512_unpacklo_epi64(r[i * 8 + 2], r[i * 8 + 6]);
1250     d[i * 8 + 5] = _mm512_unpackhi_epi64(r[i * 8 + 2], r[i * 8 + 6]);
1251     d[i * 8 + 6] = _mm512_unpacklo_epi64(r[i * 8 + 3], r[i * 8 + 7]);
1252     d[i * 8 + 7] = _mm512_unpackhi_epi64(r[i * 8 + 3], r[i * 8 + 7]);
1253   }
1254 
1255   // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 16 ... 496
1256   // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 17 ... 497
1257   // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 18 ... 498
1258   // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 19 ... 499
1259   // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 20 ... 500
1260   // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 21 ... 501
1261   // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 22 ... 502
1262   // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 23 ... 503
1263   // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 24 ... 504
1264   // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 25 ... 505
1265   // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 26 ... 506
1266   // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 27 ... 507
1267   // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 28 ... 508
1268   // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 29 ... 509
1269   // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 30 ... 510
1270   // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 31 ... 511
1271   // t[16]: 512 544 576 608 640 672 704 736 768 800 832 864 896 928 960 992 528 ... 1008
1272   // ...
1273   // t[31]: 527 559 591 623 655 687 719 751 783 815 847 879 911 943 975 1007 543 ... 1023
1274   __m512i const1 = _mm512_set_epi64(
1275       0x000000000000000d,
1276       0x000000000000000c,
1277       0x0000000000000005,
1278       0x0000000000000004,
1279       0x0000000000000009,
1280       0x0000000000000008,
1281       0x0000000000000001,
1282       0x0000000000000000);
1283   __m512i const2 = _mm512_set_epi64(
1284       0x000000000000000f,
1285       0x000000000000000e,
1286       0x0000000000000007,
1287       0x0000000000000006,
1288       0x000000000000000b,
1289       0x000000000000000a,
1290       0x0000000000000003,
1291       0x0000000000000002);
1292 #ifndef __msvc_cl__
1293 #pragma unroll(8)
1294 #endif
1295   for (int i = 0; i < 8; ++i) {
1296     r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/const1, d[i + 8]);
1297     r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/const2, d[i + 8]);
1298     r[i + 16] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const1, d[i + 24]);
1299     r[i + 24] = _mm512_permutex2var_epi64(d[i + 16], /*idx*/const2, d[i + 24]);
1300   }
1301 
1302   // t[0]: 0 32 64 96 128 160 192 224 256 288 320 352 384 416 448 480 512 544 ... 992
1303   // t[1]: 1 33 65 97 129 161 193 225 257 289 321 353 385 417 449 481 513 545 ... 993
1304   // t[2]: 2 34 66 98 130 162 194 226 258 290 322 354 386 418 450 482 514 546 ... 994
1305   // t[3]: 3 35 67 99 131 163 195 227 259 291 323 355 387 419 451 483 515 547 ... 995
1306   // t[4]: 4 36 68 100 132 164 196 228 260 292 324 356 388 420 452 484 516 548 ... 996
1307   // t[5]: 5 37 69 101 133 165 197 229 261 293 325 357 389 421 453 485 517 549 ... 997
1308   // t[6]: 6 38 70 102 134 166 198 230 262 294 326 358 390 422 454 486 518 550 ... 998
1309   // t[7]: 7 39 71 103 135 167 199 231 263 295 327 359 391 423 455 487 519 551 ... 999
1310   // t[8]: 8 40 72 104 136 168 200 232 264 296 328 360 392 424 456 488 520 552 ... 1000
1311   // t[9]: 9 41 73 105 137 169 201 233 265 297 329 361 393 425 457 489 521 553 ... 1001
1312   // t[10]: 10 42 74 106 138 170 202 234 266 298 330 362 394 426 458 490 522 554 ... 1002
1313   // t[11]: 11 43 75 107 139 171 203 235 267 299 331 363 395 427 459 491 523 555 ... 1003
1314   // t[12]: 12 44 76 108 140 172 204 236 268 300 332 364 396 428 460 492 524 556 ... 1004
1315   // t[13]: 13 45 77 109 141 173 205 237 269 301 333 365 397 429 461 493 525 557 ... 1005
1316   // t[14]: 14 46 78 110 142 174 206 238 270 302 334 366 398 430 462 494 526 558 ... 1006
1317   // t[15]: 15 47 79 111 143 175 207 239 271 303 335 367 399 431 463 495 527 559 ... 1007
1318   // t[16]: 16 48 80 112 144 176 208 240 272 304 336 368 400 432 464 496 528 560 ... 1008
1319   // ...
1320   // t[31]: 31 63 95 127 159 191 223 255 287 319 351 383 415 447 479 511 543 575 ... 1023
1321   __m512i const3 = _mm512_set_epi64(
1322       0x000000000000000b,
1323       0x000000000000000a,
1324       0x0000000000000009,
1325       0x0000000000000008,
1326       0x0000000000000003,
1327       0x0000000000000002,
1328       0x0000000000000001,
1329       0x0000000000000000);
1330   __m512i const4 = _mm512_set_epi64(
1331       0x000000000000000f,
1332       0x000000000000000e,
1333       0x000000000000000d,
1334       0x000000000000000c,
1335       0x0000000000000007,
1336       0x0000000000000006,
1337       0x0000000000000005,
1338       0x0000000000000004);
1339 #ifndef __msvc_cl__
1340 #pragma unroll(16)
1341 #endif
1342   for (int i = 0; i < 16; ++i) {
1343     d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/const3, r[i + 16]);
1344     d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/const4, r[i + 16]);
1345   }
1346 }
1347 
1348 // Code referred to FBGEMM:
1349 // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
1350 template<>
1351 inline void transpose_mxn<BFloat16>(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) {
1352   // load from src
1353   TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<BFloat16> expects M, N <= 32.");
1354   __m512i r[32];
1355   int i;
1356   if (N == 32) {
1357     for (i = 0; i < M; ++i) {
1358       r[i] = _mm512_loadu_si512(&src[i * ld_src]);
1359     }
1360   } else {
1361     __mmask32 src_mask = (1 << N) - 1;
1362     for (i = 0; i < M; ++i) {
1363       r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
1364     }
1365   }
1366   for (; i < 32; ++i) {
1367     r[i] = _mm512_setzero_si512();
1368   }
1369 
1370   __m512i d[32];
1371   _transpose_mxn_half_32_32(r, d);
1372 
1373   // store to dst
1374   if (M == 32) {
1375     for (i = 0; i < N; ++i) {
1376       _mm512_storeu_si512(&dst[i * ld_dst], d[i]);
1377     }
1378   } else {
1379     __mmask32 dst_mask = (1 << M) - 1;
1380     for (i = 0; i < N; ++i) {
1381       _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
1382     }
1383   }
1384 }
1385 
1386 template <typename T, int M, int N,
1387           typename std::enable_if_t<std::is_same<T, BFloat16>::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
transpose_mxn(const BFloat16 * src,int64_t ld_src,BFloat16 * dst,int64_t ld_dst)1388 inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) {
1389   transpose_mxn<BFloat16>(src, ld_src, dst, ld_dst, M, N);
1390 }
1391 
1392 template<>
1393 inline void transpose_mxn<Half>(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) {
1394   TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn<Half> expects M, N <= 32.");
1395   // load from src
1396   __m512i r[32];
1397   int i;
1398   if (N == 32) {
1399     for (i = 0; i < M; ++i) {
1400       r[i] = _mm512_loadu_si512(&src[i * ld_src]);
1401     }
1402   } else {
1403     __mmask32 src_mask = (1 << N) - 1;
1404     for (i = 0; i < M; ++i) {
1405       r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]);
1406     }
1407   }
1408   for (; i < 32; ++i) {
1409     r[i] = _mm512_setzero_si512();
1410   }
1411 
1412   __m512i d[32];
1413   _transpose_mxn_half_32_32(r, d);
1414 
1415   // store to dst
1416   if (M == 32) {
1417     for (i = 0; i < N; ++i) {
1418       _mm512_storeu_si512(&dst[i * ld_dst], d[i]);
1419     }
1420   } else {
1421     __mmask32 dst_mask = (1 << M) - 1;
1422     for (i = 0; i < N; ++i) {
1423       _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]);
1424     }
1425   }
1426 }
1427 
1428 template <typename T, int M, int N,
1429           typename std::enable_if_t<std::is_same<T, Half>::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0>
transpose_mxn(const Half * src,int64_t ld_src,Half * dst,int64_t ld_dst)1430 inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) {
1431   transpose_mxn<Half>(src, ld_src, dst, ld_dst, M, N);
1432 }
1433 
1434 template <>
1435 class Vectorized<Half>: public Vectorized16<Half> {
1436 public:
1437   using Vectorized16::Vectorized16;
1438 
1439   Vectorized<Half> frac() const;
1440 
1441   Vectorized<Half> eq(const Vectorized<Half>& other) const;
1442   Vectorized<Half> ne(const Vectorized<Half>& other) const;
1443   Vectorized<Half> gt(const Vectorized<Half>& other) const;
1444   Vectorized<Half> ge(const Vectorized<Half>& other) const;
1445   Vectorized<Half> lt(const Vectorized<Half>& other) const;
1446   Vectorized<Half> le(const Vectorized<Half>& other) const;
1447 };
1448 
1449 Vectorized<Half> inline operator+(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1450   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); });
1451 }
1452 Vectorized<Half> inline operator-(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1453   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); });
1454 }
1455 Vectorized<Half> inline operator*(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1456   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); });
1457 }
1458 Vectorized<Half> inline operator/(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1459   return binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); });
1460 }
1461 
1462 Vectorized<Half> inline operator&(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1463   return _mm512_and_si512(a, b);
1464 }
1465 Vectorized<Half> inline operator|(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1466   return _mm512_or_si512(a, b);
1467 }
1468 Vectorized<Half> inline operator^(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1469   return _mm512_xor_si512(a, b);
1470 }
1471 
eq(const Vectorized<Half> & other)1472 inline Vectorized<Half> Vectorized<Half>::eq(const Vectorized<Half>& other) const {
1473   return (*this == other) & Vectorized<Half>(1.0f);
1474 }
1475 
ne(const Vectorized<Half> & other)1476 inline Vectorized<Half> Vectorized<Half>::ne(const Vectorized<Half>& other) const {
1477   return (*this != other) & Vectorized<Half>(1.0f);
1478 }
1479 
gt(const Vectorized<Half> & other)1480 inline Vectorized<Half> Vectorized<Half>::gt(const Vectorized<Half>& other) const {
1481   return (*this > other) & Vectorized<Half>(1.0f);
1482 }
1483 
ge(const Vectorized<Half> & other)1484 inline Vectorized<Half> Vectorized<Half>::ge(const Vectorized<Half>& other) const {
1485   return (*this >= other) & Vectorized<Half>(1.0f);
1486 }
1487 
lt(const Vectorized<Half> & other)1488 inline Vectorized<Half> Vectorized<Half>::lt(const Vectorized<Half>& other) const {
1489   return (*this < other) & Vectorized<Half>(1.0f);
1490 }
1491 
le(const Vectorized<Half> & other)1492 inline Vectorized<Half> Vectorized<Half>::le(const Vectorized<Half>& other) const {
1493   return (*this <= other) & Vectorized<Half>(1.0f);
1494 }
1495 
1496 // frac. Implement this here so we can use subtraction
frac()1497 inline Vectorized<Half> Vectorized<Half>::frac() const {
1498   return *this - this->trunc();
1499 }
1500 
1501 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
1502 // either input is a NaN.
1503 template <>
maximum(const Vectorized<Half> & a,const Vectorized<Half> & b)1504 Vectorized<Half> inline maximum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1505   __m512 a_lo, a_hi;
1506   __m512 b_lo, b_hi;
1507   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1508   cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1509   auto max_lo = _mm512_max_ps(a_lo, b_lo);
1510   auto max_hi = _mm512_max_ps(a_hi, b_hi);
1511   auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
1512   auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
1513   auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask));
1514   auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask));
1515   // Exploit the fact that all-ones is a NaN.
1516   auto o1 = _mm512_or_ps(max_lo, nan_lo);
1517   auto o2 = _mm512_or_ps(max_hi, nan_hi);
1518   return cvtfp32_fp16(o1, o2);
1519 }
1520 
1521 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
1522 // either input is a NaN.
1523 template <>
minimum(const Vectorized<Half> & a,const Vectorized<Half> & b)1524 Vectorized<Half> inline minimum(const Vectorized<Half>& a, const Vectorized<Half>& b) {
1525   __m512 a_lo, a_hi;
1526   __m512 b_lo, b_hi;
1527   __m512i zero_vec = _mm512_set1_epi32(0);
1528   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1529   cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1530   auto min_lo = _mm512_min_ps(a_lo, b_lo);
1531   auto min_hi = _mm512_min_ps(a_hi, b_hi);
1532   auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q);
1533   auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q);
1534   auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask,
1535                                                            0xFFFFFFFF));
1536   auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask,
1537                                                            0xFFFFFFFF));
1538   // Exploit the fact that all-ones is a NaN.
1539   auto o1 = _mm512_or_ps(min_lo, nan_lo);
1540   auto o2 = _mm512_or_ps(min_hi, nan_hi);
1541   return cvtfp32_fp16(o1, o2);
1542 }
1543 
1544 template <>
clamp(const Vectorized<Half> & a,const Vectorized<Half> & min,const Vectorized<Half> & max)1545 Vectorized<Half> inline clamp(const Vectorized<Half>& a,
1546     const Vectorized<Half>& min, const Vectorized<Half>& max) {
1547   __m512 a_lo, a_hi;
1548   __m512 min_lo, min_hi;
1549   __m512 max_lo, max_hi;
1550   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1551   cvtfp16_fp32(__m512i(min), min_lo, min_hi);
1552   cvtfp16_fp32(__m512i(max), max_lo, max_hi);
1553   auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo));
1554   auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi));
1555   return cvtfp32_fp16(o1, o2);
1556 }
1557 
1558 template <>
clamp_max(const Vectorized<Half> & a,const Vectorized<Half> & max)1559 Vectorized<Half> inline clamp_max(const Vectorized<Half>& a, const Vectorized<Half>& max) {
1560   __m512 a_lo, a_hi;
1561   __m512 max_lo, max_hi;
1562   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1563   cvtfp16_fp32(__m512i(max), max_lo, max_hi);
1564   auto o1 = _mm512_min_ps(max_lo, a_lo);
1565   auto o2 = _mm512_min_ps(max_hi, a_hi);
1566   return cvtfp32_fp16(o1, o2);
1567 }
1568 
1569 template <>
clamp_min(const Vectorized<Half> & a,const Vectorized<Half> & min)1570 Vectorized<Half> inline clamp_min(const Vectorized<Half>& a, const Vectorized<Half>& min) {
1571   __m512 a_lo, a_hi;
1572   __m512 min_lo, min_hi;
1573   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1574   cvtfp16_fp32(__m512i(min), min_lo, min_hi);
1575   auto o1 = _mm512_max_ps(min_lo, a_lo);
1576   auto o2 = _mm512_max_ps(min_hi, a_hi);
1577   return cvtfp32_fp16(o1, o2);
1578 }
1579 
1580 template <>
convert(const Half * src,Half * dst,int64_t n)1581 inline void convert(const Half* src, Half* dst, int64_t n) {
1582   int64_t i;
1583 #ifndef __msvc_cl__
1584 #pragma unroll
1585 #endif
1586   for (i = 0; i <= (n - Vectorized<Half>::size()); i += Vectorized<Half>::size()) {
1587     auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i)));
1588     _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc);
1589   }
1590 #ifndef __msvc_cl__
1591 #pragma unroll
1592 #endif
1593   for (; i < n; i++) {
1594     dst[i] = src[i];
1595   }
1596 }
1597 
1598 template <>
convert(const float * src,Half * dst,int64_t n)1599 inline void convert(const float* src, Half* dst, int64_t n) {
1600   int64_t i;
1601   for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
1602     __m512 a = _mm512_loadu_ps(&src[i]);
1603     __m512 b = _mm512_loadu_ps(&src[i + 16]);
1604 
1605     __m512i bf = cvtfp32_fp16(a, b);
1606     _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
1607   }
1608   for (; i < n; i++) {
1609     dst[i] = c10::convert<Half>(src[i]);
1610   }
1611 }
1612 
1613 template <>
convert(const double * src,Half * dst,int64_t n)1614 inline void convert(const double* src, Half* dst, int64_t n) {
1615   auto load_float = [](const double *src) -> __m512 {
1616     // Load one float vector from an array of doubles
1617     __m256 a = _mm512_cvtpd_ps(_mm512_loadu_pd(src));
1618     __m256 b = _mm512_cvtpd_ps(_mm512_loadu_pd(src + 8));
1619     return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
1620   };
1621 
1622   int64_t i;
1623   for (i = 0; i + Vectorized<Half>::size() <= n; i += Vectorized<Half>::size()) {
1624     __m512 a = load_float(&src[i]);
1625     __m512 b = load_float(&src[i + 16]);
1626 
1627     __m512i bf = cvtfp32_fp16(a, b);
1628     _mm512_storeu_si512(reinterpret_cast<__m512i*>(&dst[i]), bf);
1629   }
1630   for (; i < n; i++) {
1631     dst[i] = c10::convert<Half>(src[i]);
1632   }
1633 }
1634 
1635 template <>
fmadd(const Vectorized<Half> & a,const Vectorized<Half> & b,const Vectorized<Half> & c)1636 Vectorized<Half> inline fmadd(const Vectorized<Half>& a,
1637     const Vectorized<Half>& b, const Vectorized<Half>& c) {
1638   __m512 a_lo, a_hi;
1639   __m512 b_lo, b_hi;
1640   __m512 c_lo, c_hi;
1641   cvtfp16_fp32(__m512i(a), a_lo, a_hi);
1642   cvtfp16_fp32(__m512i(b), b_lo, b_hi);
1643   cvtfp16_fp32(__m512i(c), c_lo, c_hi);
1644   auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo);
1645   auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi);
1646   return cvtfp32_fp16(o1, o2);
1647 }
1648 
1649 #define CONVERT_VECTORIZED_INIT(type, name) \
1650 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
1651   __m512 o1, o2; \
1652   cvt_to_fp32<type>(__m512i(a), o1, o2); \
1653   return std::make_tuple(o1, o2); \
1654 } \
1655 \
1656 inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
1657  return cvt_from_fp32<type>(__m512(a), __m512(b)); \
1658 }
1659 CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
1660 CONVERT_VECTORIZED_INIT(Half, half);
1661 
1662 #else //defined(CPU_CAPABILITY_AVX512)
1663 
1664 #define CONVERT_NON_VECTORIZED_INIT(type, name) \
1665 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
1666   constexpr int64_t K = Vectorized<type>::size(); \
1667   __at_align__ float arr[K]; \
1668   __at_align__ type arr2[K]; \
1669   a.store(arr2); \
1670   for (const auto k : c10::irange(K)) { \
1671     arr[k] = c10::convert<float>(arr2[k]); \
1672   } \
1673   return std::make_tuple( \
1674       Vectorized<float>::loadu(arr), \
1675       Vectorized<float>::loadu(arr + Vectorized<float>::size())); \
1676 } \
1677 \
1678 inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const Vectorized<float>& b) { \
1679   constexpr int64_t K = Vectorized<type>::size(); \
1680   __at_align__ float arr[K]; \
1681   __at_align__ type arr2[K]; \
1682   a.store(arr); \
1683   b.store(arr + Vectorized<float>::size()); \
1684   for (const auto k : c10::irange(K)) { \
1685     arr2[k] = c10::convert<type>(arr[k]); \
1686   } \
1687   return Vectorized<type>::loadu(arr2); \
1688 }
1689 CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
1690 CONVERT_NON_VECTORIZED_INIT(Half, half);
1691 
1692 #endif // defined(CPU_CAPABILITY_AVX512)
1693 
1694 #if defined(CPU_CAPABILITY_AVX512)
1695 #define LOAD_FP32_VECTORIZED_INIT(type, name) \
1696 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
1697   auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
1698   __m512 out_values; \
1699   cvt_to_fp32<type>(values, out_values); \
1700   out = out_values; \
1701 } \
1702 \
1703 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
1704   auto vec = Vectorized<type>::loadu(data); \
1705   __m512 out1_values, out2_values; \
1706   cvt_to_fp32<type>(vec, out1_values, out2_values); \
1707   out1 = out1_values; \
1708   out2 = out2_values; \
1709 }
1710 LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
1711 LOAD_FP32_VECTORIZED_INIT(Half, fp16);
1712 
1713 #else // defined(CPU_CAPABILITY_AVX512)
1714 #define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
1715 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
1716   __at_align__ float values[Vectorized<float>::size()]; \
1717   for (const auto k : c10::irange(Vectorized<float>::size())) { \
1718     values[k] = data[k]; \
1719   } \
1720   out = Vectorized<float>::loadu(values); \
1721 } \
1722 \
1723 inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vectorized<float>& out2) { \
1724   load_fp32_from_##name(data, out1); \
1725   data += Vectorized<float>::size(); \
1726   load_fp32_from_##name(data, out2); \
1727 }
1728 LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16);
1729 LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16);
1730 
1731 #endif
1732 }}}
1733