1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/common.h>
11 #include <xnnpack/dwconv.h>
12 #include <xnnpack/gemm.h>
13 #include <xnnpack/igemm.h>
14 #include <xnnpack/intrinsics-polyfill.h>
15 #include <xnnpack/lut.h>
16 #include <xnnpack/math.h>
17 #include <xnnpack/vaddsub.h>
18 #include <xnnpack/vcvt.h>
19
20
xnn_f16_f32_vcvt_ukernel__avx512skx_x16(size_t n,const void * input,float * output,const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])21 void xnn_f16_f32_vcvt_ukernel__avx512skx_x16(
22 size_t n,
23 const void* input,
24 float* output,
25 const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
26 {
27 assert(n != 0);
28 assert(n % sizeof(uint16_t) == 0);
29 assert(input != NULL);
30 assert(output != NULL);
31
32 const uint16_t* i = (const uint16_t*) input;
33 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
34 const __m512 vacc = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*) i));
35 i += 16;
36
37 _mm512_storeu_ps(output, vacc);
38 output += 16;
39 }
40 if XNN_UNLIKELY(n != 0) {
41 assert(n >= 1 * sizeof(uint16_t));
42 assert(n <= 15 * sizeof(uint16_t));
43
44 // Prepare mask for valid 32-bit elements (depends on n).
45 n >>= 1 /* log2(sizeof(uint16_t)) */;
46 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
47
48 const __m512 vacc = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(vmask, i));
49
50 _mm512_mask_storeu_ps(output, vmask, vacc);
51 }
52 }
53
xnn_f32_f16_vcvt_ukernel__avx512skx_x16(size_t n,const float * input,void * output,const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])54 void xnn_f32_f16_vcvt_ukernel__avx512skx_x16(
55 size_t n,
56 const float* input,
57 void* output,
58 const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
59 {
60 assert(n != 0);
61 assert(n % sizeof(float) == 0);
62 assert(input != NULL);
63 assert(output != NULL);
64
65 uint16_t* o = (uint16_t*) output;
66 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
67 const __m512 vf = _mm512_loadu_ps(input);
68 input += 16;
69
70 _mm256_storeu_si256((__m256i*) o, _mm512_cvtps_ph(vf, _MM_FROUND_NO_EXC));
71 o += 16;
72 }
73 if XNN_UNLIKELY(n != 0) {
74 assert(n >= 1 * sizeof(float));
75 assert(n <= 15 * sizeof(float));
76
77 // Prepare mask for valid elements (depends on n).
78 n >>= 2 /* log2(sizeof(float)) */;
79 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
80
81 const __m512 vf = _mm512_maskz_loadu_ps(vmask, input);
82 const __m256i vh = _mm512_cvtps_ph(vf, _MM_FROUND_NO_EXC);
83 _mm256_mask_storeu_epi16(o, vmask, vh);
84 }
85 }
86
xnn_f32_qs8_vcvt_ukernel__avx512skx_x128(size_t n,const float * x,int8_t * y,const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])87 void xnn_f32_qs8_vcvt_ukernel__avx512skx_x128(
88 size_t n,
89 const float* x,
90 int8_t* y,
91 const union xnn_f32_qs8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
92 {
93 assert(n != 0);
94 assert(n % sizeof(float) == 0);
95 assert(x != NULL);
96 assert(y != NULL);
97
98 const __m512 vscale = _mm512_load_ps(params->avx2.scale);
99 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
100 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
101 const __m512i vshuffle512_mask = _mm512_load_si512(params->avx512.shuffle512_mask);
102 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
103 for (; n >= 128 * sizeof(float); n -= 128 * sizeof(float)) {
104 __m512 vx0123 = _mm512_loadu_ps(x);
105 __m512 vx4567 = _mm512_loadu_ps(x + 16);
106 __m512 vx89AB = _mm512_loadu_ps(x + 32);
107 __m512 vxCDEF = _mm512_loadu_ps(x + 48);
108 __m512 vxGHIJ = _mm512_loadu_ps(x + 64);
109 __m512 vxKLMN = _mm512_loadu_ps(x + 80);
110 __m512 vxOPQR = _mm512_loadu_ps(x + 96);
111 __m512 vxSTUV = _mm512_loadu_ps(x + 112);
112 x += 128;
113
114 vx0123 = _mm512_mul_ps(vx0123, vscale);
115 vx4567 = _mm512_mul_ps(vx4567, vscale);
116 vx89AB = _mm512_mul_ps(vx89AB, vscale);
117 vxCDEF = _mm512_mul_ps(vxCDEF, vscale);
118 vxGHIJ = _mm512_mul_ps(vxGHIJ, vscale);
119 vxKLMN = _mm512_mul_ps(vxKLMN, vscale);
120 vxOPQR = _mm512_mul_ps(vxOPQR, vscale);
121 vxSTUV = _mm512_mul_ps(vxSTUV, vscale);
122
123 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
124 vx4567 = _mm512_min_ps(vx4567, voutput_max_less_zero_point);
125 vx89AB = _mm512_min_ps(vx89AB, voutput_max_less_zero_point);
126 vxCDEF = _mm512_min_ps(vxCDEF, voutput_max_less_zero_point);
127 vxGHIJ = _mm512_min_ps(vxGHIJ, voutput_max_less_zero_point);
128 vxKLMN = _mm512_min_ps(vxKLMN, voutput_max_less_zero_point);
129 vxOPQR = _mm512_min_ps(vxOPQR, voutput_max_less_zero_point);
130 vxSTUV = _mm512_min_ps(vxSTUV, voutput_max_less_zero_point);
131
132 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
133 const __m512i vacc4567 = _mm512_cvtps_epi32(vx4567);
134 const __m512i vacc89AB = _mm512_cvtps_epi32(vx89AB);
135 const __m512i vaccCDEF = _mm512_cvtps_epi32(vxCDEF);
136 const __m512i vaccGHIJ = _mm512_cvtps_epi32(vxGHIJ);
137 const __m512i vaccKLMN = _mm512_cvtps_epi32(vxKLMN);
138 const __m512i vaccOPQR = _mm512_cvtps_epi32(vxOPQR);
139 const __m512i vaccSTUV = _mm512_cvtps_epi32(vxSTUV);
140
141 __m512i vacc04152637 = _mm512_packs_epi32(vacc0123, vacc4567);
142 __m512i vacc8C9DAEBF = _mm512_packs_epi32(vacc89AB, vaccCDEF);
143 __m512i vaccGKHLIMJN = _mm512_packs_epi32(vaccGHIJ, vaccKLMN);
144 __m512i vaccOSPTQURV = _mm512_packs_epi32(vaccOPQR, vaccSTUV);
145
146 vacc04152637 = _mm512_adds_epi16(vacc04152637, voutput_zero_point);
147 vacc8C9DAEBF = _mm512_adds_epi16(vacc8C9DAEBF, voutput_zero_point);
148 vaccGKHLIMJN = _mm512_adds_epi16(vaccGKHLIMJN, voutput_zero_point);
149 vaccOSPTQURV = _mm512_adds_epi16(vaccOSPTQURV, voutput_zero_point);
150
151 __m512i vy048C159D26AE37BF = _mm512_packs_epi16(vacc04152637, vacc8C9DAEBF);
152 __m512i vyGKOSHLPTIMQUJNRV = _mm512_packs_epi16(vaccGKHLIMJN, vaccOSPTQURV);
153
154 vy048C159D26AE37BF = _mm512_max_epi8(vy048C159D26AE37BF, voutput_min);
155 vyGKOSHLPTIMQUJNRV = _mm512_max_epi8(vyGKOSHLPTIMQUJNRV, voutput_min);
156
157 const __m512i vy0123456789ABCDEF = _mm512_permutexvar_epi32(vshuffle512_mask, vy048C159D26AE37BF);
158 const __m512i vyGHIJKLMNOPQRSTUV = _mm512_permutexvar_epi32(vshuffle512_mask, vyGKOSHLPTIMQUJNRV);
159
160 _mm512_storeu_si512(y, vy0123456789ABCDEF);
161 _mm512_storeu_si512(y + 64, vyGHIJKLMNOPQRSTUV);
162 y += 128;
163 }
164 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
165 __m512 vx0123 = _mm512_loadu_ps(x);
166 vx0123 = _mm512_mul_ps(vx0123, vscale);
167 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
168 x += 16;
169
170 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
171
172 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
173 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
174 const __m128i vy0213 = _mm_packs_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
175 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
176 vy0123 = _mm_max_epi8(vy0123, _mm512_castsi512_si128(voutput_min));
177
178 _mm_storeu_si128((__m128i*) y, vy0123);
179 y += 16;
180 }
181 if XNN_UNLIKELY(n != 0) {
182 assert(n >= 1 * sizeof(float));
183 assert(n <= 15 * sizeof(float));
184
185 // Prepare mask for valid elements (depends on n).
186 n >>= 2 /* log2(sizeof(float)) */;
187 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
188
189 __m512 vx0123 = _mm512_maskz_loadu_ps(vmask, x);
190 vx0123 = _mm512_mul_ps(vx0123, vscale);
191 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
192
193 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
194
195 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
196 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
197 const __m128i vy0213 = _mm_packs_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
198 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
199 vy0123 = _mm_max_epi8(vy0123, _mm512_castsi512_si128(voutput_min));
200
201 _mm_mask_storeu_epi8(y, vmask, vy0123);
202 }
203 }
204
xnn_f32_qu8_vcvt_ukernel__avx512skx_x128(size_t n,const float * x,uint8_t * y,const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])205 void xnn_f32_qu8_vcvt_ukernel__avx512skx_x128(
206 size_t n,
207 const float* x,
208 uint8_t* y,
209 const union xnn_f32_qu8_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
210 {
211 assert(n != 0);
212 assert(n % sizeof(float) == 0);
213 assert(x != NULL);
214 assert(y != NULL);
215
216 const __m512 vscale = _mm512_load_ps(params->avx2.scale);
217 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
218 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
219 const __m512i vshuffle512_mask = _mm512_load_si512(params->avx512.shuffle512_mask);
220 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
221 for (; n >= 128 * sizeof(float); n -= 128 * sizeof(float)) {
222 __m512 vx0123 = _mm512_loadu_ps(x);
223 __m512 vx4567 = _mm512_loadu_ps(x + 16);
224 __m512 vx89AB = _mm512_loadu_ps(x + 32);
225 __m512 vxCDEF = _mm512_loadu_ps(x + 48);
226 __m512 vxGHIJ = _mm512_loadu_ps(x + 64);
227 __m512 vxKLMN = _mm512_loadu_ps(x + 80);
228 __m512 vxOPQR = _mm512_loadu_ps(x + 96);
229 __m512 vxSTUV = _mm512_loadu_ps(x + 112);
230 x += 128;
231
232 vx0123 = _mm512_mul_ps(vx0123, vscale);
233 vx4567 = _mm512_mul_ps(vx4567, vscale);
234 vx89AB = _mm512_mul_ps(vx89AB, vscale);
235 vxCDEF = _mm512_mul_ps(vxCDEF, vscale);
236 vxGHIJ = _mm512_mul_ps(vxGHIJ, vscale);
237 vxKLMN = _mm512_mul_ps(vxKLMN, vscale);
238 vxOPQR = _mm512_mul_ps(vxOPQR, vscale);
239 vxSTUV = _mm512_mul_ps(vxSTUV, vscale);
240
241 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
242 vx4567 = _mm512_min_ps(vx4567, voutput_max_less_zero_point);
243 vx89AB = _mm512_min_ps(vx89AB, voutput_max_less_zero_point);
244 vxCDEF = _mm512_min_ps(vxCDEF, voutput_max_less_zero_point);
245 vxGHIJ = _mm512_min_ps(vxGHIJ, voutput_max_less_zero_point);
246 vxKLMN = _mm512_min_ps(vxKLMN, voutput_max_less_zero_point);
247 vxOPQR = _mm512_min_ps(vxOPQR, voutput_max_less_zero_point);
248 vxSTUV = _mm512_min_ps(vxSTUV, voutput_max_less_zero_point);
249
250 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
251 const __m512i vacc4567 = _mm512_cvtps_epi32(vx4567);
252 const __m512i vacc89AB = _mm512_cvtps_epi32(vx89AB);
253 const __m512i vaccCDEF = _mm512_cvtps_epi32(vxCDEF);
254 const __m512i vaccGHIJ = _mm512_cvtps_epi32(vxGHIJ);
255 const __m512i vaccKLMN = _mm512_cvtps_epi32(vxKLMN);
256 const __m512i vaccOPQR = _mm512_cvtps_epi32(vxOPQR);
257 const __m512i vaccSTUV = _mm512_cvtps_epi32(vxSTUV);
258
259 __m512i vacc04152637 = _mm512_packs_epi32(vacc0123, vacc4567);
260 __m512i vacc8C9DAEBF = _mm512_packs_epi32(vacc89AB, vaccCDEF);
261 __m512i vaccGKHLIMJN = _mm512_packs_epi32(vaccGHIJ, vaccKLMN);
262 __m512i vaccOSPTQURV = _mm512_packs_epi32(vaccOPQR, vaccSTUV);
263
264 vacc04152637 = _mm512_adds_epi16(vacc04152637, voutput_zero_point);
265 vacc8C9DAEBF = _mm512_adds_epi16(vacc8C9DAEBF, voutput_zero_point);
266 vaccGKHLIMJN = _mm512_adds_epi16(vaccGKHLIMJN, voutput_zero_point);
267 vaccOSPTQURV = _mm512_adds_epi16(vaccOSPTQURV, voutput_zero_point);
268
269 __m512i vy048C159D26AE37BF = _mm512_packus_epi16(vacc04152637, vacc8C9DAEBF);
270 __m512i vyGKOSHLPTIMQUJNRV = _mm512_packus_epi16(vaccGKHLIMJN, vaccOSPTQURV);
271
272 vy048C159D26AE37BF = _mm512_max_epu8(vy048C159D26AE37BF, voutput_min);
273 vyGKOSHLPTIMQUJNRV = _mm512_max_epu8(vyGKOSHLPTIMQUJNRV, voutput_min);
274
275 const __m512i vy0123456789ABCDEF = _mm512_permutexvar_epi32(vshuffle512_mask, vy048C159D26AE37BF);
276 const __m512i vyGHIJKLMNOPQRSTUV = _mm512_permutexvar_epi32(vshuffle512_mask, vyGKOSHLPTIMQUJNRV);
277
278 _mm512_storeu_si512(y, vy0123456789ABCDEF);
279 _mm512_storeu_si512(y + 64, vyGHIJKLMNOPQRSTUV);
280 y += 128;
281 }
282 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
283 __m512 vx0123 = _mm512_loadu_ps(x);
284 vx0123 = _mm512_mul_ps(vx0123, vscale);
285 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
286 x += 16;
287
288 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
289
290 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
291 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
292 const __m128i vy0213 = _mm_packus_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
293 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
294 vy0123 = _mm_max_epu8(vy0123, _mm512_castsi512_si128(voutput_min));
295
296 _mm_storeu_si128((__m128i*) y, vy0123);
297 y += 16;
298 }
299 if XNN_UNLIKELY(n != 0) {
300 assert(n >= 1 * sizeof(float));
301 assert(n <= 15 * sizeof(float));
302
303 // Prepare mask for valid elements (depends on n).
304 n >>= 2 /* log2(sizeof(float)) */;
305 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
306
307 __m512 vx0123 = _mm512_maskz_loadu_ps(vmask, x);
308 vx0123 = _mm512_mul_ps(vx0123, vscale);
309 vx0123 = _mm512_min_ps(vx0123, voutput_max_less_zero_point);
310
311 const __m512i vacc0123 = _mm512_cvtps_epi32(vx0123);
312
313 __m256i vacc0213 = _mm256_packs_epi32(_mm512_castsi512_si256(vacc0123), _mm512_extracti32x8_epi32(vacc0123, 1));
314 vacc0213 = _mm256_adds_epi16(vacc0213, _mm512_castsi512_si256(voutput_zero_point));
315 const __m128i vy0213 = _mm_packus_epi16(_mm256_castsi256_si128(vacc0213), _mm256_extracti128_si256(vacc0213, 1));
316 __m128i vy0123 = _mm_shuffle_epi32(vy0213, _MM_SHUFFLE(3, 1, 2, 0));
317 vy0123 = _mm_max_epu8(vy0123, _mm512_castsi512_si128(voutput_min));
318
319 _mm_mask_storeu_epi8(y, vmask, vy0123);
320 }
321 }
322
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])323 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
324 size_t channels,
325 size_t output_width,
326 const int8_t** input,
327 const void* weights,
328 int8_t* output,
329 size_t input_stride,
330 size_t output_increment,
331 size_t input_offset,
332 const int8_t* zero,
333 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
334 {
335 assert(channels != 0);
336 assert(output_width != 0);
337
338 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
339 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
340 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->avx512.output_min);
341 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
342
343 do {
344 const int8_t* i0 = input[0];
345 assert(i0 != NULL);
346 if XNN_UNPREDICTABLE(i0 != zero) {
347 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
348 }
349 const int8_t* i1 = input[1];
350 assert(i1 != NULL);
351 if XNN_UNPREDICTABLE(i1 != zero) {
352 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
353 }
354 const int8_t* i2 = input[2];
355 assert(i2 != NULL);
356 if XNN_UNPREDICTABLE(i2 != zero) {
357 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
358 }
359 const int8_t* i3 = input[3];
360 assert(i3 != NULL);
361 if XNN_UNPREDICTABLE(i3 != zero) {
362 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
363 }
364 const int8_t* i4 = input[4];
365 assert(i4 != NULL);
366 if XNN_UNPREDICTABLE(i4 != zero) {
367 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
368 }
369 const int8_t* i5 = input[5];
370 assert(i5 != NULL);
371 if XNN_UNPREDICTABLE(i5 != zero) {
372 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
373 }
374 const int8_t* i6 = input[6];
375 assert(i6 != NULL);
376 if XNN_UNPREDICTABLE(i6 != zero) {
377 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
378 }
379 const int8_t* i7 = input[7];
380 assert(i7 != NULL);
381 if XNN_UNPREDICTABLE(i7 != zero) {
382 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
383 }
384 const int8_t* i8 = input[8];
385 assert(i8 != NULL);
386 if XNN_UNPREDICTABLE(i8 != zero) {
387 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
388 }
389 const int8_t* i9 = input[9];
390 assert(i9 != NULL);
391 if XNN_UNPREDICTABLE(i9 != zero) {
392 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
393 }
394 const int8_t* i10 = input[10];
395 assert(i10 != NULL);
396 if XNN_UNPREDICTABLE(i10 != zero) {
397 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
398 }
399 const int8_t* i11 = input[11];
400 assert(i11 != NULL);
401 if XNN_UNPREDICTABLE(i11 != zero) {
402 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
403 }
404 const int8_t* i12 = input[12];
405 assert(i12 != NULL);
406 if XNN_UNPREDICTABLE(i12 != zero) {
407 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
408 }
409 const int8_t* i13 = input[13];
410 assert(i13 != NULL);
411 if XNN_UNPREDICTABLE(i13 != zero) {
412 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
413 }
414 const int8_t* i14 = input[14];
415 assert(i14 != NULL);
416 if XNN_UNPREDICTABLE(i14 != zero) {
417 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
418 }
419 const int8_t* i15 = input[15];
420 assert(i15 != NULL);
421 if XNN_UNPREDICTABLE(i15 != zero) {
422 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
423 }
424 const int8_t* i16 = input[16];
425 assert(i16 != NULL);
426 if XNN_UNPREDICTABLE(i16 != zero) {
427 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
428 }
429 const int8_t* i17 = input[17];
430 assert(i17 != NULL);
431 if XNN_UNPREDICTABLE(i17 != zero) {
432 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
433 }
434 const int8_t* i18 = input[18];
435 assert(i18 != NULL);
436 if XNN_UNPREDICTABLE(i18 != zero) {
437 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
438 }
439 const int8_t* i19 = input[19];
440 assert(i19 != NULL);
441 if XNN_UNPREDICTABLE(i19 != zero) {
442 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
443 }
444 const int8_t* i20 = input[20];
445 assert(i20 != NULL);
446 if XNN_UNPREDICTABLE(i20 != zero) {
447 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
448 }
449 const int8_t* i21 = input[21];
450 assert(i21 != NULL);
451 if XNN_UNPREDICTABLE(i21 != zero) {
452 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
453 }
454 const int8_t* i22 = input[22];
455 assert(i22 != NULL);
456 if XNN_UNPREDICTABLE(i22 != zero) {
457 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
458 }
459 const int8_t* i23 = input[23];
460 assert(i23 != NULL);
461 if XNN_UNPREDICTABLE(i23 != zero) {
462 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
463 }
464 const int8_t* i24 = input[24];
465 assert(i24 != NULL);
466 if XNN_UNPREDICTABLE(i24 != zero) {
467 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
468 }
469 input = (const int8_t**) ((uintptr_t) input + input_stride);
470
471 size_t c = channels;
472 const void* w = weights;
473 for (; c >= 32; c -= 32) {
474 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
475 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
476
477
478 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
479 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
480 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
481 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
482 i0 += 32;
483
484 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
485 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
486
487 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
488 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
489 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
490 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
491 i1 += 32;
492
493 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
494 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
495
496 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
497 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
498 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
499 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
500 i2 += 32;
501
502 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
503 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
504
505 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
506 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
507 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
508 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
509 i3 += 32;
510
511 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
512 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
513
514 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
515 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
516 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
517 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
518 i4 += 32;
519
520 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
521 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
522
523 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
524 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
525 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
526 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
527 i5 += 32;
528
529 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
530 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
531
532 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
533 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
534 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
535 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
536 i6 += 32;
537
538 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
539 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
540
541 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
542 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
543 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
544 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
545 i7 += 32;
546
547 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
548 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
549
550 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
551 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
552 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
553 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
554 i8 += 32;
555
556 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
557 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
558
559 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
560 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t))));
561 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
562 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(int8_t))));
563 i9 += 32;
564
565 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
566 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
567
568 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
569 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(int8_t))));
570 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
571 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(int8_t))));
572 i10 += 32;
573
574 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
575 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
576
577 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
578 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(int8_t))));
579 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
580 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(int8_t))));
581 i11 += 32;
582
583 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
584 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
585
586 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
587 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(int8_t))));
588 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
589 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(int8_t))));
590 i12 += 32;
591
592 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
593 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
594
595 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
596 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(int8_t))));
597 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
598 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(int8_t))));
599 i13 += 32;
600
601 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
602 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
603
604 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
605 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(int8_t))));
606 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
607 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(int8_t))));
608 i14 += 32;
609
610 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
611 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
612
613 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
614 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(int8_t))));
615 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
616 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(int8_t))));
617 i15 += 32;
618
619 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
620 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
621
622 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
623 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(int8_t))));
624 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
625 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(int8_t))));
626 i16 += 32;
627
628 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
629 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
630
631 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
632 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(int8_t))));
633 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
634 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(int8_t))));
635 i17 += 32;
636
637 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
638 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
639
640 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
641 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(int8_t))));
642 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
643 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(int8_t))));
644 i18 += 32;
645
646 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
647 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
648
649 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
650 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(int8_t))));
651 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
652 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(int8_t))));
653 i19 += 32;
654
655 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
656 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
657
658 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
659 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(int8_t))));
660 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
661 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(int8_t))));
662 i20 += 32;
663
664 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
665 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
666
667 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
668 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(int8_t))));
669 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
670 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(int8_t))));
671 i21 += 32;
672
673 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
674 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
675
676 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
677 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(int8_t))));
678 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
679 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(int8_t))));
680 i22 += 32;
681
682 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
683 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
684
685 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
686 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(int8_t))));
687 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
688 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(int8_t))));
689 i23 += 32;
690
691 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
692 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
693
694 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
695 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(int8_t))));
696 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
697 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(int8_t))));
698 i24 += 32;
699
700 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
701 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
702
703 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t));
704
705 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
706 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
707
708 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
709 const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
710 w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
711 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
712 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
713
714 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
715 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
716
717 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
718 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
719
720 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
721 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
722
723 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
724 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
725 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
726 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
727 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
728 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
729 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
730
731 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
732 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
733
734 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
735 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
736 output += 32;
737 }
738 if XNN_UNLIKELY(c != 0) {
739 // Prepare mask for valid 8-bit elements (depends on nc).
740 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
741 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
742 do {
743 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
744
745
746 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
747 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
748 i0 += 16;
749
750 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
751
752 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
753 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
754 i1 += 16;
755
756 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
757
758 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
759 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
760 i2 += 16;
761
762 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
763
764 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
765 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
766 i3 += 16;
767
768 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
769
770 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
771 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
772 i4 += 16;
773
774 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
775
776 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
777 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
778 i5 += 16;
779
780 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
781
782 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
783 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
784 i6 += 16;
785
786 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
787
788 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
789 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
790 i7 += 16;
791
792 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
793
794 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
795 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
796 i8 += 16;
797
798 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
799
800 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
801 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 288)));
802 i9 += 16;
803
804 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
805
806 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
807 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 320)));
808 i10 += 16;
809
810 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
811
812 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
813 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 352)));
814 i11 += 16;
815
816 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
817
818 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
819 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 384)));
820 i12 += 16;
821
822 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
823
824 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
825 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 416)));
826 i13 += 16;
827
828 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
829
830 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
831 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 448)));
832 i14 += 16;
833
834 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
835
836 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
837 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 480)));
838 i15 += 16;
839
840 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
841
842 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
843 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 512)));
844 i16 += 16;
845
846 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
847
848 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
849 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 544)));
850 i17 += 16;
851
852 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
853
854 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
855 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 576)));
856 i18 += 16;
857
858 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
859
860 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
861 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 608)));
862 i19 += 16;
863
864 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
865
866 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
867 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 640)));
868 i20 += 16;
869
870 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
871
872 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
873 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 672)));
874 i21 += 16;
875
876 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
877
878 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
879 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 704)));
880 i22 += 16;
881
882 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
883
884 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
885 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 736)));
886 i23 += 16;
887
888 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
889
890 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
891 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 768)));
892 i24 += 16;
893
894 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
895
896 k += 16;
897
898 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
899 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t)));
900 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
901 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
902 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
903
904 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
905
906 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
907
908 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
909 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
910 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
911 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
912
913 if XNN_LIKELY(c >= 16) {
914 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
915 output += 16;
916 c -= 16;
917 } else {
918 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
919 output = (int8_t*) ((uintptr_t) output + c);
920 c = 0;
921 }
922 } while (c != 0);
923 }
924
925 output = (int8_t*) ((uintptr_t) output + output_increment);
926 } while (--output_width != 0);
927 }
928
xnn_qc8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])929 void xnn_qc8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
930 size_t channels,
931 size_t output_width,
932 const int8_t** input,
933 const void* weights,
934 int8_t* output,
935 size_t input_stride,
936 size_t output_increment,
937 size_t input_offset,
938 const int8_t* zero,
939 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
940 {
941 assert(channels != 0);
942 assert(output_width != 0);
943
944 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
945 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
946 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->avx512.output_min);
947 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
948
949 do {
950 const int8_t* i0 = input[0];
951 assert(i0 != NULL);
952 if XNN_UNPREDICTABLE(i0 != zero) {
953 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
954 }
955 const int8_t* i1 = input[1];
956 assert(i1 != NULL);
957 if XNN_UNPREDICTABLE(i1 != zero) {
958 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
959 }
960 const int8_t* i2 = input[2];
961 assert(i2 != NULL);
962 if XNN_UNPREDICTABLE(i2 != zero) {
963 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
964 }
965 const int8_t* i3 = input[3];
966 assert(i3 != NULL);
967 if XNN_UNPREDICTABLE(i3 != zero) {
968 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
969 }
970 const int8_t* i4 = input[4];
971 assert(i4 != NULL);
972 if XNN_UNPREDICTABLE(i4 != zero) {
973 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
974 }
975 const int8_t* i5 = input[5];
976 assert(i5 != NULL);
977 if XNN_UNPREDICTABLE(i5 != zero) {
978 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
979 }
980 const int8_t* i6 = input[6];
981 assert(i6 != NULL);
982 if XNN_UNPREDICTABLE(i6 != zero) {
983 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
984 }
985 const int8_t* i7 = input[7];
986 assert(i7 != NULL);
987 if XNN_UNPREDICTABLE(i7 != zero) {
988 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
989 }
990 const int8_t* i8 = input[8];
991 assert(i8 != NULL);
992 if XNN_UNPREDICTABLE(i8 != zero) {
993 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
994 }
995 input = (const int8_t**) ((uintptr_t) input + input_stride);
996
997 size_t c = channels;
998 const void* w = weights;
999 for (; c >= 32; c -= 32) {
1000 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1001 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
1002
1003
1004 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1005 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
1006 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
1007 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
1008 i0 += 32;
1009
1010 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1011 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
1012
1013 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1014 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
1015 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
1016 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
1017 i1 += 32;
1018
1019 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1020 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
1021
1022 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1023 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
1024 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
1025 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
1026 i2 += 32;
1027
1028 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1029 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
1030
1031 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
1032 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
1033 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
1034 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
1035 i3 += 32;
1036
1037 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
1038 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
1039
1040 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
1041 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
1042 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
1043 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
1044 i4 += 32;
1045
1046 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
1047 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
1048
1049 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
1050 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
1051 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
1052 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
1053 i5 += 32;
1054
1055 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
1056 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
1057
1058 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
1059 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
1060 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
1061 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
1062 i6 += 32;
1063
1064 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
1065 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
1066
1067 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
1068 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
1069 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
1070 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
1071 i7 += 32;
1072
1073 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
1074 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
1075
1076 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
1077 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
1078 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
1079 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
1080 i8 += 32;
1081
1082 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
1083 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
1084
1085 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t));
1086
1087 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1088 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
1089
1090 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps(w);
1091 const __m512 vscaleGHIJKLMNOPQRSTUV = _mm512_loadu_ps((const void*) ((uintptr_t) w + 16 * sizeof(float)));
1092 w = (const void*) ((uintptr_t) w + 32 * sizeof(float));
1093 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1094 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscaleGHIJKLMNOPQRSTUV);
1095
1096 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1097 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
1098
1099 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1100 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
1101
1102 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
1103 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
1104
1105 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
1106 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
1107 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
1108 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
1109 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
1110 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
1111 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
1112
1113 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
1114 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
1115
1116 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
1117 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
1118 output += 32;
1119 }
1120 if XNN_UNLIKELY(c != 0) {
1121 // Prepare mask for valid 8-bit elements (depends on nc).
1122 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
1123 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
1124 do {
1125 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1126
1127
1128 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1129 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
1130 i0 += 16;
1131
1132 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1133
1134 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1135 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
1136 i1 += 16;
1137
1138 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1139
1140 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1141 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
1142 i2 += 16;
1143
1144 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1145
1146 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
1147 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
1148 i3 += 16;
1149
1150 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
1151
1152 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
1153 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
1154 i4 += 16;
1155
1156 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
1157
1158 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
1159 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
1160 i5 += 16;
1161
1162 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
1163
1164 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
1165 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
1166 i6 += 16;
1167
1168 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
1169
1170 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
1171 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
1172 i7 += 16;
1173
1174 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
1175
1176 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
1177 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
1178 i8 += 16;
1179
1180 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
1181
1182 k += 16;
1183
1184 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
1185 const __m512 vscale0123456789ABCDEF = _mm512_loadu_ps((const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t)));
1186 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale0123456789ABCDEF);
1187 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
1188 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
1189
1190 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
1191
1192 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
1193
1194 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
1195 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
1196 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
1197 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
1198
1199 if XNN_LIKELY(c >= 16) {
1200 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
1201 output += 16;
1202 c -= 16;
1203 } else {
1204 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
1205 output = (int8_t*) ((uintptr_t) output + c);
1206 c = 0;
1207 }
1208 } while (c != 0);
1209 }
1210
1211 output = (int8_t*) ((uintptr_t) output + output_increment);
1212 } while (--output_width != 0);
1213 }
1214
xnn_qc8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1215 void xnn_qc8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
1216 size_t mr,
1217 size_t nc,
1218 size_t kc,
1219 const int8_t* restrict a,
1220 size_t a_stride,
1221 const void* restrict w,
1222 int8_t* restrict c,
1223 size_t cm_stride,
1224 size_t cn_stride,
1225 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1226 {
1227 assert(mr != 0);
1228 assert(mr <= 1);
1229 assert(nc != 0);
1230 assert(kc != 0);
1231 assert(kc % sizeof(int8_t) == 0);
1232 assert(a != NULL);
1233 assert(w != NULL);
1234 assert(c != NULL);
1235
1236 kc = round_up_po2(kc, 8);
1237 const int8_t* a0 = a;
1238 int8_t* c0 = c;
1239
1240 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1241 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
1242 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
1243 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
1244 do {
1245 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1246 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1247 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1248 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1249 w = (const void*) ((const int32_t*) w + 16);
1250
1251 size_t k = 0;
1252 while (k < kc) {
1253 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1254 a0 += 8;
1255
1256 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1257
1258 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1259 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1260
1261 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1262 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1263
1264 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1265 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1266
1267 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1268
1269 w = (const void*) ((const int8_t*) w + 128);
1270 k += 8 * sizeof(int8_t);
1271 }
1272
1273 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1274 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1275
1276 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1277
1278 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1279
1280 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1281 w = (const void*) ((const float*) w + 16);
1282 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1283 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1284
1285 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1286
1287 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1288
1289 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
1290
1291 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
1292 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
1293 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
1294
1295 if (nc >= 16) {
1296 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
1297
1298 a0 = (const int8_t*) ((uintptr_t) a0 - k);
1299
1300 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1301
1302 nc -= 16;
1303 } else {
1304 // Prepare mask for valid 8-bit elements (depends on nc).
1305 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1306
1307 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
1308
1309 nc = 0;
1310 }
1311 } while (nc != 0);
1312 }
1313
xnn_qc8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1314 void xnn_qc8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
1315 size_t mr,
1316 size_t nc,
1317 size_t kc,
1318 const int8_t* restrict a,
1319 size_t a_stride,
1320 const void* restrict w,
1321 int8_t* restrict c,
1322 size_t cm_stride,
1323 size_t cn_stride,
1324 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1325 {
1326 assert(mr != 0);
1327 assert(mr <= 4);
1328 assert(nc != 0);
1329 assert(kc != 0);
1330 assert(kc % sizeof(int8_t) == 0);
1331 assert(a != NULL);
1332 assert(w != NULL);
1333 assert(c != NULL);
1334
1335 kc = round_up_po2(kc, 8);
1336 const int8_t* a0 = a;
1337 int8_t* c0 = c;
1338 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
1339 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
1340 if XNN_UNPREDICTABLE(mr < 2) {
1341 a1 = a0;
1342 c1 = c0;
1343 }
1344 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
1345 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
1346 if XNN_UNPREDICTABLE(mr <= 2) {
1347 a2 = a1;
1348 c2 = c1;
1349 }
1350 const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
1351 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
1352 if XNN_UNPREDICTABLE(mr != 4) {
1353 a3 = a2;
1354 c3 = c2;
1355 }
1356
1357 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1358 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
1359 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
1360 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
1361 do {
1362 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1363 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1364 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1365 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1366 __m512i vacc1x0123 = vacc0x0123;
1367 __m512i vacc1x4567 = vacc0x4567;
1368 __m512i vacc1x89AB = vacc0x89AB;
1369 __m512i vacc1xCDEF = vacc0xCDEF;
1370 __m512i vacc2x0123 = vacc0x0123;
1371 __m512i vacc2x4567 = vacc0x4567;
1372 __m512i vacc2x89AB = vacc0x89AB;
1373 __m512i vacc2xCDEF = vacc0xCDEF;
1374 __m512i vacc3x0123 = vacc0x0123;
1375 __m512i vacc3x4567 = vacc0x4567;
1376 __m512i vacc3x89AB = vacc0x89AB;
1377 __m512i vacc3xCDEF = vacc0xCDEF;
1378 w = (const void*) ((const int32_t*) w + 16);
1379
1380 size_t k = 0;
1381 while (k < kc) {
1382 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1383 a0 += 8;
1384 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
1385 a1 += 8;
1386 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
1387 a2 += 8;
1388 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
1389 a3 += 8;
1390
1391 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1392
1393 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1394 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
1395 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
1396 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
1397 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1398
1399 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1400 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
1401 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
1402 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
1403 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1404
1405 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1406 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
1407 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
1408 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
1409 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1410
1411 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1412 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
1413 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
1414 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
1415
1416 w = (const void*) ((const int8_t*) w + 128);
1417 k += 8 * sizeof(int8_t);
1418 }
1419
1420 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1421 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1422 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
1423 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
1424 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
1425 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
1426 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
1427 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
1428
1429 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1430 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
1431 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
1432 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
1433
1434 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1435 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
1436 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
1437 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
1438
1439 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1440 w = (const void*) ((const float*) w + 16);
1441 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1442 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1443 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1444 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1445 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1446
1447 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1448 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
1449 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
1450 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
1451
1452 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1453 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
1454 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
1455 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
1456
1457 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
1458 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
1459
1460 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
1461 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
1462 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
1463 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
1464
1465 if (nc >= 16) {
1466 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
1467 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
1468 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
1469 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
1470
1471 a0 = (const int8_t*) ((uintptr_t) a0 - k);
1472 a1 = (const int8_t*) ((uintptr_t) a1 - k);
1473 a2 = (const int8_t*) ((uintptr_t) a2 - k);
1474 a3 = (const int8_t*) ((uintptr_t) a3 - k);
1475
1476 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1477 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
1478 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
1479 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
1480
1481 nc -= 16;
1482 } else {
1483 // Prepare mask for valid 8-bit elements (depends on nc).
1484 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1485
1486 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
1487 vmask = _kshiftli_mask64(vmask, 16);
1488 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
1489 vmask = _kshiftli_mask64(vmask, 16);
1490 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
1491 vmask = _kshiftli_mask64(vmask, 16);
1492 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
1493
1494 nc = 0;
1495 }
1496 } while (nc != 0);
1497 }
1498
xnn_qc8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1499 void xnn_qc8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
1500 size_t mr,
1501 size_t nc,
1502 size_t kc,
1503 size_t ks,
1504 const int8_t** restrict a,
1505 const void* restrict w,
1506 int8_t* restrict c,
1507 size_t cm_stride,
1508 size_t cn_stride,
1509 size_t a_offset,
1510 const int8_t* zero,
1511 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1512 {
1513 assert(mr != 0);
1514 assert(mr <= 1);
1515 assert(nc != 0);
1516 assert(kc != 0);
1517 assert(kc % sizeof(int8_t) == 0);
1518 assert(a != NULL);
1519 assert(w != NULL);
1520 assert(c != NULL);
1521
1522 kc = round_up_po2(kc, 8);
1523 int8_t* c0 = c;
1524
1525 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1526 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
1527 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
1528 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
1529 do {
1530 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1531 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1532 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1533 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1534 w = (const void*) ((const int32_t*) w + 16);
1535
1536 size_t p = ks;
1537 do {
1538 const int8_t* restrict a0 = a[0];
1539 if XNN_UNPREDICTABLE(a0 != zero) {
1540 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
1541 }
1542 a += 1;
1543
1544 size_t k = 0;
1545 while (k < kc) {
1546 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1547 a0 += 8;
1548
1549 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1550
1551 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1552 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1553
1554 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1555 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1556
1557 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1558 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1559
1560 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1561
1562 w = (const void*) ((const int8_t*) w + 128);
1563 k += 8 * sizeof(int8_t);
1564 }
1565 p -= 1 * sizeof(void*);
1566 } while (p != 0);
1567
1568 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1569 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1570
1571 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1572
1573 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1574
1575 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1576 w = (const void*) ((const float*) w + 16);
1577 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1578 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1579
1580 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1581
1582 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1583
1584 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
1585
1586 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
1587 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
1588 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
1589
1590 if (nc >= 16) {
1591 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
1592
1593 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1594
1595 a = (const int8_t**restrict) ((uintptr_t) a - ks);
1596
1597 nc -= 16;
1598 } else {
1599 // Prepare mask for valid 8-bit elements (depends on nc).
1600 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
1601
1602 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
1603
1604 nc = 0;
1605 }
1606 } while (nc != 0);
1607 }
1608
xnn_qc8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1609 void xnn_qc8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
1610 size_t mr,
1611 size_t nc,
1612 size_t kc,
1613 size_t ks,
1614 const int8_t** restrict a,
1615 const void* restrict w,
1616 int8_t* restrict c,
1617 size_t cm_stride,
1618 size_t cn_stride,
1619 size_t a_offset,
1620 const int8_t* zero,
1621 const union xnn_qs8_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1622 {
1623 assert(mr != 0);
1624 assert(mr <= 4);
1625 assert(nc != 0);
1626 assert(kc != 0);
1627 assert(kc % sizeof(int8_t) == 0);
1628 assert(a != NULL);
1629 assert(w != NULL);
1630 assert(c != NULL);
1631
1632 kc = round_up_po2(kc, 8);
1633 int8_t* c0 = c;
1634 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
1635 if XNN_UNPREDICTABLE(mr < 2) {
1636 c1 = c0;
1637 }
1638 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
1639 if XNN_UNPREDICTABLE(mr <= 2) {
1640 c2 = c1;
1641 }
1642 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
1643 if XNN_UNPREDICTABLE(mr != 4) {
1644 c3 = c2;
1645 }
1646
1647 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
1648 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->avx512.output_max_less_zero_point);
1649 const __m512i voutput_zero_point = _mm512_load_si512(params->avx512.output_zero_point);
1650 const __m512i voutput_min = _mm512_load_si512(params->avx512.output_min);
1651 do {
1652 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
1653 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
1654 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
1655 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
1656 __m512i vacc1x0123 = vacc0x0123;
1657 __m512i vacc1x4567 = vacc0x4567;
1658 __m512i vacc1x89AB = vacc0x89AB;
1659 __m512i vacc1xCDEF = vacc0xCDEF;
1660 __m512i vacc2x0123 = vacc0x0123;
1661 __m512i vacc2x4567 = vacc0x4567;
1662 __m512i vacc2x89AB = vacc0x89AB;
1663 __m512i vacc2xCDEF = vacc0xCDEF;
1664 __m512i vacc3x0123 = vacc0x0123;
1665 __m512i vacc3x4567 = vacc0x4567;
1666 __m512i vacc3x89AB = vacc0x89AB;
1667 __m512i vacc3xCDEF = vacc0xCDEF;
1668 w = (const void*) ((const int32_t*) w + 16);
1669
1670 size_t p = ks;
1671 do {
1672 const int8_t* restrict a0 = a[0];
1673 if XNN_UNPREDICTABLE(a0 != zero) {
1674 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
1675 }
1676 const int8_t* restrict a1 = a[1];
1677 if XNN_UNPREDICTABLE(a1 != zero) {
1678 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
1679 }
1680 const int8_t* restrict a2 = a[2];
1681 if XNN_UNPREDICTABLE(a2 != zero) {
1682 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
1683 }
1684 const int8_t* restrict a3 = a[3];
1685 if XNN_UNPREDICTABLE(a3 != zero) {
1686 a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
1687 }
1688 a += 4;
1689
1690 size_t k = 0;
1691 while (k < kc) {
1692 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
1693 a0 += 8;
1694 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
1695 a1 += 8;
1696 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
1697 a2 += 8;
1698 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
1699 a3 += 8;
1700
1701 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
1702
1703 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
1704 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
1705 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
1706 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
1707 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
1708
1709 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
1710 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
1711 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
1712 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
1713 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
1714
1715 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
1716 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
1717 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
1718 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
1719 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
1720
1721 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
1722 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
1723 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
1724 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
1725
1726 w = (const void*) ((const int8_t*) w + 128);
1727 k += 8 * sizeof(int8_t);
1728 }
1729 p -= 4 * sizeof(void*);
1730 } while (p != 0);
1731
1732 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
1733 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
1734 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
1735 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
1736 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
1737 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
1738 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
1739 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
1740
1741 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
1742 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
1743 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
1744 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
1745
1746 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
1747 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
1748 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
1749 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
1750
1751 const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
1752 w = (const void*) ((const float*) w + 16);
1753 const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
1754 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1755 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1756 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1757 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
1758
1759 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
1760 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
1761 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
1762 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
1763
1764 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
1765 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
1766 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
1767 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
1768
1769 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
1770 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
1771
1772 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
1773 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
1774 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
1775 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
1776
1777 if (nc >= 16) {
1778 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
1779 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
1780 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
1781 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
1782
1783 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
1784 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
1785 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
1786 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
1787
1788 a = (const int8_t**restrict) ((uintptr_t) a - ks);
1789
1790 nc -= 16;
1791 } else {
1792 // Prepare mask for valid 8-bit elements (depends on nc).
1793 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
1794
1795 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
1796 vmask = _kshiftri_mask64(vmask, 16);
1797 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
1798 vmask = _kshiftri_mask64(vmask, 16);
1799 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
1800 vmask = _kshiftri_mask64(vmask, 16);
1801 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
1802
1803 nc = 0;
1804 }
1805 } while (nc != 0);
1806 }
1807
xnn_qs8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])1808 void xnn_qs8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
1809 size_t channels,
1810 size_t output_width,
1811 const int8_t** input,
1812 const void* weights,
1813 int8_t* output,
1814 size_t input_stride,
1815 size_t output_increment,
1816 size_t input_offset,
1817 const int8_t* zero,
1818 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
1819 {
1820 assert(channels != 0);
1821 assert(output_width != 0);
1822
1823 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
1824 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
1825 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
1826 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
1827 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
1828
1829 do {
1830 const int8_t* i0 = input[0];
1831 assert(i0 != NULL);
1832 if XNN_UNPREDICTABLE(i0 != zero) {
1833 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
1834 }
1835 const int8_t* i1 = input[1];
1836 assert(i1 != NULL);
1837 if XNN_UNPREDICTABLE(i1 != zero) {
1838 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
1839 }
1840 const int8_t* i2 = input[2];
1841 assert(i2 != NULL);
1842 if XNN_UNPREDICTABLE(i2 != zero) {
1843 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
1844 }
1845 const int8_t* i3 = input[3];
1846 assert(i3 != NULL);
1847 if XNN_UNPREDICTABLE(i3 != zero) {
1848 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
1849 }
1850 const int8_t* i4 = input[4];
1851 assert(i4 != NULL);
1852 if XNN_UNPREDICTABLE(i4 != zero) {
1853 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
1854 }
1855 const int8_t* i5 = input[5];
1856 assert(i5 != NULL);
1857 if XNN_UNPREDICTABLE(i5 != zero) {
1858 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
1859 }
1860 const int8_t* i6 = input[6];
1861 assert(i6 != NULL);
1862 if XNN_UNPREDICTABLE(i6 != zero) {
1863 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
1864 }
1865 const int8_t* i7 = input[7];
1866 assert(i7 != NULL);
1867 if XNN_UNPREDICTABLE(i7 != zero) {
1868 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
1869 }
1870 const int8_t* i8 = input[8];
1871 assert(i8 != NULL);
1872 if XNN_UNPREDICTABLE(i8 != zero) {
1873 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
1874 }
1875 const int8_t* i9 = input[9];
1876 assert(i9 != NULL);
1877 if XNN_UNPREDICTABLE(i9 != zero) {
1878 i9 = (const int8_t*) ((uintptr_t) i9 + input_offset);
1879 }
1880 const int8_t* i10 = input[10];
1881 assert(i10 != NULL);
1882 if XNN_UNPREDICTABLE(i10 != zero) {
1883 i10 = (const int8_t*) ((uintptr_t) i10 + input_offset);
1884 }
1885 const int8_t* i11 = input[11];
1886 assert(i11 != NULL);
1887 if XNN_UNPREDICTABLE(i11 != zero) {
1888 i11 = (const int8_t*) ((uintptr_t) i11 + input_offset);
1889 }
1890 const int8_t* i12 = input[12];
1891 assert(i12 != NULL);
1892 if XNN_UNPREDICTABLE(i12 != zero) {
1893 i12 = (const int8_t*) ((uintptr_t) i12 + input_offset);
1894 }
1895 const int8_t* i13 = input[13];
1896 assert(i13 != NULL);
1897 if XNN_UNPREDICTABLE(i13 != zero) {
1898 i13 = (const int8_t*) ((uintptr_t) i13 + input_offset);
1899 }
1900 const int8_t* i14 = input[14];
1901 assert(i14 != NULL);
1902 if XNN_UNPREDICTABLE(i14 != zero) {
1903 i14 = (const int8_t*) ((uintptr_t) i14 + input_offset);
1904 }
1905 const int8_t* i15 = input[15];
1906 assert(i15 != NULL);
1907 if XNN_UNPREDICTABLE(i15 != zero) {
1908 i15 = (const int8_t*) ((uintptr_t) i15 + input_offset);
1909 }
1910 const int8_t* i16 = input[16];
1911 assert(i16 != NULL);
1912 if XNN_UNPREDICTABLE(i16 != zero) {
1913 i16 = (const int8_t*) ((uintptr_t) i16 + input_offset);
1914 }
1915 const int8_t* i17 = input[17];
1916 assert(i17 != NULL);
1917 if XNN_UNPREDICTABLE(i17 != zero) {
1918 i17 = (const int8_t*) ((uintptr_t) i17 + input_offset);
1919 }
1920 const int8_t* i18 = input[18];
1921 assert(i18 != NULL);
1922 if XNN_UNPREDICTABLE(i18 != zero) {
1923 i18 = (const int8_t*) ((uintptr_t) i18 + input_offset);
1924 }
1925 const int8_t* i19 = input[19];
1926 assert(i19 != NULL);
1927 if XNN_UNPREDICTABLE(i19 != zero) {
1928 i19 = (const int8_t*) ((uintptr_t) i19 + input_offset);
1929 }
1930 const int8_t* i20 = input[20];
1931 assert(i20 != NULL);
1932 if XNN_UNPREDICTABLE(i20 != zero) {
1933 i20 = (const int8_t*) ((uintptr_t) i20 + input_offset);
1934 }
1935 const int8_t* i21 = input[21];
1936 assert(i21 != NULL);
1937 if XNN_UNPREDICTABLE(i21 != zero) {
1938 i21 = (const int8_t*) ((uintptr_t) i21 + input_offset);
1939 }
1940 const int8_t* i22 = input[22];
1941 assert(i22 != NULL);
1942 if XNN_UNPREDICTABLE(i22 != zero) {
1943 i22 = (const int8_t*) ((uintptr_t) i22 + input_offset);
1944 }
1945 const int8_t* i23 = input[23];
1946 assert(i23 != NULL);
1947 if XNN_UNPREDICTABLE(i23 != zero) {
1948 i23 = (const int8_t*) ((uintptr_t) i23 + input_offset);
1949 }
1950 const int8_t* i24 = input[24];
1951 assert(i24 != NULL);
1952 if XNN_UNPREDICTABLE(i24 != zero) {
1953 i24 = (const int8_t*) ((uintptr_t) i24 + input_offset);
1954 }
1955 input = (const int8_t**) ((uintptr_t) input + input_stride);
1956
1957 size_t c = channels;
1958 const void* w = weights;
1959 for (; c >= 32; c -= 32) {
1960 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
1961 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
1962
1963
1964 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
1965 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
1966 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
1967 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
1968 i0 += 32;
1969
1970 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
1971 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
1972
1973 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
1974 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
1975 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
1976 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
1977 i1 += 32;
1978
1979 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
1980 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
1981
1982 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
1983 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
1984 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
1985 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
1986 i2 += 32;
1987
1988 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
1989 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
1990
1991 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
1992 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
1993 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
1994 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
1995 i3 += 32;
1996
1997 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
1998 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
1999
2000 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2001 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
2002 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
2003 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
2004 i4 += 32;
2005
2006 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2007 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
2008
2009 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2010 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
2011 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
2012 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
2013 i5 += 32;
2014
2015 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2016 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
2017
2018 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2019 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
2020 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
2021 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
2022 i6 += 32;
2023
2024 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2025 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
2026
2027 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2028 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
2029 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
2030 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
2031 i7 += 32;
2032
2033 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2034 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
2035
2036 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2037 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
2038 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
2039 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
2040 i8 += 32;
2041
2042 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2043 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
2044
2045 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
2046 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t))));
2047 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
2048 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(int8_t))));
2049 i9 += 32;
2050
2051 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
2052 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
2053
2054 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
2055 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(int8_t))));
2056 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
2057 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(int8_t))));
2058 i10 += 32;
2059
2060 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
2061 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
2062
2063 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
2064 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(int8_t))));
2065 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
2066 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(int8_t))));
2067 i11 += 32;
2068
2069 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
2070 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
2071
2072 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
2073 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(int8_t))));
2074 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
2075 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(int8_t))));
2076 i12 += 32;
2077
2078 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
2079 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
2080
2081 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
2082 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(int8_t))));
2083 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
2084 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(int8_t))));
2085 i13 += 32;
2086
2087 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
2088 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
2089
2090 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
2091 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(int8_t))));
2092 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
2093 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(int8_t))));
2094 i14 += 32;
2095
2096 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
2097 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
2098
2099 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
2100 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(int8_t))));
2101 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
2102 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(int8_t))));
2103 i15 += 32;
2104
2105 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
2106 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
2107
2108 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
2109 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(int8_t))));
2110 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
2111 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(int8_t))));
2112 i16 += 32;
2113
2114 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
2115 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
2116
2117 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
2118 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(int8_t))));
2119 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
2120 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(int8_t))));
2121 i17 += 32;
2122
2123 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
2124 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
2125
2126 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
2127 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(int8_t))));
2128 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
2129 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(int8_t))));
2130 i18 += 32;
2131
2132 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
2133 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
2134
2135 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
2136 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(int8_t))));
2137 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
2138 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(int8_t))));
2139 i19 += 32;
2140
2141 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
2142 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
2143
2144 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
2145 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(int8_t))));
2146 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
2147 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(int8_t))));
2148 i20 += 32;
2149
2150 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
2151 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
2152
2153 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
2154 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(int8_t))));
2155 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
2156 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(int8_t))));
2157 i21 += 32;
2158
2159 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
2160 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
2161
2162 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
2163 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(int8_t))));
2164 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
2165 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(int8_t))));
2166 i22 += 32;
2167
2168 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
2169 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
2170
2171 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
2172 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(int8_t))));
2173 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
2174 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(int8_t))));
2175 i23 += 32;
2176
2177 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
2178 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
2179
2180 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
2181 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(int8_t))));
2182 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
2183 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(int8_t))));
2184 i24 += 32;
2185
2186 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
2187 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
2188
2189 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(int8_t));
2190
2191 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2192 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
2193
2194 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2195 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
2196
2197 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2198 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
2199
2200 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2201 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
2202
2203 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
2204 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
2205
2206 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
2207 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
2208 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
2209 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
2210 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
2211 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
2212 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
2213
2214 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
2215 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
2216
2217 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
2218 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
2219 output += 32;
2220 }
2221 if XNN_UNLIKELY(c != 0) {
2222 // Prepare mask for valid 8-bit elements (depends on nc).
2223 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
2224 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
2225 do {
2226 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2227
2228
2229 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2230 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
2231 i0 += 16;
2232
2233 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2234
2235 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2236 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
2237 i1 += 16;
2238
2239 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2240
2241 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2242 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
2243 i2 += 16;
2244
2245 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2246
2247 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2248 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
2249 i3 += 16;
2250
2251 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2252
2253 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2254 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
2255 i4 += 16;
2256
2257 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2258
2259 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2260 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
2261 i5 += 16;
2262
2263 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2264
2265 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2266 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
2267 i6 += 16;
2268
2269 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2270
2271 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2272 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
2273 i7 += 16;
2274
2275 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2276
2277 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2278 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
2279 i8 += 16;
2280
2281 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2282
2283 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i9));
2284 const __m512i vk9x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 288)));
2285 i9 += 16;
2286
2287 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
2288
2289 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i10));
2290 const __m512i vk10x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 320)));
2291 i10 += 16;
2292
2293 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
2294
2295 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i11));
2296 const __m512i vk11x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 352)));
2297 i11 += 16;
2298
2299 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
2300
2301 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i12));
2302 const __m512i vk12x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 384)));
2303 i12 += 16;
2304
2305 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
2306
2307 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i13));
2308 const __m512i vk13x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 416)));
2309 i13 += 16;
2310
2311 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
2312
2313 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i14));
2314 const __m512i vk14x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 448)));
2315 i14 += 16;
2316
2317 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
2318
2319 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i15));
2320 const __m512i vk15x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 480)));
2321 i15 += 16;
2322
2323 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
2324
2325 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i16));
2326 const __m512i vk16x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 512)));
2327 i16 += 16;
2328
2329 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
2330
2331 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i17));
2332 const __m512i vk17x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 544)));
2333 i17 += 16;
2334
2335 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
2336
2337 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i18));
2338 const __m512i vk18x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 576)));
2339 i18 += 16;
2340
2341 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
2342
2343 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i19));
2344 const __m512i vk19x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 608)));
2345 i19 += 16;
2346
2347 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
2348
2349 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i20));
2350 const __m512i vk20x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 640)));
2351 i20 += 16;
2352
2353 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
2354
2355 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i21));
2356 const __m512i vk21x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 672)));
2357 i21 += 16;
2358
2359 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
2360
2361 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i22));
2362 const __m512i vk22x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 704)));
2363 i22 += 16;
2364
2365 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
2366
2367 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i23));
2368 const __m512i vk23x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 736)));
2369 i23 += 16;
2370
2371 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
2372
2373 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i24));
2374 const __m512i vk24x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 768)));
2375 i24 += 16;
2376
2377 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
2378
2379 k += 16;
2380
2381 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2382 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2383 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2384 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2385
2386 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
2387
2388 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
2389
2390 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
2391 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
2392 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
2393 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
2394
2395 if XNN_LIKELY(c >= 16) {
2396 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2397 output += 16;
2398 c -= 16;
2399 } else {
2400 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
2401 output = (int8_t*) ((uintptr_t) output + c);
2402 c = 0;
2403 }
2404 } while (c != 0);
2405 }
2406
2407 output = (int8_t*) ((uintptr_t) output + output_increment);
2408 } while (--output_width != 0);
2409 }
2410
xnn_qs8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const int8_t ** input,const void * weights,int8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2411 void xnn_qs8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
2412 size_t channels,
2413 size_t output_width,
2414 const int8_t** input,
2415 const void* weights,
2416 int8_t* output,
2417 size_t input_stride,
2418 size_t output_increment,
2419 size_t input_offset,
2420 const int8_t* zero,
2421 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
2422 {
2423 assert(channels != 0);
2424 assert(output_width != 0);
2425
2426 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
2427 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
2428 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
2429 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
2430 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
2431
2432 do {
2433 const int8_t* i0 = input[0];
2434 assert(i0 != NULL);
2435 if XNN_UNPREDICTABLE(i0 != zero) {
2436 i0 = (const int8_t*) ((uintptr_t) i0 + input_offset);
2437 }
2438 const int8_t* i1 = input[1];
2439 assert(i1 != NULL);
2440 if XNN_UNPREDICTABLE(i1 != zero) {
2441 i1 = (const int8_t*) ((uintptr_t) i1 + input_offset);
2442 }
2443 const int8_t* i2 = input[2];
2444 assert(i2 != NULL);
2445 if XNN_UNPREDICTABLE(i2 != zero) {
2446 i2 = (const int8_t*) ((uintptr_t) i2 + input_offset);
2447 }
2448 const int8_t* i3 = input[3];
2449 assert(i3 != NULL);
2450 if XNN_UNPREDICTABLE(i3 != zero) {
2451 i3 = (const int8_t*) ((uintptr_t) i3 + input_offset);
2452 }
2453 const int8_t* i4 = input[4];
2454 assert(i4 != NULL);
2455 if XNN_UNPREDICTABLE(i4 != zero) {
2456 i4 = (const int8_t*) ((uintptr_t) i4 + input_offset);
2457 }
2458 const int8_t* i5 = input[5];
2459 assert(i5 != NULL);
2460 if XNN_UNPREDICTABLE(i5 != zero) {
2461 i5 = (const int8_t*) ((uintptr_t) i5 + input_offset);
2462 }
2463 const int8_t* i6 = input[6];
2464 assert(i6 != NULL);
2465 if XNN_UNPREDICTABLE(i6 != zero) {
2466 i6 = (const int8_t*) ((uintptr_t) i6 + input_offset);
2467 }
2468 const int8_t* i7 = input[7];
2469 assert(i7 != NULL);
2470 if XNN_UNPREDICTABLE(i7 != zero) {
2471 i7 = (const int8_t*) ((uintptr_t) i7 + input_offset);
2472 }
2473 const int8_t* i8 = input[8];
2474 assert(i8 != NULL);
2475 if XNN_UNPREDICTABLE(i8 != zero) {
2476 i8 = (const int8_t*) ((uintptr_t) i8 + input_offset);
2477 }
2478 input = (const int8_t**) ((uintptr_t) input + input_stride);
2479
2480 size_t c = channels;
2481 const void* w = weights;
2482 for (; c >= 32; c -= 32) {
2483 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2484 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
2485
2486
2487 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2488 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(int8_t))));
2489 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
2490 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(int8_t))));
2491 i0 += 32;
2492
2493 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2494 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
2495
2496 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2497 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(int8_t))));
2498 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
2499 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(int8_t))));
2500 i1 += 32;
2501
2502 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2503 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
2504
2505 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2506 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(int8_t))));
2507 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
2508 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(int8_t))));
2509 i2 += 32;
2510
2511 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2512 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
2513
2514 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2515 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(int8_t))));
2516 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
2517 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(int8_t))));
2518 i3 += 32;
2519
2520 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2521 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
2522
2523 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2524 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(int8_t))));
2525 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
2526 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(int8_t))));
2527 i4 += 32;
2528
2529 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2530 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
2531
2532 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2533 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(int8_t))));
2534 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
2535 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(int8_t))));
2536 i5 += 32;
2537
2538 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2539 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
2540
2541 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2542 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(int8_t))));
2543 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
2544 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(int8_t))));
2545 i6 += 32;
2546
2547 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2548 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
2549
2550 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2551 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(int8_t))));
2552 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
2553 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(int8_t))));
2554 i7 += 32;
2555
2556 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2557 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
2558
2559 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2560 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(int8_t))));
2561 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
2562 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(int8_t))));
2563 i8 += 32;
2564
2565 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2566 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
2567
2568 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(int8_t));
2569
2570 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2571 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
2572
2573 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2574 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
2575
2576 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2577 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
2578
2579 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2580 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
2581
2582 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
2583 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
2584
2585 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
2586 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
2587 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packs_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
2588 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
2589 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
2590 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
2591 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packs_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
2592
2593 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epi8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
2594 voutGHIJKLMNOPQRSTUV = _mm_max_epi8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
2595
2596 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
2597 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
2598 output += 32;
2599 }
2600 if XNN_UNLIKELY(c != 0) {
2601 // Prepare mask for valid 8-bit elements (depends on nc).
2602 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
2603 const int8_t* k = (const int8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
2604 do {
2605 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
2606
2607
2608 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i0));
2609 const __m512i vk0x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) k));
2610 i0 += 16;
2611
2612 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
2613
2614 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i1));
2615 const __m512i vk1x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 32)));
2616 i1 += 16;
2617
2618 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
2619
2620 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i2));
2621 const __m512i vk2x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 64)));
2622 i2 += 16;
2623
2624 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
2625
2626 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i3));
2627 const __m512i vk3x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 96)));
2628 i3 += 16;
2629
2630 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
2631
2632 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i4));
2633 const __m512i vk4x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 128)));
2634 i4 += 16;
2635
2636 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
2637
2638 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i5));
2639 const __m512i vk5x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 160)));
2640 i5 += 16;
2641
2642 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
2643
2644 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i6));
2645 const __m512i vk6x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 192)));
2646 i6 += 16;
2647
2648 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
2649
2650 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i7));
2651 const __m512i vk7x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 224)));
2652 i7 += 16;
2653
2654 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
2655
2656 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) i8));
2657 const __m512i vk8x0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (k + 256)));
2658 i8 += 16;
2659
2660 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
2661
2662 k += 16;
2663
2664 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
2665 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
2666 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
2667 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
2668
2669 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
2670
2671 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
2672
2673 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
2674 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
2675 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
2676 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
2677
2678 if XNN_LIKELY(c >= 16) {
2679 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
2680 output += 16;
2681 c -= 16;
2682 } else {
2683 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
2684 output = (int8_t*) ((uintptr_t) output + c);
2685 c = 0;
2686 }
2687 } while (c != 0);
2688 }
2689
2690 output = (int8_t*) ((uintptr_t) output + output_increment);
2691 } while (--output_width != 0);
2692 }
2693
xnn_qs8_f32_vcvt_ukernel__avx512skx_x32(size_t n,const int8_t * x,float * y,const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])2694 void xnn_qs8_f32_vcvt_ukernel__avx512skx_x32(
2695 size_t n,
2696 const int8_t* x,
2697 float* y,
2698 const union xnn_qs8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2699 {
2700 assert(n != 0);
2701 assert(n % sizeof(int8_t) == 0);
2702 assert(x != NULL);
2703 assert(y != NULL);
2704
2705 const __m512i vminus_zero_point = _mm512_load_si512(params->avx512.minus_zero_point);
2706 const __m512 vscale = _mm512_load_ps(params->avx512.scale);
2707 for (; n >= 32 * sizeof(int8_t); n -= 32 * sizeof(int8_t)) {
2708 __m512i vx0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) x));
2709 __m512i vxGHIJKLMNOPQRSTUV = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) (x + 16)));
2710 x += 32;
2711
2712 vx0123456789ABCDEF = _mm512_add_epi32(vx0123456789ABCDEF, vminus_zero_point);
2713 vxGHIJKLMNOPQRSTUV = _mm512_add_epi32(vxGHIJKLMNOPQRSTUV, vminus_zero_point);
2714
2715 __m512 vy0123456789ABCDEF = _mm512_cvtepi32_ps(vx0123456789ABCDEF);
2716 __m512 vyGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vxGHIJKLMNOPQRSTUV);
2717
2718 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vscale);
2719 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vscale);
2720
2721 _mm512_storeu_ps(y, vy0123456789ABCDEF);
2722 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
2723 y += 32;
2724 }
2725 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
2726 __m512i vx = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) x));
2727 vx = _mm512_add_epi32(vx, vminus_zero_point);
2728 x += 16;
2729
2730 __m512 vy = _mm512_cvtepi32_ps(vx);
2731 vy = _mm512_mul_ps(vy, vscale);
2732
2733 _mm512_storeu_ps(y, vy);
2734 y += 16;
2735 }
2736 if XNN_UNLIKELY(n != 0) {
2737 assert(n >= 1 * sizeof(int8_t));
2738 assert(n <= 15 * sizeof(int8_t));
2739
2740 // Prepare mask for valid elements (depends on n).
2741 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
2742
2743 __m512i vx = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, x));
2744 vx = _mm512_add_epi32(vx, vminus_zero_point);
2745
2746 __m512 vy = _mm512_cvtepi32_ps(vx);
2747 vy = _mm512_mul_ps(vy, vscale);
2748
2749 _mm512_mask_storeu_ps(y, vmask, vy);
2750 }
2751 }
2752
xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2753 void xnn_qs8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
2754 size_t mr,
2755 size_t nc,
2756 size_t kc,
2757 const int8_t* restrict a,
2758 size_t a_stride,
2759 const void* restrict w,
2760 int8_t* restrict c,
2761 size_t cm_stride,
2762 size_t cn_stride,
2763 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2764 {
2765 assert(mr != 0);
2766 assert(mr <= 1);
2767 assert(nc != 0);
2768 assert(kc != 0);
2769 assert(kc % sizeof(int8_t) == 0);
2770 assert(a != NULL);
2771 assert(w != NULL);
2772 assert(c != NULL);
2773
2774 kc = round_up_po2(kc, 8);
2775 const int8_t* a0 = a;
2776 int8_t* c0 = c;
2777
2778 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
2779 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
2780 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
2781 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
2782 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
2783 do {
2784 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
2785 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
2786 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
2787 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
2788 w = (const void*) ((const int32_t*) w + 16);
2789
2790 size_t k = 0;
2791 while (k < kc) {
2792 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
2793 a0 += 8;
2794
2795 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
2796
2797 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
2798 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
2799
2800 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
2801 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
2802
2803 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
2804 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
2805
2806 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
2807
2808 w = (const void*) ((const int8_t*) w + 128);
2809 k += 8 * sizeof(int8_t);
2810 }
2811
2812 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
2813 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
2814
2815 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
2816
2817 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
2818
2819 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
2820
2821 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
2822
2823 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
2824
2825 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
2826
2827 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
2828 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
2829 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
2830
2831 if (nc >= 16) {
2832 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
2833
2834 a0 = (const int8_t*) ((uintptr_t) a0 - k);
2835
2836 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
2837
2838 nc -= 16;
2839 } else {
2840 // Prepare mask for valid 8-bit elements (depends on nc).
2841 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
2842
2843 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
2844
2845 nc = 0;
2846 }
2847 } while (nc != 0);
2848 }
2849
xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const int8_t * restrict a,size_t a_stride,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])2850 void xnn_qs8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
2851 size_t mr,
2852 size_t nc,
2853 size_t kc,
2854 const int8_t* restrict a,
2855 size_t a_stride,
2856 const void* restrict w,
2857 int8_t* restrict c,
2858 size_t cm_stride,
2859 size_t cn_stride,
2860 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
2861 {
2862 assert(mr != 0);
2863 assert(mr <= 4);
2864 assert(nc != 0);
2865 assert(kc != 0);
2866 assert(kc % sizeof(int8_t) == 0);
2867 assert(a != NULL);
2868 assert(w != NULL);
2869 assert(c != NULL);
2870
2871 kc = round_up_po2(kc, 8);
2872 const int8_t* a0 = a;
2873 int8_t* c0 = c;
2874 const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride);
2875 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
2876 if XNN_UNPREDICTABLE(mr < 2) {
2877 a1 = a0;
2878 c1 = c0;
2879 }
2880 const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride);
2881 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
2882 if XNN_UNPREDICTABLE(mr <= 2) {
2883 a2 = a1;
2884 c2 = c1;
2885 }
2886 const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride);
2887 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
2888 if XNN_UNPREDICTABLE(mr != 4) {
2889 a3 = a2;
2890 c3 = c2;
2891 }
2892
2893 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
2894 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
2895 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
2896 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
2897 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
2898 do {
2899 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
2900 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
2901 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
2902 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
2903 __m512i vacc1x0123 = vacc0x0123;
2904 __m512i vacc1x4567 = vacc0x4567;
2905 __m512i vacc1x89AB = vacc0x89AB;
2906 __m512i vacc1xCDEF = vacc0xCDEF;
2907 __m512i vacc2x0123 = vacc0x0123;
2908 __m512i vacc2x4567 = vacc0x4567;
2909 __m512i vacc2x89AB = vacc0x89AB;
2910 __m512i vacc2xCDEF = vacc0xCDEF;
2911 __m512i vacc3x0123 = vacc0x0123;
2912 __m512i vacc3x4567 = vacc0x4567;
2913 __m512i vacc3x89AB = vacc0x89AB;
2914 __m512i vacc3xCDEF = vacc0xCDEF;
2915 w = (const void*) ((const int32_t*) w + 16);
2916
2917 size_t k = 0;
2918 while (k < kc) {
2919 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
2920 a0 += 8;
2921 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
2922 a1 += 8;
2923 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
2924 a2 += 8;
2925 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
2926 a3 += 8;
2927
2928 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
2929
2930 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
2931 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
2932 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
2933 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
2934 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
2935
2936 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
2937 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
2938 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
2939 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
2940 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
2941
2942 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
2943 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
2944 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
2945 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
2946 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
2947
2948 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
2949 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
2950 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
2951 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
2952
2953 w = (const void*) ((const int8_t*) w + 128);
2954 k += 8 * sizeof(int8_t);
2955 }
2956
2957 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
2958 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
2959 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
2960 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
2961 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
2962 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
2963 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
2964 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
2965
2966 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
2967 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
2968 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
2969 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
2970
2971 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
2972 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
2973 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
2974 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
2975
2976 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
2977 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
2978 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
2979 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
2980
2981 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
2982 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
2983 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
2984 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
2985
2986 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
2987 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
2988 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
2989 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
2990
2991 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
2992 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
2993
2994 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
2995 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
2996 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
2997 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
2998
2999 if (nc >= 16) {
3000 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
3001 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
3002 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
3003 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
3004
3005 a0 = (const int8_t*) ((uintptr_t) a0 - k);
3006 a1 = (const int8_t*) ((uintptr_t) a1 - k);
3007 a2 = (const int8_t*) ((uintptr_t) a2 - k);
3008 a3 = (const int8_t*) ((uintptr_t) a3 - k);
3009
3010 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3011 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3012 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3013 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
3014
3015 nc -= 16;
3016 } else {
3017 // Prepare mask for valid 8-bit elements (depends on nc).
3018 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
3019
3020 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
3021 vmask = _kshiftli_mask64(vmask, 16);
3022 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
3023 vmask = _kshiftli_mask64(vmask, 16);
3024 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
3025 vmask = _kshiftli_mask64(vmask, 16);
3026 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
3027
3028 nc = 0;
3029 }
3030 } while (nc != 0);
3031 }
3032
xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3033 void xnn_qs8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
3034 size_t mr,
3035 size_t nc,
3036 size_t kc,
3037 size_t ks,
3038 const int8_t** restrict a,
3039 const void* restrict w,
3040 int8_t* restrict c,
3041 size_t cm_stride,
3042 size_t cn_stride,
3043 size_t a_offset,
3044 const int8_t* zero,
3045 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3046 {
3047 assert(mr != 0);
3048 assert(mr <= 1);
3049 assert(nc != 0);
3050 assert(kc != 0);
3051 assert(kc % sizeof(int8_t) == 0);
3052 assert(a != NULL);
3053 assert(w != NULL);
3054 assert(c != NULL);
3055
3056 kc = round_up_po2(kc, 8);
3057 int8_t* c0 = c;
3058
3059 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
3060 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3061 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3062 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
3063 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
3064 do {
3065 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
3066 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
3067 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
3068 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
3069 w = (const void*) ((const int32_t*) w + 16);
3070
3071 size_t p = ks;
3072 do {
3073 const int8_t* restrict a0 = a[0];
3074 if XNN_UNPREDICTABLE(a0 != zero) {
3075 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3076 }
3077 a += 1;
3078
3079 size_t k = 0;
3080 while (k < kc) {
3081 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
3082 a0 += 8;
3083
3084 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
3085
3086 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
3087 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
3088
3089 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
3090 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
3091
3092 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
3093 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
3094
3095 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
3096
3097 w = (const void*) ((const int8_t*) w + 128);
3098 k += 8 * sizeof(int8_t);
3099 }
3100 p -= 1 * sizeof(void*);
3101 } while (p != 0);
3102
3103 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
3104 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
3105
3106 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
3107
3108 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
3109
3110 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
3111
3112 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
3113
3114 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
3115
3116 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
3117
3118 const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
3119 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
3120 vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
3121
3122 if (nc >= 16) {
3123 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
3124
3125 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3126
3127 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3128
3129 nc -= 16;
3130 } else {
3131 // Prepare mask for valid 8-bit elements (depends on nc).
3132 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
3133
3134 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
3135
3136 nc = 0;
3137 }
3138 } while (nc != 0);
3139 }
3140
xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const int8_t ** restrict a,const void * restrict w,int8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const int8_t * zero,const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3141 void xnn_qs8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
3142 size_t mr,
3143 size_t nc,
3144 size_t kc,
3145 size_t ks,
3146 const int8_t** restrict a,
3147 const void* restrict w,
3148 int8_t* restrict c,
3149 size_t cm_stride,
3150 size_t cn_stride,
3151 size_t a_offset,
3152 const int8_t* zero,
3153 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
3154 {
3155 assert(mr != 0);
3156 assert(mr <= 4);
3157 assert(nc != 0);
3158 assert(kc != 0);
3159 assert(kc % sizeof(int8_t) == 0);
3160 assert(a != NULL);
3161 assert(w != NULL);
3162 assert(c != NULL);
3163
3164 kc = round_up_po2(kc, 8);
3165 int8_t* c0 = c;
3166 int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride);
3167 if XNN_UNPREDICTABLE(mr < 2) {
3168 c1 = c0;
3169 }
3170 int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride);
3171 if XNN_UNPREDICTABLE(mr <= 2) {
3172 c2 = c1;
3173 }
3174 int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride);
3175 if XNN_UNPREDICTABLE(mr != 4) {
3176 c3 = c2;
3177 }
3178
3179 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
3180 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3181 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3182 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
3183 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
3184 do {
3185 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
3186 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
3187 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
3188 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
3189 __m512i vacc1x0123 = vacc0x0123;
3190 __m512i vacc1x4567 = vacc0x4567;
3191 __m512i vacc1x89AB = vacc0x89AB;
3192 __m512i vacc1xCDEF = vacc0xCDEF;
3193 __m512i vacc2x0123 = vacc0x0123;
3194 __m512i vacc2x4567 = vacc0x4567;
3195 __m512i vacc2x89AB = vacc0x89AB;
3196 __m512i vacc2xCDEF = vacc0xCDEF;
3197 __m512i vacc3x0123 = vacc0x0123;
3198 __m512i vacc3x4567 = vacc0x4567;
3199 __m512i vacc3x89AB = vacc0x89AB;
3200 __m512i vacc3xCDEF = vacc0xCDEF;
3201 w = (const void*) ((const int32_t*) w + 16);
3202
3203 size_t p = ks;
3204 do {
3205 const int8_t* restrict a0 = a[0];
3206 if XNN_UNPREDICTABLE(a0 != zero) {
3207 a0 = (const int8_t*) ((uintptr_t) a0 + a_offset);
3208 }
3209 const int8_t* restrict a1 = a[1];
3210 if XNN_UNPREDICTABLE(a1 != zero) {
3211 a1 = (const int8_t*) ((uintptr_t) a1 + a_offset);
3212 }
3213 const int8_t* restrict a2 = a[2];
3214 if XNN_UNPREDICTABLE(a2 != zero) {
3215 a2 = (const int8_t*) ((uintptr_t) a2 + a_offset);
3216 }
3217 const int8_t* restrict a3 = a[3];
3218 if XNN_UNPREDICTABLE(a3 != zero) {
3219 a3 = (const int8_t*) ((uintptr_t) a3 + a_offset);
3220 }
3221 a += 4;
3222
3223 size_t k = 0;
3224 while (k < kc) {
3225 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
3226 a0 += 8;
3227 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
3228 a1 += 8;
3229 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
3230 a2 += 8;
3231 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
3232 a3 += 8;
3233
3234 const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
3235
3236 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
3237 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
3238 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
3239 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
3240 const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
3241
3242 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
3243 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
3244 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
3245 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
3246 const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
3247
3248 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
3249 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
3250 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
3251 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
3252 const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
3253
3254 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
3255 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
3256 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
3257 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
3258
3259 w = (const void*) ((const int8_t*) w + 128);
3260 k += 8 * sizeof(int8_t);
3261 }
3262 p -= 4 * sizeof(void*);
3263 } while (p != 0);
3264
3265 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
3266 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
3267 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
3268 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
3269 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
3270 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
3271 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
3272 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
3273
3274 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
3275 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
3276 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
3277 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
3278
3279 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
3280 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
3281 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
3282 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
3283
3284 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
3285 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
3286 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
3287 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
3288
3289 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
3290 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
3291 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
3292 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
3293
3294 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
3295 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
3296 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
3297 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
3298
3299 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
3300 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
3301
3302 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
3303 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
3304 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
3305 vout0123x0123456789ABCDEF = _mm512_max_epi8(vout0123x0123456789ABCDEF, voutput_min);
3306
3307 if (nc >= 16) {
3308 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
3309 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
3310 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
3311 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
3312
3313 c3 = (int8_t*) ((uintptr_t) c3 + cn_stride);
3314 c2 = (int8_t*) ((uintptr_t) c2 + cn_stride);
3315 c1 = (int8_t*) ((uintptr_t) c1 + cn_stride);
3316 c0 = (int8_t*) ((uintptr_t) c0 + cn_stride);
3317
3318 a = (const int8_t**restrict) ((uintptr_t) a - ks);
3319
3320 nc -= 16;
3321 } else {
3322 // Prepare mask for valid 8-bit elements (depends on nc).
3323 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
3324
3325 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
3326 vmask = _kshiftri_mask64(vmask, 16);
3327 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
3328 vmask = _kshiftri_mask64(vmask, 16);
3329 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
3330 vmask = _kshiftri_mask64(vmask, 16);
3331 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
3332
3333 nc = 0;
3334 }
3335 } while (nc != 0);
3336 }
3337
xnn_qs8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3338 void xnn_qs8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(
3339 size_t n,
3340 const int8_t* input_a,
3341 const int8_t* input_b,
3342 int8_t* output,
3343 const union xnn_qs8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3344 {
3345 const __m512i vbias = _mm512_load_si512(params->avx512.bias);
3346 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
3347 const __m512i vb_multiplier = _mm512_load_si512(params->avx512.b_multiplier);
3348 const __m128i vshift = _mm_loadu_si32(params->avx512.shift);
3349 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
3350 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
3351 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
3352
3353 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
3354 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_a));
3355 const __m512i vb0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_b));
3356 input_a += 16;
3357 input_b += 16;
3358
3359 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3360
3361 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
3362
3363 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3364
3365 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3366
3367 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3368
3369 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3370
3371 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3372
3373 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
3374 output += 16;
3375 }
3376 if XNN_UNLIKELY(n != 0) {
3377 {
3378 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
3379 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
3380 const __m512i vb0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_b));
3381
3382 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3383
3384 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
3385
3386 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3387
3388 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3389 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3390 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3391 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3392
3393 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
3394 }
3395 }
3396 }
3397
xnn_qs8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const int8_t * input_a,const int8_t * input_b,int8_t * output,const union xnn_qs8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3398 void xnn_qs8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(
3399 size_t n,
3400 const int8_t* input_a,
3401 const int8_t* input_b,
3402 int8_t* output,
3403 const union xnn_qs8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3404 {
3405 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
3406 const __m128i vshift = _mm_loadu_si32(params->avx512.shift);
3407 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
3408 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
3409 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
3410
3411 const __m512i vbias = _mm512_add_epi32(
3412 _mm512_broadcastd_epi32(_mm_cvtsi32_si128(params->avx512.b_multiplier[0] * (int32_t) *input_b)),
3413 _mm512_load_si512(params->avx512.bias));
3414 for (; n >= 16 * sizeof(int8_t); n -= 16 * sizeof(int8_t)) {
3415 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*) input_a));
3416 input_a += 16;
3417
3418 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3419
3420 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3421
3422 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3423
3424 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3425
3426 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3427
3428 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3429
3430 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
3431 output += 16;
3432 }
3433 if XNN_UNLIKELY(n != 0) {
3434 {
3435 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
3436 const __m512i va0123456789ABCDEF = _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
3437
3438 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
3439
3440 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
3441
3442 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
3443 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
3444 vout0123456789ABCDEF = _mm_max_epi8(vout0123456789ABCDEF, voutput_min);
3445 vout0123456789ABCDEF = _mm_min_epi8(vout0123456789ABCDEF, voutput_max);
3446
3447 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
3448 }
3449 }
3450 }
3451
xnn_qu8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])3452 void xnn_qu8_dwconv_minmax_fp32_ukernel_up32x25__avx512skx_mul32(
3453 size_t channels,
3454 size_t output_width,
3455 const uint8_t** input,
3456 const void* weights,
3457 uint8_t* output,
3458 size_t input_stride,
3459 size_t output_increment,
3460 size_t input_offset,
3461 const uint8_t* zero,
3462 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
3463 {
3464 assert(channels != 0);
3465 assert(output_width != 0);
3466
3467 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
3468 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
3469 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
3470 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
3471 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
3472
3473 const __m512i vk_zero_point = _mm512_cvtepu16_epi32(_mm256_load_si256((const __m256i*) params->fp32_avx512.kernel_zero_point));
3474 do {
3475 const uint8_t* i0 = input[0];
3476 assert(i0 != NULL);
3477 if XNN_UNPREDICTABLE(i0 != zero) {
3478 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
3479 }
3480 const uint8_t* i1 = input[1];
3481 assert(i1 != NULL);
3482 if XNN_UNPREDICTABLE(i1 != zero) {
3483 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
3484 }
3485 const uint8_t* i2 = input[2];
3486 assert(i2 != NULL);
3487 if XNN_UNPREDICTABLE(i2 != zero) {
3488 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
3489 }
3490 const uint8_t* i3 = input[3];
3491 assert(i3 != NULL);
3492 if XNN_UNPREDICTABLE(i3 != zero) {
3493 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
3494 }
3495 const uint8_t* i4 = input[4];
3496 assert(i4 != NULL);
3497 if XNN_UNPREDICTABLE(i4 != zero) {
3498 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
3499 }
3500 const uint8_t* i5 = input[5];
3501 assert(i5 != NULL);
3502 if XNN_UNPREDICTABLE(i5 != zero) {
3503 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
3504 }
3505 const uint8_t* i6 = input[6];
3506 assert(i6 != NULL);
3507 if XNN_UNPREDICTABLE(i6 != zero) {
3508 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
3509 }
3510 const uint8_t* i7 = input[7];
3511 assert(i7 != NULL);
3512 if XNN_UNPREDICTABLE(i7 != zero) {
3513 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
3514 }
3515 const uint8_t* i8 = input[8];
3516 assert(i8 != NULL);
3517 if XNN_UNPREDICTABLE(i8 != zero) {
3518 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
3519 }
3520 const uint8_t* i9 = input[9];
3521 assert(i9 != NULL);
3522 if XNN_UNPREDICTABLE(i9 != zero) {
3523 i9 = (const uint8_t*) ((uintptr_t) i9 + input_offset);
3524 }
3525 const uint8_t* i10 = input[10];
3526 assert(i10 != NULL);
3527 if XNN_UNPREDICTABLE(i10 != zero) {
3528 i10 = (const uint8_t*) ((uintptr_t) i10 + input_offset);
3529 }
3530 const uint8_t* i11 = input[11];
3531 assert(i11 != NULL);
3532 if XNN_UNPREDICTABLE(i11 != zero) {
3533 i11 = (const uint8_t*) ((uintptr_t) i11 + input_offset);
3534 }
3535 const uint8_t* i12 = input[12];
3536 assert(i12 != NULL);
3537 if XNN_UNPREDICTABLE(i12 != zero) {
3538 i12 = (const uint8_t*) ((uintptr_t) i12 + input_offset);
3539 }
3540 const uint8_t* i13 = input[13];
3541 assert(i13 != NULL);
3542 if XNN_UNPREDICTABLE(i13 != zero) {
3543 i13 = (const uint8_t*) ((uintptr_t) i13 + input_offset);
3544 }
3545 const uint8_t* i14 = input[14];
3546 assert(i14 != NULL);
3547 if XNN_UNPREDICTABLE(i14 != zero) {
3548 i14 = (const uint8_t*) ((uintptr_t) i14 + input_offset);
3549 }
3550 const uint8_t* i15 = input[15];
3551 assert(i15 != NULL);
3552 if XNN_UNPREDICTABLE(i15 != zero) {
3553 i15 = (const uint8_t*) ((uintptr_t) i15 + input_offset);
3554 }
3555 const uint8_t* i16 = input[16];
3556 assert(i16 != NULL);
3557 if XNN_UNPREDICTABLE(i16 != zero) {
3558 i16 = (const uint8_t*) ((uintptr_t) i16 + input_offset);
3559 }
3560 const uint8_t* i17 = input[17];
3561 assert(i17 != NULL);
3562 if XNN_UNPREDICTABLE(i17 != zero) {
3563 i17 = (const uint8_t*) ((uintptr_t) i17 + input_offset);
3564 }
3565 const uint8_t* i18 = input[18];
3566 assert(i18 != NULL);
3567 if XNN_UNPREDICTABLE(i18 != zero) {
3568 i18 = (const uint8_t*) ((uintptr_t) i18 + input_offset);
3569 }
3570 const uint8_t* i19 = input[19];
3571 assert(i19 != NULL);
3572 if XNN_UNPREDICTABLE(i19 != zero) {
3573 i19 = (const uint8_t*) ((uintptr_t) i19 + input_offset);
3574 }
3575 const uint8_t* i20 = input[20];
3576 assert(i20 != NULL);
3577 if XNN_UNPREDICTABLE(i20 != zero) {
3578 i20 = (const uint8_t*) ((uintptr_t) i20 + input_offset);
3579 }
3580 const uint8_t* i21 = input[21];
3581 assert(i21 != NULL);
3582 if XNN_UNPREDICTABLE(i21 != zero) {
3583 i21 = (const uint8_t*) ((uintptr_t) i21 + input_offset);
3584 }
3585 const uint8_t* i22 = input[22];
3586 assert(i22 != NULL);
3587 if XNN_UNPREDICTABLE(i22 != zero) {
3588 i22 = (const uint8_t*) ((uintptr_t) i22 + input_offset);
3589 }
3590 const uint8_t* i23 = input[23];
3591 assert(i23 != NULL);
3592 if XNN_UNPREDICTABLE(i23 != zero) {
3593 i23 = (const uint8_t*) ((uintptr_t) i23 + input_offset);
3594 }
3595 const uint8_t* i24 = input[24];
3596 assert(i24 != NULL);
3597 if XNN_UNPREDICTABLE(i24 != zero) {
3598 i24 = (const uint8_t*) ((uintptr_t) i24 + input_offset);
3599 }
3600 input = (const uint8_t**) ((uintptr_t) input + input_stride);
3601
3602 size_t c = channels;
3603 const void* w = weights;
3604 for (; c >= 32; c -= 32) {
3605 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
3606 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
3607
3608
3609 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
3610 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
3611 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
3612 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
3613 i0 += 32;
3614
3615 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
3616 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
3617
3618 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
3619 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
3620 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
3621 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
3622 i1 += 32;
3623
3624 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
3625 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
3626
3627 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
3628 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
3629 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
3630 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
3631 i2 += 32;
3632
3633 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
3634 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
3635
3636 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
3637 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
3638 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
3639 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
3640 i3 += 32;
3641
3642 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
3643 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
3644
3645 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
3646 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
3647 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
3648 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(uint8_t)))), vk_zero_point);
3649 i4 += 32;
3650
3651 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
3652 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
3653
3654 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
3655 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(uint8_t)))), vk_zero_point);
3656 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
3657 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(uint8_t)))), vk_zero_point);
3658 i5 += 32;
3659
3660 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
3661 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
3662
3663 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
3664 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(uint8_t)))), vk_zero_point);
3665 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
3666 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(uint8_t)))), vk_zero_point);
3667 i6 += 32;
3668
3669 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
3670 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
3671
3672 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
3673 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(uint8_t)))), vk_zero_point);
3674 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
3675 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(uint8_t)))), vk_zero_point);
3676 i7 += 32;
3677
3678 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
3679 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
3680
3681 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
3682 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(uint8_t)))), vk_zero_point);
3683 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
3684 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(uint8_t)))), vk_zero_point);
3685 i8 += 32;
3686
3687 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
3688 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
3689
3690 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i9));
3691 const __m512i vk9x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(uint8_t)))), vk_zero_point);
3692 const __m512i vi9xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i9 + 16)));
3693 const __m512i vk9xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 304 * sizeof(uint8_t)))), vk_zero_point);
3694 i9 += 32;
3695
3696 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
3697 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi9xGHIJKLMNOPQRSTUV, vk9xGHIJKLMNOPQRSTUV));
3698
3699 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i10));
3700 const __m512i vk10x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 320 * sizeof(uint8_t)))), vk_zero_point);
3701 const __m512i vi10xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i10 + 16)));
3702 const __m512i vk10xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 336 * sizeof(uint8_t)))), vk_zero_point);
3703 i10 += 32;
3704
3705 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
3706 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi10xGHIJKLMNOPQRSTUV, vk10xGHIJKLMNOPQRSTUV));
3707
3708 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i11));
3709 const __m512i vk11x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 352 * sizeof(uint8_t)))), vk_zero_point);
3710 const __m512i vi11xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i11 + 16)));
3711 const __m512i vk11xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 368 * sizeof(uint8_t)))), vk_zero_point);
3712 i11 += 32;
3713
3714 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
3715 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi11xGHIJKLMNOPQRSTUV, vk11xGHIJKLMNOPQRSTUV));
3716
3717 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i12));
3718 const __m512i vk12x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 384 * sizeof(uint8_t)))), vk_zero_point);
3719 const __m512i vi12xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i12 + 16)));
3720 const __m512i vk12xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 400 * sizeof(uint8_t)))), vk_zero_point);
3721 i12 += 32;
3722
3723 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
3724 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi12xGHIJKLMNOPQRSTUV, vk12xGHIJKLMNOPQRSTUV));
3725
3726 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i13));
3727 const __m512i vk13x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 416 * sizeof(uint8_t)))), vk_zero_point);
3728 const __m512i vi13xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i13 + 16)));
3729 const __m512i vk13xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 432 * sizeof(uint8_t)))), vk_zero_point);
3730 i13 += 32;
3731
3732 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
3733 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi13xGHIJKLMNOPQRSTUV, vk13xGHIJKLMNOPQRSTUV));
3734
3735 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i14));
3736 const __m512i vk14x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 448 * sizeof(uint8_t)))), vk_zero_point);
3737 const __m512i vi14xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i14 + 16)));
3738 const __m512i vk14xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 464 * sizeof(uint8_t)))), vk_zero_point);
3739 i14 += 32;
3740
3741 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
3742 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi14xGHIJKLMNOPQRSTUV, vk14xGHIJKLMNOPQRSTUV));
3743
3744 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i15));
3745 const __m512i vk15x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 480 * sizeof(uint8_t)))), vk_zero_point);
3746 const __m512i vi15xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i15 + 16)));
3747 const __m512i vk15xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 496 * sizeof(uint8_t)))), vk_zero_point);
3748 i15 += 32;
3749
3750 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
3751 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi15xGHIJKLMNOPQRSTUV, vk15xGHIJKLMNOPQRSTUV));
3752
3753 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i16));
3754 const __m512i vk16x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 512 * sizeof(uint8_t)))), vk_zero_point);
3755 const __m512i vi16xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i16 + 16)));
3756 const __m512i vk16xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 528 * sizeof(uint8_t)))), vk_zero_point);
3757 i16 += 32;
3758
3759 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
3760 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi16xGHIJKLMNOPQRSTUV, vk16xGHIJKLMNOPQRSTUV));
3761
3762 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i17));
3763 const __m512i vk17x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 544 * sizeof(uint8_t)))), vk_zero_point);
3764 const __m512i vi17xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i17 + 16)));
3765 const __m512i vk17xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 560 * sizeof(uint8_t)))), vk_zero_point);
3766 i17 += 32;
3767
3768 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
3769 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi17xGHIJKLMNOPQRSTUV, vk17xGHIJKLMNOPQRSTUV));
3770
3771 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i18));
3772 const __m512i vk18x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 576 * sizeof(uint8_t)))), vk_zero_point);
3773 const __m512i vi18xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i18 + 16)));
3774 const __m512i vk18xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 592 * sizeof(uint8_t)))), vk_zero_point);
3775 i18 += 32;
3776
3777 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
3778 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi18xGHIJKLMNOPQRSTUV, vk18xGHIJKLMNOPQRSTUV));
3779
3780 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i19));
3781 const __m512i vk19x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 608 * sizeof(uint8_t)))), vk_zero_point);
3782 const __m512i vi19xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i19 + 16)));
3783 const __m512i vk19xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 624 * sizeof(uint8_t)))), vk_zero_point);
3784 i19 += 32;
3785
3786 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
3787 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi19xGHIJKLMNOPQRSTUV, vk19xGHIJKLMNOPQRSTUV));
3788
3789 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i20));
3790 const __m512i vk20x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 640 * sizeof(uint8_t)))), vk_zero_point);
3791 const __m512i vi20xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i20 + 16)));
3792 const __m512i vk20xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 656 * sizeof(uint8_t)))), vk_zero_point);
3793 i20 += 32;
3794
3795 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
3796 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi20xGHIJKLMNOPQRSTUV, vk20xGHIJKLMNOPQRSTUV));
3797
3798 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i21));
3799 const __m512i vk21x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 672 * sizeof(uint8_t)))), vk_zero_point);
3800 const __m512i vi21xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i21 + 16)));
3801 const __m512i vk21xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 688 * sizeof(uint8_t)))), vk_zero_point);
3802 i21 += 32;
3803
3804 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
3805 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi21xGHIJKLMNOPQRSTUV, vk21xGHIJKLMNOPQRSTUV));
3806
3807 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i22));
3808 const __m512i vk22x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 704 * sizeof(uint8_t)))), vk_zero_point);
3809 const __m512i vi22xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i22 + 16)));
3810 const __m512i vk22xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 720 * sizeof(uint8_t)))), vk_zero_point);
3811 i22 += 32;
3812
3813 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
3814 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi22xGHIJKLMNOPQRSTUV, vk22xGHIJKLMNOPQRSTUV));
3815
3816 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i23));
3817 const __m512i vk23x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 736 * sizeof(uint8_t)))), vk_zero_point);
3818 const __m512i vi23xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i23 + 16)));
3819 const __m512i vk23xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 752 * sizeof(uint8_t)))), vk_zero_point);
3820 i23 += 32;
3821
3822 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
3823 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi23xGHIJKLMNOPQRSTUV, vk23xGHIJKLMNOPQRSTUV));
3824
3825 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i24));
3826 const __m512i vk24x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 768 * sizeof(uint8_t)))), vk_zero_point);
3827 const __m512i vi24xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i24 + 16)));
3828 const __m512i vk24xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 784 * sizeof(uint8_t)))), vk_zero_point);
3829 i24 += 32;
3830
3831 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
3832 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi24xGHIJKLMNOPQRSTUV, vk24xGHIJKLMNOPQRSTUV));
3833
3834 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 800 * sizeof(uint8_t));
3835
3836 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
3837 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
3838
3839 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
3840 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
3841
3842 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
3843 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
3844
3845 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
3846 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
3847
3848 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
3849 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
3850
3851 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
3852 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
3853 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packus_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
3854 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
3855 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
3856 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
3857 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packus_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
3858
3859 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epu8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
3860 voutGHIJKLMNOPQRSTUV = _mm_max_epu8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
3861
3862 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
3863 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
3864 output += 32;
3865 }
3866 if XNN_UNLIKELY(c != 0) {
3867 // Prepare mask for valid 8-bit elements (depends on nc).
3868 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
3869 const uint8_t* k = (const uint8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
3870 do {
3871 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
3872
3873
3874 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
3875 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) k)), vk_zero_point);
3876 i0 += 16;
3877
3878 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
3879
3880 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
3881 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 32))), vk_zero_point);
3882 i1 += 16;
3883
3884 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
3885
3886 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
3887 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 64))), vk_zero_point);
3888 i2 += 16;
3889
3890 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
3891
3892 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
3893 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 96))), vk_zero_point);
3894 i3 += 16;
3895
3896 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
3897
3898 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
3899 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 128))), vk_zero_point);
3900 i4 += 16;
3901
3902 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
3903
3904 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
3905 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 160))), vk_zero_point);
3906 i5 += 16;
3907
3908 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
3909
3910 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
3911 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 192))), vk_zero_point);
3912 i6 += 16;
3913
3914 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
3915
3916 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
3917 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 224))), vk_zero_point);
3918 i7 += 16;
3919
3920 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
3921
3922 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
3923 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 256))), vk_zero_point);
3924 i8 += 16;
3925
3926 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
3927
3928 const __m512i vi9x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i9));
3929 const __m512i vk9x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 288))), vk_zero_point);
3930 i9 += 16;
3931
3932 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi9x0123456789ABCDEF, vk9x0123456789ABCDEF));
3933
3934 const __m512i vi10x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i10));
3935 const __m512i vk10x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 320))), vk_zero_point);
3936 i10 += 16;
3937
3938 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi10x0123456789ABCDEF, vk10x0123456789ABCDEF));
3939
3940 const __m512i vi11x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i11));
3941 const __m512i vk11x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 352))), vk_zero_point);
3942 i11 += 16;
3943
3944 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi11x0123456789ABCDEF, vk11x0123456789ABCDEF));
3945
3946 const __m512i vi12x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i12));
3947 const __m512i vk12x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 384))), vk_zero_point);
3948 i12 += 16;
3949
3950 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi12x0123456789ABCDEF, vk12x0123456789ABCDEF));
3951
3952 const __m512i vi13x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i13));
3953 const __m512i vk13x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 416))), vk_zero_point);
3954 i13 += 16;
3955
3956 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi13x0123456789ABCDEF, vk13x0123456789ABCDEF));
3957
3958 const __m512i vi14x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i14));
3959 const __m512i vk14x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 448))), vk_zero_point);
3960 i14 += 16;
3961
3962 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi14x0123456789ABCDEF, vk14x0123456789ABCDEF));
3963
3964 const __m512i vi15x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i15));
3965 const __m512i vk15x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 480))), vk_zero_point);
3966 i15 += 16;
3967
3968 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi15x0123456789ABCDEF, vk15x0123456789ABCDEF));
3969
3970 const __m512i vi16x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i16));
3971 const __m512i vk16x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 512))), vk_zero_point);
3972 i16 += 16;
3973
3974 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi16x0123456789ABCDEF, vk16x0123456789ABCDEF));
3975
3976 const __m512i vi17x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i17));
3977 const __m512i vk17x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 544))), vk_zero_point);
3978 i17 += 16;
3979
3980 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi17x0123456789ABCDEF, vk17x0123456789ABCDEF));
3981
3982 const __m512i vi18x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i18));
3983 const __m512i vk18x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 576))), vk_zero_point);
3984 i18 += 16;
3985
3986 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi18x0123456789ABCDEF, vk18x0123456789ABCDEF));
3987
3988 const __m512i vi19x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i19));
3989 const __m512i vk19x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 608))), vk_zero_point);
3990 i19 += 16;
3991
3992 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi19x0123456789ABCDEF, vk19x0123456789ABCDEF));
3993
3994 const __m512i vi20x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i20));
3995 const __m512i vk20x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 640))), vk_zero_point);
3996 i20 += 16;
3997
3998 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi20x0123456789ABCDEF, vk20x0123456789ABCDEF));
3999
4000 const __m512i vi21x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i21));
4001 const __m512i vk21x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 672))), vk_zero_point);
4002 i21 += 16;
4003
4004 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi21x0123456789ABCDEF, vk21x0123456789ABCDEF));
4005
4006 const __m512i vi22x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i22));
4007 const __m512i vk22x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 704))), vk_zero_point);
4008 i22 += 16;
4009
4010 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi22x0123456789ABCDEF, vk22x0123456789ABCDEF));
4011
4012 const __m512i vi23x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i23));
4013 const __m512i vk23x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 736))), vk_zero_point);
4014 i23 += 16;
4015
4016 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi23x0123456789ABCDEF, vk23x0123456789ABCDEF));
4017
4018 const __m512i vi24x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i24));
4019 const __m512i vk24x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 768))), vk_zero_point);
4020 i24 += 16;
4021
4022 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi24x0123456789ABCDEF, vk24x0123456789ABCDEF));
4023
4024 k += 16;
4025
4026 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4027 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4028 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4029 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4030
4031 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
4032
4033 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
4034
4035 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
4036 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
4037 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
4038 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
4039
4040 if XNN_LIKELY(c >= 16) {
4041 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4042 output += 16;
4043 c -= 16;
4044 } else {
4045 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
4046 output = (uint8_t*) ((uintptr_t) output + c);
4047 c = 0;
4048 }
4049 } while (c != 0);
4050 }
4051
4052 output = (uint8_t*) ((uintptr_t) output + output_increment);
4053 } while (--output_width != 0);
4054 }
4055
xnn_qu8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4056 void xnn_qu8_dwconv_minmax_fp32_ukernel_up32x9__avx512skx_mul32(
4057 size_t channels,
4058 size_t output_width,
4059 const uint8_t** input,
4060 const void* weights,
4061 uint8_t* output,
4062 size_t input_stride,
4063 size_t output_increment,
4064 size_t input_offset,
4065 const uint8_t* zero,
4066 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_MSAN
4067 {
4068 assert(channels != 0);
4069 assert(output_width != 0);
4070
4071 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4072 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4073 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4074 const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
4075 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0);
4076
4077 const __m512i vk_zero_point = _mm512_cvtepu16_epi32(_mm256_load_si256((const __m256i*) params->fp32_avx512.kernel_zero_point));
4078 do {
4079 const uint8_t* i0 = input[0];
4080 assert(i0 != NULL);
4081 if XNN_UNPREDICTABLE(i0 != zero) {
4082 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
4083 }
4084 const uint8_t* i1 = input[1];
4085 assert(i1 != NULL);
4086 if XNN_UNPREDICTABLE(i1 != zero) {
4087 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
4088 }
4089 const uint8_t* i2 = input[2];
4090 assert(i2 != NULL);
4091 if XNN_UNPREDICTABLE(i2 != zero) {
4092 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
4093 }
4094 const uint8_t* i3 = input[3];
4095 assert(i3 != NULL);
4096 if XNN_UNPREDICTABLE(i3 != zero) {
4097 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
4098 }
4099 const uint8_t* i4 = input[4];
4100 assert(i4 != NULL);
4101 if XNN_UNPREDICTABLE(i4 != zero) {
4102 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
4103 }
4104 const uint8_t* i5 = input[5];
4105 assert(i5 != NULL);
4106 if XNN_UNPREDICTABLE(i5 != zero) {
4107 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
4108 }
4109 const uint8_t* i6 = input[6];
4110 assert(i6 != NULL);
4111 if XNN_UNPREDICTABLE(i6 != zero) {
4112 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
4113 }
4114 const uint8_t* i7 = input[7];
4115 assert(i7 != NULL);
4116 if XNN_UNPREDICTABLE(i7 != zero) {
4117 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
4118 }
4119 const uint8_t* i8 = input[8];
4120 assert(i8 != NULL);
4121 if XNN_UNPREDICTABLE(i8 != zero) {
4122 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
4123 }
4124 input = (const uint8_t**) ((uintptr_t) input + input_stride);
4125
4126 size_t c = channels;
4127 const void* w = weights;
4128 for (; c >= 32; c -= 32) {
4129 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
4130 __m512i vaccGHIJKLMNOPQRSTUV = _mm512_loadu_si512((const void*) ((uintptr_t) w + 16 * sizeof(int32_t)));
4131
4132
4133 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
4134 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 0 * sizeof(uint8_t)))), vk_zero_point);
4135 const __m512i vi0xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i0 + 16)));
4136 const __m512i vk0xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 16 * sizeof(uint8_t)))), vk_zero_point);
4137 i0 += 32;
4138
4139 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
4140 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi0xGHIJKLMNOPQRSTUV, vk0xGHIJKLMNOPQRSTUV));
4141
4142 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
4143 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 32 * sizeof(uint8_t)))), vk_zero_point);
4144 const __m512i vi1xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i1 + 16)));
4145 const __m512i vk1xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 48 * sizeof(uint8_t)))), vk_zero_point);
4146 i1 += 32;
4147
4148 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
4149 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi1xGHIJKLMNOPQRSTUV, vk1xGHIJKLMNOPQRSTUV));
4150
4151 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
4152 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 64 * sizeof(uint8_t)))), vk_zero_point);
4153 const __m512i vi2xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i2 + 16)));
4154 const __m512i vk2xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 80 * sizeof(uint8_t)))), vk_zero_point);
4155 i2 += 32;
4156
4157 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
4158 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi2xGHIJKLMNOPQRSTUV, vk2xGHIJKLMNOPQRSTUV));
4159
4160 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
4161 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 96 * sizeof(uint8_t)))), vk_zero_point);
4162 const __m512i vi3xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i3 + 16)));
4163 const __m512i vk3xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 112 * sizeof(uint8_t)))), vk_zero_point);
4164 i3 += 32;
4165
4166 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
4167 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi3xGHIJKLMNOPQRSTUV, vk3xGHIJKLMNOPQRSTUV));
4168
4169 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
4170 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 128 * sizeof(uint8_t)))), vk_zero_point);
4171 const __m512i vi4xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i4 + 16)));
4172 const __m512i vk4xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 144 * sizeof(uint8_t)))), vk_zero_point);
4173 i4 += 32;
4174
4175 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
4176 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi4xGHIJKLMNOPQRSTUV, vk4xGHIJKLMNOPQRSTUV));
4177
4178 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
4179 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 160 * sizeof(uint8_t)))), vk_zero_point);
4180 const __m512i vi5xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i5 + 16)));
4181 const __m512i vk5xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 176 * sizeof(uint8_t)))), vk_zero_point);
4182 i5 += 32;
4183
4184 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
4185 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi5xGHIJKLMNOPQRSTUV, vk5xGHIJKLMNOPQRSTUV));
4186
4187 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
4188 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 192 * sizeof(uint8_t)))), vk_zero_point);
4189 const __m512i vi6xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i6 + 16)));
4190 const __m512i vk6xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 208 * sizeof(uint8_t)))), vk_zero_point);
4191 i6 += 32;
4192
4193 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
4194 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi6xGHIJKLMNOPQRSTUV, vk6xGHIJKLMNOPQRSTUV));
4195
4196 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
4197 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 224 * sizeof(uint8_t)))), vk_zero_point);
4198 const __m512i vi7xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i7 + 16)));
4199 const __m512i vk7xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 240 * sizeof(uint8_t)))), vk_zero_point);
4200 i7 += 32;
4201
4202 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
4203 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi7xGHIJKLMNOPQRSTUV, vk7xGHIJKLMNOPQRSTUV));
4204
4205 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
4206 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 256 * sizeof(uint8_t)))), vk_zero_point);
4207 const __m512i vi8xGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (i8 + 16)));
4208 const __m512i vk8xGHIJKLMNOPQRSTUV = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_load_si128((const __m128i*) ((uintptr_t) w + 32 * sizeof(int32_t) + 272 * sizeof(uint8_t)))), vk_zero_point);
4209 i8 += 32;
4210
4211 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
4212 vaccGHIJKLMNOPQRSTUV = _mm512_add_epi32(vaccGHIJKLMNOPQRSTUV, _mm512_mullo_epi32(vi8xGHIJKLMNOPQRSTUV, vk8xGHIJKLMNOPQRSTUV));
4213
4214 w = (const void*) ((uintptr_t) w + 32 * sizeof(int32_t) + 288 * sizeof(uint8_t));
4215
4216 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4217 __m512 vscaledGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vaccGHIJKLMNOPQRSTUV);
4218
4219 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4220 vscaledGHIJKLMNOPQRSTUV = _mm512_mul_ps(vscaledGHIJKLMNOPQRSTUV, vscale);
4221
4222 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4223 vscaledGHIJKLMNOPQRSTUV = _mm512_min_ps(vscaledGHIJKLMNOPQRSTUV, voutput_max_less_zero_point);
4224
4225 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4226 vaccGHIJKLMNOPQRSTUV = _mm512_cvtps_epi32(vscaledGHIJKLMNOPQRSTUV);
4227
4228 __m512i vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV = _mm512_adds_epi16(_mm512_packs_epi32(vacc0123456789ABCDEF, vaccGHIJKLMNOPQRSTUV), voutput_zero_point);
4229 __m256i voutGHIJOPQRKLMNSTUV = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vaccGHIJKLMNOPQRSTUV), _mm512_extracti32x8_epi32(vaccGHIJKLMNOPQRSTUV, 1)), _mm512_castsi512_si256(voutput_zero_point));
4230
4231 const __m256i vout0123GHIJ4567KLMN = _mm512_castsi512_si256(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV);
4232 const __m256i vout89ABOPQRCDEFSTUV = _mm512_extracti32x8_epi32(vout0123GHIJ4567KLMN89ABOPQRCDEFSTUV, 1);
4233 const __m256i vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV = _mm256_packus_epi16(vout0123GHIJ4567KLMN, vout89ABOPQRCDEFSTUV);
4234 __m256i vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_permutevar8x32_epi32(vout0123GHIJ89ABOPQR4567KLMNCDEFSTUV, vpermute_mask);
4235 const __m128i voutGHIJOPQR = _mm256_castsi256_si128(voutGHIJOPQRKLMNSTUV);
4236 const __m128i voutKLMNSTUV = _mm256_extracti128_si256(voutGHIJOPQRKLMNSTUV, 1);
4237 __m128i voutGHIJKLMNOPQRSTUV = _mm_shuffle_epi32(_mm_packus_epi16(voutGHIJOPQR, voutKLMNSTUV), _MM_SHUFFLE(3, 1, 2, 0));
4238
4239 vout0123456789ABCDEFGHIJKLMNOPQRSTUV = _mm256_max_epu8(vout0123456789ABCDEFGHIJKLMNOPQRSTUV, voutput_min);
4240 voutGHIJKLMNOPQRSTUV = _mm_max_epu8(voutGHIJKLMNOPQRSTUV, _mm256_castsi256_si128(voutput_min));
4241
4242 _mm256_storeu_si256((__m256i*) output, vout0123456789ABCDEFGHIJKLMNOPQRSTUV);
4243 _mm_storeu_si128((__m128i*) (output + 16), voutGHIJKLMNOPQRSTUV);
4244 output += 32;
4245 }
4246 if XNN_UNLIKELY(c != 0) {
4247 // Prepare mask for valid 8-bit elements (depends on nc).
4248 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << (c & 15)) - UINT32_C(1)));
4249 const uint8_t* k = (const uint8_t*) ((uintptr_t) w + 32 * sizeof(int32_t));
4250 do {
4251 __m512i vacc0123456789ABCDEF = _mm512_loadu_si512(w);
4252
4253
4254 const __m512i vi0x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i0));
4255 const __m512i vk0x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) k)), vk_zero_point);
4256 i0 += 16;
4257
4258 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi0x0123456789ABCDEF, vk0x0123456789ABCDEF));
4259
4260 const __m512i vi1x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i1));
4261 const __m512i vk1x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 32))), vk_zero_point);
4262 i1 += 16;
4263
4264 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi1x0123456789ABCDEF, vk1x0123456789ABCDEF));
4265
4266 const __m512i vi2x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i2));
4267 const __m512i vk2x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 64))), vk_zero_point);
4268 i2 += 16;
4269
4270 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi2x0123456789ABCDEF, vk2x0123456789ABCDEF));
4271
4272 const __m512i vi3x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i3));
4273 const __m512i vk3x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 96))), vk_zero_point);
4274 i3 += 16;
4275
4276 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi3x0123456789ABCDEF, vk3x0123456789ABCDEF));
4277
4278 const __m512i vi4x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i4));
4279 const __m512i vk4x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 128))), vk_zero_point);
4280 i4 += 16;
4281
4282 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi4x0123456789ABCDEF, vk4x0123456789ABCDEF));
4283
4284 const __m512i vi5x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i5));
4285 const __m512i vk5x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 160))), vk_zero_point);
4286 i5 += 16;
4287
4288 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi5x0123456789ABCDEF, vk5x0123456789ABCDEF));
4289
4290 const __m512i vi6x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i6));
4291 const __m512i vk6x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 192))), vk_zero_point);
4292 i6 += 16;
4293
4294 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi6x0123456789ABCDEF, vk6x0123456789ABCDEF));
4295
4296 const __m512i vi7x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i7));
4297 const __m512i vk7x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 224))), vk_zero_point);
4298 i7 += 16;
4299
4300 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi7x0123456789ABCDEF, vk7x0123456789ABCDEF));
4301
4302 const __m512i vi8x0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) i8));
4303 const __m512i vk8x0123456789ABCDEF = _mm512_sub_epi32(_mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (k + 256))), vk_zero_point);
4304 i8 += 16;
4305
4306 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vi8x0123456789ABCDEF, vk8x0123456789ABCDEF));
4307
4308 k += 16;
4309
4310 __m512 vscaled0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0123456789ABCDEF);
4311 vscaled0123456789ABCDEF = _mm512_mul_ps(vscaled0123456789ABCDEF, vscale);
4312 vscaled0123456789ABCDEF = _mm512_min_ps(vscaled0123456789ABCDEF, voutput_max_less_zero_point);
4313 vacc0123456789ABCDEF = _mm512_cvtps_epi32(vscaled0123456789ABCDEF);
4314
4315 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
4316
4317 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), _mm512_castsi512_si256(voutput_zero_point));
4318
4319 const __m128i vout012389AB = _mm256_castsi256_si128(vout012389AB4567CDEF);
4320 const __m128i vout4567CDEF = _mm256_extracti128_si256(vout012389AB4567CDEF, 1);
4321 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(vout012389AB, vout4567CDEF), _MM_SHUFFLE(3, 1, 2, 0));
4322 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, _mm256_castsi256_si128(voutput_min));
4323
4324 if XNN_LIKELY(c >= 16) {
4325 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
4326 output += 16;
4327 c -= 16;
4328 } else {
4329 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
4330 output = (uint8_t*) ((uintptr_t) output + c);
4331 c = 0;
4332 }
4333 } while (c != 0);
4334 }
4335
4336 output = (uint8_t*) ((uintptr_t) output + output_increment);
4337 } while (--output_width != 0);
4338 }
4339
xnn_qu8_f32_vcvt_ukernel__avx512skx_x32(size_t n,const uint8_t * x,float * y,const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS (1)])4340 void xnn_qu8_f32_vcvt_ukernel__avx512skx_x32(
4341 size_t n,
4342 const uint8_t* x,
4343 float* y,
4344 const union xnn_qu8_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4345 {
4346 assert(n != 0);
4347 assert(n % sizeof(uint8_t) == 0);
4348 assert(x != NULL);
4349 assert(y != NULL);
4350
4351 const __m512i vminus_zero_point = _mm512_load_si512(params->avx512.minus_zero_point);
4352 const __m512 vscale = _mm512_load_ps(params->avx512.scale);
4353 for (; n >= 32 * sizeof(uint8_t); n -= 32 * sizeof(uint8_t)) {
4354 __m512i vx0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) x));
4355 __m512i vxGHIJKLMNOPQRSTUV = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) (x + 16)));
4356 x += 32;
4357
4358 vx0123456789ABCDEF = _mm512_add_epi32(vx0123456789ABCDEF, vminus_zero_point);
4359 vxGHIJKLMNOPQRSTUV = _mm512_add_epi32(vxGHIJKLMNOPQRSTUV, vminus_zero_point);
4360
4361 __m512 vy0123456789ABCDEF = _mm512_cvtepi32_ps(vx0123456789ABCDEF);
4362 __m512 vyGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vxGHIJKLMNOPQRSTUV);
4363
4364 vy0123456789ABCDEF = _mm512_mul_ps(vy0123456789ABCDEF, vscale);
4365 vyGHIJKLMNOPQRSTUV = _mm512_mul_ps(vyGHIJKLMNOPQRSTUV, vscale);
4366
4367 _mm512_storeu_ps(y, vy0123456789ABCDEF);
4368 _mm512_storeu_ps(y + 16, vyGHIJKLMNOPQRSTUV);
4369 y += 32;
4370 }
4371 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
4372 __m512i vx = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) x));
4373 vx = _mm512_add_epi32(vx, vminus_zero_point);
4374 x += 16;
4375
4376 __m512 vy = _mm512_cvtepi32_ps(vx);
4377 vy = _mm512_mul_ps(vy, vscale);
4378
4379 _mm512_storeu_ps(y, vy);
4380 y += 16;
4381 }
4382 if XNN_UNLIKELY(n != 0) {
4383 assert(n >= 1 * sizeof(uint8_t));
4384 assert(n <= 15 * sizeof(uint8_t));
4385
4386 // Prepare mask for valid elements (depends on n).
4387 const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << n) - UINT32_C(1)));
4388
4389 __m512i vx = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, x));
4390 vx = _mm512_add_epi32(vx, vminus_zero_point);
4391
4392 __m512 vy = _mm512_cvtepi32_ps(vx);
4393 vy = _mm512_mul_ps(vy, vscale);
4394
4395 _mm512_mask_storeu_ps(y, vmask, vy);
4396 }
4397 }
4398
xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4399 void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
4400 size_t mr,
4401 size_t nc,
4402 size_t kc,
4403 const uint8_t* restrict a,
4404 size_t a_stride,
4405 const void* restrict w,
4406 uint8_t* restrict c,
4407 size_t cm_stride,
4408 size_t cn_stride,
4409 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4410 {
4411 assert(mr != 0);
4412 assert(mr <= 1);
4413 assert(nc != 0);
4414 assert(kc != 0);
4415 assert(kc % sizeof(uint8_t) == 0);
4416 assert(a != NULL);
4417 assert(w != NULL);
4418 assert(c != NULL);
4419
4420 kc = round_up_po2(kc, 8);
4421 const uint8_t* a0 = a;
4422 uint8_t* c0 = c;
4423
4424 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4425 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4426 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4427 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
4428 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
4429 do {
4430 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4431 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4432 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4433 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4434 w = (const void*) ((const int32_t*) w + 16);
4435
4436 size_t k = 0;
4437 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4438 while (k < kc) {
4439 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4440 a0 += 8;
4441
4442 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4443
4444 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4445 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4446
4447 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4448 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4449
4450 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4451 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4452
4453 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4454
4455 w = (const void*) ((const uint8_t*) w + 128);
4456 k += 8 * sizeof(uint8_t);
4457 }
4458
4459 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4460 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4461
4462 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4463
4464 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4465
4466 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4467
4468 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4469
4470 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4471
4472 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
4473
4474 const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
4475 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
4476 vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
4477
4478 if (nc >= 16) {
4479 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
4480
4481 a0 = (const uint8_t*) ((uintptr_t) a0 - k);
4482
4483 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4484
4485 nc -= 16;
4486 } else {
4487 // Prepare mask for valid 8-bit elements (depends on nc).
4488 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4489
4490 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
4491
4492 nc = 0;
4493 }
4494 } while (nc != 0);
4495 }
4496
xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,const uint8_t * restrict a,size_t a_stride,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4497 void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
4498 size_t mr,
4499 size_t nc,
4500 size_t kc,
4501 const uint8_t* restrict a,
4502 size_t a_stride,
4503 const void* restrict w,
4504 uint8_t* restrict c,
4505 size_t cm_stride,
4506 size_t cn_stride,
4507 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4508 {
4509 assert(mr != 0);
4510 assert(mr <= 4);
4511 assert(nc != 0);
4512 assert(kc != 0);
4513 assert(kc % sizeof(uint8_t) == 0);
4514 assert(a != NULL);
4515 assert(w != NULL);
4516 assert(c != NULL);
4517
4518 kc = round_up_po2(kc, 8);
4519 const uint8_t* a0 = a;
4520 uint8_t* c0 = c;
4521 const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
4522 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
4523 if XNN_UNPREDICTABLE(mr < 2) {
4524 a1 = a0;
4525 c1 = c0;
4526 }
4527 const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
4528 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
4529 if XNN_UNPREDICTABLE(mr <= 2) {
4530 a2 = a1;
4531 c2 = c1;
4532 }
4533 const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
4534 uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
4535 if XNN_UNPREDICTABLE(mr != 4) {
4536 a3 = a2;
4537 c3 = c2;
4538 }
4539
4540 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4541 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4542 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4543 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4544 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
4545 do {
4546 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4547 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4548 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4549 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4550 __m512i vacc1x0123 = vacc0x0123;
4551 __m512i vacc1x4567 = vacc0x4567;
4552 __m512i vacc1x89AB = vacc0x89AB;
4553 __m512i vacc1xCDEF = vacc0xCDEF;
4554 __m512i vacc2x0123 = vacc0x0123;
4555 __m512i vacc2x4567 = vacc0x4567;
4556 __m512i vacc2x89AB = vacc0x89AB;
4557 __m512i vacc2xCDEF = vacc0xCDEF;
4558 __m512i vacc3x0123 = vacc0x0123;
4559 __m512i vacc3x4567 = vacc0x4567;
4560 __m512i vacc3x89AB = vacc0x89AB;
4561 __m512i vacc3xCDEF = vacc0xCDEF;
4562 w = (const void*) ((const int32_t*) w + 16);
4563
4564 size_t k = 0;
4565 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4566 while (k < kc) {
4567 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4568 a0 += 8;
4569 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
4570 a1 += 8;
4571 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
4572 a2 += 8;
4573 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
4574 a3 += 8;
4575
4576 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4577
4578 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4579 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
4580 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
4581 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
4582 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4583
4584 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4585 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
4586 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
4587 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
4588 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4589
4590 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4591 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
4592 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
4593 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
4594 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4595
4596 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4597 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
4598 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
4599 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
4600
4601 w = (const void*) ((const uint8_t*) w + 128);
4602 k += 8 * sizeof(uint8_t);
4603 }
4604
4605 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4606 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4607 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
4608 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
4609 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
4610 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
4611 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
4612 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
4613
4614 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4615 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
4616 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
4617 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
4618
4619 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4620 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
4621 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
4622 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
4623
4624 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4625 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
4626 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
4627 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
4628
4629 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4630 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
4631 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
4632 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
4633
4634 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4635 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
4636 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
4637 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
4638
4639 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
4640 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
4641
4642 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
4643 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
4644 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
4645 vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
4646
4647 if (nc >= 16) {
4648 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
4649 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
4650 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
4651 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
4652
4653 a0 = (const uint8_t*) ((uintptr_t) a0 - k);
4654 a1 = (const uint8_t*) ((uintptr_t) a1 - k);
4655 a2 = (const uint8_t*) ((uintptr_t) a2 - k);
4656 a3 = (const uint8_t*) ((uintptr_t) a3 - k);
4657
4658 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4659 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
4660 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
4661 c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
4662
4663 nc -= 16;
4664 } else {
4665 // Prepare mask for valid 8-bit elements (depends on nc).
4666 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4667
4668 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
4669 vmask = _kshiftli_mask64(vmask, 16);
4670 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
4671 vmask = _kshiftli_mask64(vmask, 16);
4672 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
4673 vmask = _kshiftli_mask64(vmask, 16);
4674 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
4675
4676 nc = 0;
4677 }
4678 } while (nc != 0);
4679 }
4680
xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4681 void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
4682 size_t mr,
4683 size_t nc,
4684 size_t kc,
4685 size_t ks,
4686 const uint8_t** restrict a,
4687 const void* restrict w,
4688 uint8_t* restrict c,
4689 size_t cm_stride,
4690 size_t cn_stride,
4691 size_t a_offset,
4692 const uint8_t* zero,
4693 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4694 {
4695 assert(mr != 0);
4696 assert(mr <= 1);
4697 assert(nc != 0);
4698 assert(kc != 0);
4699 assert(kc % sizeof(uint8_t) == 0);
4700 assert(a != NULL);
4701 assert(w != NULL);
4702 assert(c != NULL);
4703
4704 kc = round_up_po2(kc, 8);
4705 uint8_t* c0 = c;
4706
4707 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4708 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4709 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4710 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
4711 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
4712 do {
4713 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4714 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4715 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4716 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4717 w = (const void*) ((const int32_t*) w + 16);
4718
4719 size_t p = ks;
4720 do {
4721 const uint8_t* restrict a0 = a[0];
4722 if XNN_UNPREDICTABLE(a0 != zero) {
4723 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
4724 }
4725 a += 1;
4726
4727 size_t k = 0;
4728 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4729 while (k < kc) {
4730 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4731 a0 += 8;
4732
4733 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4734
4735 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4736 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4737
4738 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4739 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4740
4741 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4742 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4743
4744 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4745
4746 w = (const void*) ((const uint8_t*) w + 128);
4747 k += 8 * sizeof(uint8_t);
4748 }
4749 p -= 1 * sizeof(void*);
4750 } while (p != 0);
4751
4752 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4753 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4754
4755 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4756
4757 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4758
4759 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4760
4761 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4762
4763 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4764
4765 const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
4766
4767 const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
4768 __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
4769 vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
4770
4771 if (nc >= 16) {
4772 _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
4773
4774 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4775
4776 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
4777
4778 nc -= 16;
4779 } else {
4780 // Prepare mask for valid 8-bit elements (depends on nc).
4781 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
4782
4783 _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
4784
4785 nc = 0;
4786 }
4787 } while (nc != 0);
4788 }
4789
xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(size_t mr,size_t nc,size_t kc,size_t ks,const uint8_t ** restrict a,const void * restrict w,uint8_t * restrict c,size_t cm_stride,size_t cn_stride,size_t a_offset,const uint8_t * zero,const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4790 void xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
4791 size_t mr,
4792 size_t nc,
4793 size_t kc,
4794 size_t ks,
4795 const uint8_t** restrict a,
4796 const void* restrict w,
4797 uint8_t* restrict c,
4798 size_t cm_stride,
4799 size_t cn_stride,
4800 size_t a_offset,
4801 const uint8_t* zero,
4802 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
4803 {
4804 assert(mr != 0);
4805 assert(mr <= 4);
4806 assert(nc != 0);
4807 assert(kc != 0);
4808 assert(kc % sizeof(uint8_t) == 0);
4809 assert(a != NULL);
4810 assert(w != NULL);
4811 assert(c != NULL);
4812
4813 kc = round_up_po2(kc, 8);
4814 uint8_t* c0 = c;
4815 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
4816 if XNN_UNPREDICTABLE(mr < 2) {
4817 c1 = c0;
4818 }
4819 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
4820 if XNN_UNPREDICTABLE(mr <= 2) {
4821 c2 = c1;
4822 }
4823 uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
4824 if XNN_UNPREDICTABLE(mr != 4) {
4825 c3 = c2;
4826 }
4827
4828 const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
4829 const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
4830 const __m512 voutput_max_less_zero_point = _mm512_load_ps(params->fp32_avx512.output_max_less_zero_point);
4831 const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
4832 const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
4833 do {
4834 __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
4835 __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
4836 __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
4837 __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
4838 __m512i vacc1x0123 = vacc0x0123;
4839 __m512i vacc1x4567 = vacc0x4567;
4840 __m512i vacc1x89AB = vacc0x89AB;
4841 __m512i vacc1xCDEF = vacc0xCDEF;
4842 __m512i vacc2x0123 = vacc0x0123;
4843 __m512i vacc2x4567 = vacc0x4567;
4844 __m512i vacc2x89AB = vacc0x89AB;
4845 __m512i vacc2xCDEF = vacc0xCDEF;
4846 __m512i vacc3x0123 = vacc0x0123;
4847 __m512i vacc3x4567 = vacc0x4567;
4848 __m512i vacc3x89AB = vacc0x89AB;
4849 __m512i vacc3xCDEF = vacc0xCDEF;
4850 w = (const void*) ((const int32_t*) w + 16);
4851
4852 size_t p = ks;
4853 do {
4854 const uint8_t* restrict a0 = a[0];
4855 if XNN_UNPREDICTABLE(a0 != zero) {
4856 a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
4857 }
4858 const uint8_t* restrict a1 = a[1];
4859 if XNN_UNPREDICTABLE(a1 != zero) {
4860 a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
4861 }
4862 const uint8_t* restrict a2 = a[2];
4863 if XNN_UNPREDICTABLE(a2 != zero) {
4864 a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
4865 }
4866 const uint8_t* restrict a3 = a[3];
4867 if XNN_UNPREDICTABLE(a3 != zero) {
4868 a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
4869 }
4870 a += 4;
4871
4872 size_t k = 0;
4873 const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
4874 while (k < kc) {
4875 const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
4876 a0 += 8;
4877 const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
4878 a1 += 8;
4879 const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
4880 a2 += 8;
4881 const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
4882 a3 += 8;
4883
4884 const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
4885
4886 vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
4887 vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
4888 vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
4889 vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
4890 const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
4891
4892 vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
4893 vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
4894 vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
4895 vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
4896 const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
4897
4898 vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
4899 vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
4900 vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
4901 vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
4902 const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
4903
4904 vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
4905 vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
4906 vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
4907 vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
4908
4909 w = (const void*) ((const uint8_t*) w + 128);
4910 k += 8 * sizeof(uint8_t);
4911 }
4912 p -= 4 * sizeof(void*);
4913 } while (p != 0);
4914
4915 const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
4916 const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
4917 const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
4918 const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
4919 const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
4920 const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
4921 const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
4922 const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
4923
4924 __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
4925 __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
4926 __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
4927 __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
4928
4929 __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
4930 __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
4931 __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
4932 __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
4933
4934 vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
4935 vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
4936 vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
4937 vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
4938
4939 vscaled0x084C195D2A6E3B7F = _mm512_min_ps(vscaled0x084C195D2A6E3B7F, voutput_max_less_zero_point);
4940 vscaled1x084C195D2A6E3B7F = _mm512_min_ps(vscaled1x084C195D2A6E3B7F, voutput_max_less_zero_point);
4941 vscaled2x084C195D2A6E3B7F = _mm512_min_ps(vscaled2x084C195D2A6E3B7F, voutput_max_less_zero_point);
4942 vscaled3x084C195D2A6E3B7F = _mm512_min_ps(vscaled3x084C195D2A6E3B7F, voutput_max_less_zero_point);
4943
4944 vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
4945 vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
4946 vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
4947 vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
4948
4949 const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
4950 const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
4951
4952 __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
4953 vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
4954 __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
4955 vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
4956
4957 if (nc >= 16) {
4958 _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
4959 _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
4960 _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
4961 _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
4962
4963 c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
4964 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
4965 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
4966 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
4967
4968 a = (const uint8_t**restrict) ((uintptr_t) a - ks);
4969
4970 nc -= 16;
4971 } else {
4972 // Prepare mask for valid 8-bit elements (depends on nc).
4973 __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
4974
4975 _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
4976 vmask = _kshiftri_mask64(vmask, 16);
4977 _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
4978 vmask = _kshiftri_mask64(vmask, 16);
4979 _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
4980 vmask = _kshiftri_mask64(vmask, 16);
4981 _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
4982
4983 nc = 0;
4984 }
4985 } while (nc != 0);
4986 }
4987
xnn_qu8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])4988 void xnn_qu8_vadd_minmax_ukernel__avx512skx_mul32_ld128_x16(
4989 size_t n,
4990 const uint8_t* input_a,
4991 const uint8_t* input_b,
4992 uint8_t* output,
4993 const union xnn_qu8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
4994 {
4995 const __m512i vbias = _mm512_load_si512(params->avx512.bias);
4996 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
4997 const __m512i vb_multiplier = _mm512_load_si512(params->avx512.b_multiplier);
4998 const __m128i vshift = _mm_loadu_si32(params->avx512.shift);
4999 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
5000 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
5001 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
5002
5003 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
5004 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_a));
5005 const __m512i vb0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_b));
5006 input_a += 16;
5007 input_b += 16;
5008
5009 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5010
5011 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
5012
5013 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5014
5015 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5016
5017 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5018
5019 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5020
5021 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5022
5023 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5024 output += 16;
5025 }
5026 if XNN_UNLIKELY(n != 0) {
5027 {
5028 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
5029 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
5030 const __m512i vb0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_b));
5031
5032 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5033
5034 vacc0123456789ABCDEF = _mm512_add_epi32(vacc0123456789ABCDEF, _mm512_mullo_epi32(vb0123456789ABCDEF, vb_multiplier));
5035
5036 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5037
5038 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5039 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5040 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5041 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5042
5043 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
5044 }
5045 }
5046 }
5047
xnn_qu8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(size_t n,const uint8_t * input_a,const uint8_t * input_b,uint8_t * output,const union xnn_qu8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])5048 void xnn_qu8_vaddc_minmax_ukernel__avx512skx_mul32_ld128_x16(
5049 size_t n,
5050 const uint8_t* input_a,
5051 const uint8_t* input_b,
5052 uint8_t* output,
5053 const union xnn_qu8_addsub_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
5054 {
5055 const __m512i va_multiplier = _mm512_load_si512(params->avx512.a_multiplier);
5056 const __m128i vshift = _mm_loadu_si32(params->avx512.shift);
5057 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx512.output_zero_point);
5058 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx512.output_min);
5059 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
5060
5061 const __m512i vbias = _mm512_add_epi32(
5062 _mm512_broadcastd_epi32(_mm_cvtsi32_si128(params->avx512.b_multiplier[0] * (int32_t) *input_b)),
5063 _mm512_load_si512(params->avx512.bias));
5064 for (; n >= 16 * sizeof(uint8_t); n -= 16 * sizeof(uint8_t)) {
5065 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i*) input_a));
5066 input_a += 16;
5067
5068 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5069
5070 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5071
5072 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5073
5074 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5075
5076 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5077
5078 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5079
5080 _mm_storeu_si128((__m128i*) output, vout0123456789ABCDEF);
5081 output += 16;
5082 }
5083 if XNN_UNLIKELY(n != 0) {
5084 {
5085 const __mmask16 vmask = _cvtu32_mask16((uint32_t) ((UINT32_C(1) << n) - UINT32_C(1)));
5086 const __m512i va0123456789ABCDEF = _mm512_cvtepu8_epi32(_mm_maskz_loadu_epi8(vmask, input_a));
5087
5088 __m512i vacc0123456789ABCDEF = _mm512_add_epi32(vbias, _mm512_mullo_epi32(va0123456789ABCDEF, va_multiplier));
5089
5090 vacc0123456789ABCDEF = _mm512_sra_epi32(vacc0123456789ABCDEF, vshift);
5091
5092 __m256i vout012389AB4567CDEF = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0123456789ABCDEF), _mm512_extracti32x8_epi32(vacc0123456789ABCDEF, 1)), voutput_zero_point);
5093 __m128i vout0123456789ABCDEF = _mm_shuffle_epi32(_mm_packus_epi16(_mm256_castsi256_si128(vout012389AB4567CDEF), _mm256_extracti128_si256(vout012389AB4567CDEF, 1)), _MM_SHUFFLE(3, 1, 2, 0));
5094 vout0123456789ABCDEF = _mm_max_epu8(vout0123456789ABCDEF, voutput_min);
5095 vout0123456789ABCDEF = _mm_min_epu8(vout0123456789ABCDEF, voutput_max);
5096
5097 _mm_mask_storeu_epi8(output, vmask, vout0123456789ABCDEF);
5098 }
5099 }
5100 }
5101
xnn_x8_lut_ukernel__avx512skx_vpshufb_x64(size_t n,const uint8_t * x,uint8_t * y,const uint8_t t[restrict XNN_MIN_ELEMENTS (256)])5102 void xnn_x8_lut_ukernel__avx512skx_vpshufb_x64(
5103 size_t n,
5104 const uint8_t* x,
5105 uint8_t* y,
5106 const uint8_t t[restrict XNN_MIN_ELEMENTS(256)])
5107 {
5108 assert(n != 0);
5109 assert(x != NULL);
5110 assert(y != NULL);
5111
5112 const __m512i vt0 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) t));
5113 const __m512i vt1 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 16)));
5114 const __m512i vt2 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 32)));
5115 const __m512i vt3 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 48)));
5116 const __m512i vt4 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 64)));
5117 const __m512i vt5 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 80)));
5118 const __m512i vt6 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 96)));
5119 const __m512i vt7 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 112)));
5120 const __m512i vt8 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 128)));
5121 const __m512i vt9 = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 144)));
5122 const __m512i vtA = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 160)));
5123 const __m512i vtB = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 176)));
5124 const __m512i vtC = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 192)));
5125 const __m512i vtD = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 208)));
5126 const __m512i vtE = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 224)));
5127 const __m512i vtF = _mm512_broadcast_i32x4(_mm_load_si128((const __m128i*) (t + 240)));
5128
5129 const __m512i vtable0 = vt0;
5130 const __m512i vtable1 = _mm512_xor_si512(vt0, vt1);
5131 const __m512i vtable2 = _mm512_xor_si512(vt1, vt2);
5132 const __m512i vtable3 = _mm512_xor_si512(vt2, vt3);
5133 const __m512i vtable4 = _mm512_xor_si512(vt3, vt4);
5134 const __m512i vtable5 = _mm512_xor_si512(vt4, vt5);
5135 const __m512i vtable6 = _mm512_xor_si512(vt5, vt6);
5136 const __m512i vtable7 = _mm512_xor_si512(vt6, vt7);
5137 const __m512i vtable8 = _mm512_xor_si512(_mm512_xor_si512(vt7, vt8), vtable0);
5138 const __m512i vtable9 = _mm512_xor_si512(_mm512_xor_si512(vt8, vt9), vtable1);
5139 const __m512i vtableA = _mm512_xor_si512(_mm512_xor_si512(vt9, vtA), vtable2);
5140 const __m512i vtableB = _mm512_xor_si512(_mm512_xor_si512(vtA, vtB), vtable3);
5141 const __m512i vtableC = _mm512_xor_si512(_mm512_xor_si512(vtB, vtC), vtable4);
5142 const __m512i vtableD = _mm512_xor_si512(_mm512_xor_si512(vtC, vtD), vtable5);
5143 const __m512i vtableE = _mm512_xor_si512(_mm512_xor_si512(vtD, vtE), vtable6);
5144 const __m512i vtableF = _mm512_xor_si512(_mm512_xor_si512(vtE, vtF), vtable7);
5145
5146 const __m512i voffset = _mm512_set1_epi8(16);
5147 for (; n >= 64 * sizeof(uint8_t); n -= 64 * sizeof(uint8_t)) {
5148 __m512i vx = _mm512_loadu_si512(x);
5149 x += 64;
5150
5151 __m512i vy = _mm512_shuffle_epi8(vtable0, vx);
5152
5153 vx = _mm512_sub_epi8(vx, voffset);
5154 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable1, vx));
5155 vx = _mm512_sub_epi8(vx, voffset);
5156 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable2, vx));
5157 vx = _mm512_sub_epi8(vx, voffset);
5158 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable3, vx));
5159 vx = _mm512_sub_epi8(vx, voffset);
5160 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable4, vx));
5161 vx = _mm512_sub_epi8(vx, voffset);
5162 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable5, vx));
5163 vx = _mm512_sub_epi8(vx, voffset);
5164 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable6, vx));
5165 vx = _mm512_sub_epi8(vx, voffset);
5166 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable7, vx));
5167 vx = _mm512_sub_epi8(vx, voffset);
5168 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable8, vx));
5169
5170 vx = _mm512_subs_epi8(vx, voffset);
5171 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable9, vx));
5172 vx = _mm512_subs_epi8(vx, voffset);
5173 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableA, vx));
5174 vx = _mm512_subs_epi8(vx, voffset);
5175 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableB, vx));
5176 vx = _mm512_subs_epi8(vx, voffset);
5177 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableC, vx));
5178 vx = _mm512_subs_epi8(vx, voffset);
5179 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableD, vx));
5180 vx = _mm512_subs_epi8(vx, voffset);
5181 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableE, vx));
5182 vx = _mm512_subs_epi8(vx, voffset);
5183 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableF, vx));
5184
5185 _mm512_storeu_si512(y, vy);
5186 y += 64;
5187 }
5188 if XNN_UNLIKELY(n != 0) {
5189 assert(n < 64);
5190 const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << n) - UINT64_C(1)));
5191
5192 __m512i vx = _mm512_maskz_loadu_epi8(vmask, x);
5193
5194 __m512i vy = _mm512_shuffle_epi8(vtable0, vx);
5195
5196 vx = _mm512_sub_epi8(vx, voffset);
5197 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable1, vx));
5198 vx = _mm512_sub_epi8(vx, voffset);
5199 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable2, vx));
5200 vx = _mm512_sub_epi8(vx, voffset);
5201 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable3, vx));
5202 vx = _mm512_sub_epi8(vx, voffset);
5203 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable4, vx));
5204 vx = _mm512_sub_epi8(vx, voffset);
5205 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable5, vx));
5206 vx = _mm512_sub_epi8(vx, voffset);
5207 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable6, vx));
5208 vx = _mm512_sub_epi8(vx, voffset);
5209 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable7, vx));
5210 vx = _mm512_sub_epi8(vx, voffset);
5211 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable8, vx));
5212
5213 vx = _mm512_subs_epi8(vx, voffset);
5214 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtable9, vx));
5215 vx = _mm512_subs_epi8(vx, voffset);
5216 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableA, vx));
5217 vx = _mm512_subs_epi8(vx, voffset);
5218 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableB, vx));
5219 vx = _mm512_subs_epi8(vx, voffset);
5220 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableC, vx));
5221 vx = _mm512_subs_epi8(vx, voffset);
5222 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableD, vx));
5223 vx = _mm512_subs_epi8(vx, voffset);
5224 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableE, vx));
5225 vx = _mm512_subs_epi8(vx, voffset);
5226 vy = _mm512_xor_si512(vy, _mm512_shuffle_epi8(vtableF, vx));
5227
5228 _mm512_mask_storeu_epi8(y, vmask, vy);
5229 }
5230 }
5231