1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14
15 #include <xnnpack.h>
16 #include <xnnpack/common.h>
17
18
19 union xnn_f16_default_params {
20 // Empty; serves to differentiate pointer types for micro-kernels without fused activation.
21 char _; // Dummy member variable to comply with the C standard
22 };
23
24 // scaleminmax is used for gemm/igemm ukernels.
25 union xnn_f16_scaleminmax_params {
26 // Empty; serves to differentiate pointer types for micro-kernels without fused activation.
27 char _; // Dummy member variable to comply with the C standard
28 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
29 struct {
30 uint16_t scale;
31 uint16_t min;
32 uint16_t max;
33 uint16_t pad; // pad to 8 bytes for neonfp16arith assembly.
34 } neon;
35 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
36 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
37 struct {
38 XNN_ALIGN(32) float scale[8];
39 XNN_ALIGN(32) float min[8];
40 XNN_ALIGN(32) float max[8];
41 } avx;
42 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
43 };
44
45 union xnn_f16_minmax_params {
46 // Empty; serves to differentiate pointer types for micro-kernels without fused activation.
47 char _; // Dummy member variable to comply with the C standard
48 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
49 struct {
50 uint16_t min;
51 uint16_t max;
52 } neon;
53 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
54 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
55 struct {
56 XNN_ALIGN(32) float min[8];
57 XNN_ALIGN(32) float max[8];
58 } avx;
59 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
60 };
61
62 union xnn_f32_default_params {
63 // Empty; serves to differentiate pointer types for micro-kernels without fused activation.
64 char _; // Dummy member variable to comply with the C standard
65 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
66 struct {
67 int32_t mask_table[14];
68 } avx;
69 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
70 };
71
72 union xnn_f32_relu_params {
73 // Empty; serves to differentiate pointer types for micro-kernels with different fused activations.
74 char _; // Dummy member variable to comply with the C standard
75 };
76
77 union xnn_f32_minmax_params {
78 struct {
79 float min;
80 float max;
81 } scalar;
82 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
83 struct {
84 XNN_ALIGN(16) float min[4];
85 XNN_ALIGN(16) float max[4];
86 } sse;
87 struct {
88 XNN_ALIGN(32) float min[8];
89 XNN_ALIGN(32) float max[8];
90 int32_t mask_table[14];
91 } avx;
92 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
93 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
94 struct {
95 XNN_ALIGN(8) float min[2];
96 XNN_ALIGN(8) float max[2];
97 } wasmsimd;
98 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
99 };
100
101 union xnn_f32_abs_params {
102 char _; // Dummy member variable to comply with the C standard
103 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
104 struct {
105 XNN_ALIGN(16) float nonsign_mask[4];
106 } sse;
107 struct {
108 XNN_ALIGN(32) float nonsign_mask[8];
109 int32_t mask_table[14];
110 } avx;
111 struct {
112 uint32_t nonsign_mask;
113 } avx512;
114 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
115 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
116 struct {
117 XNN_ALIGN(8) float nonsign_mask[2];
118 } wasmsimd;
119 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
120 };
121
122 union xnn_f32_neg_params {
123 char _; // Dummy member variable to comply with the C standard
124 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
125 struct {
126 XNN_ALIGN(16) float sign_mask[4];
127 } sse;
128 struct {
129 XNN_ALIGN(32) float sign_mask[8];
130 int32_t mask_table[14];
131 } avx;
132 struct {
133 uint32_t sign_mask;
134 } avx512;
135 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
136 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
137 struct {
138 XNN_ALIGN(8) float sign_mask[2];
139 } wasmsimd;
140 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
141 };
142
143 union xnn_f32_rnd_params {
144 char _; // Dummy member variable to comply with the C standard
145 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
146 struct {
147 XNN_ALIGN(16) float sign_mask[4];
148 XNN_ALIGN(16) float one[4];
149 } sse2;
150 struct {
151 int32_t mask_table[14];
152 } avx;
153 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
154 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
155 struct {
156 XNN_ALIGN(8) float sign_mask[2];
157 XNN_ALIGN(8) float magic_bias[2];
158 XNN_ALIGN(8) float one[2];
159 } wasmsimd;
160 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
161 };
162
163 union xnn_f32_elu_params {
164 struct {
165 float prescale;
166 float alpha;
167 float beta;
168 float sat_cutoff;
169 float magic_bias;
170 float log2e;
171 float minus_ln2_hi;
172 float minus_ln2_lo;
173 float c3;
174 float c2;
175 float one;
176 } scalar_rr2_lut16_p3;
177 struct {
178 float prescale;
179 float alpha;
180 float beta;
181 float sat_cutoff;
182 float magic_bias;
183 float log2e;
184 float minus_ln2_hi;
185 float minus_ln2_lo;
186 float c6;
187 float c5;
188 float c4;
189 float c3;
190 float c2;
191 float one;
192 } scalar_rr2_p6;
193 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
194 struct {
195 float prescale;
196 float alpha;
197 float beta;
198 float sat_cutoff;
199 float magic_bias;
200 float log2e;
201 float minus_ln2_hi;
202 float minus_ln2_lo;
203 float c6;
204 float c5;
205 float c4;
206 float c3;
207 float c2;
208 } neon_rr2_p6;
209 struct {
210 float prescale;
211 float alpha;
212 float beta;
213 float sat_cutoff;
214 float magic_bias;
215 float log2e;
216 float minus_ln2_hi;
217 float minus_ln2_lo;
218 float c3;
219 float c2;
220 } neon_rr2_lut16_p3;
221 struct {
222 float prescale;
223 float alpha;
224 float beta;
225 float sat_cutoff;
226 float magic_bias;
227 float log2e;
228 float minus_ln2;
229 float c6;
230 float c5;
231 float c4;
232 float c3;
233 float c2;
234 } neonfma_rr1_p6;
235 struct {
236 float prescale;
237 float alpha;
238 float beta;
239 float sat_cutoff;
240 float magic_bias;
241 float log2e;
242 float minus_ln2;
243 float c3;
244 float c2;
245 } neonfma_rr1_lut16_p3;
246 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
247 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
248 struct {
249 XNN_ALIGN(16) float prescale[4];
250 XNN_ALIGN(16) float alpha[4];
251 XNN_ALIGN(16) float beta[4];
252 XNN_ALIGN(16) float sat_cutoff[4];
253 XNN_ALIGN(16) float magic_bias[4];
254 XNN_ALIGN(16) float log2e[4];
255 XNN_ALIGN(16) uint32_t index_mask[4];
256 XNN_ALIGN(16) float minus_ln2_hi[4];
257 XNN_ALIGN(16) float minus_ln2_lo[4];
258 XNN_ALIGN(16) float c3[4];
259 XNN_ALIGN(16) float c2[4];
260 XNN_ALIGN(16) float one[4];
261 } sse2_rr2_lut16_p3;
262 struct {
263 XNN_ALIGN(16) float prescale[4];
264 XNN_ALIGN(16) float alpha[4];
265 XNN_ALIGN(16) float beta[4];
266 XNN_ALIGN(16) float sat_cutoff[4];
267 XNN_ALIGN(16) float magic_bias[4];
268 XNN_ALIGN(16) float log2e[4];
269 XNN_ALIGN(16) float minus_ln2_hi[4];
270 XNN_ALIGN(16) float minus_ln2_lo[4];
271 XNN_ALIGN(16) float c6[4];
272 XNN_ALIGN(16) float c5[4];
273 XNN_ALIGN(16) float c4[4];
274 XNN_ALIGN(16) float c3[4];
275 XNN_ALIGN(16) float c2[4];
276 XNN_ALIGN(16) float one[4];
277 } sse2_rr2_p6;
278 struct {
279 XNN_ALIGN(32) float prescale[8];
280 XNN_ALIGN(32) float alpha[8];
281 XNN_ALIGN(32) float beta[8];
282 XNN_ALIGN(32) float sat_cutoff[8];
283 XNN_ALIGN(32) float magic_bias[8];
284 XNN_ALIGN(32) float log2e[8];
285 XNN_ALIGN(32) uint32_t index_mask[8];
286 XNN_ALIGN(32) float minus_ln2_hi[8];
287 XNN_ALIGN(32) float minus_ln2_lo[8];
288 XNN_ALIGN(32) float c3[8];
289 XNN_ALIGN(32) float c2[8];
290 XNN_ALIGN(32) float one[8];
291 int32_t mask_table[14];
292 } avx_rr2_lut16_p3;
293 struct {
294 XNN_ALIGN(32) float prescale[8];
295 XNN_ALIGN(32) float alpha[8];
296 XNN_ALIGN(32) float beta[8];
297 XNN_ALIGN(32) float sat_cutoff[8];
298 XNN_ALIGN(32) float magic_bias[8];
299 XNN_ALIGN(32) float log2e[8];
300 XNN_ALIGN(32) uint32_t index_mask[8];
301 XNN_ALIGN(32) float table[8];
302 XNN_ALIGN(32) float minus_ln2_hi[8];
303 XNN_ALIGN(32) float minus_ln2_lo[8];
304 XNN_ALIGN(32) float c4[8];
305 XNN_ALIGN(32) float c3[8];
306 XNN_ALIGN(32) float c2[8];
307 XNN_ALIGN(32) float one[8];
308 int32_t mask_table[14];
309 } avx_rr2_lut4_p4;
310 struct {
311 XNN_ALIGN(32) float prescale[8];
312 XNN_ALIGN(32) float alpha[8];
313 XNN_ALIGN(32) float beta[8];
314 XNN_ALIGN(32) float sat_cutoff[8];
315 XNN_ALIGN(32) float magic_bias[8];
316 XNN_ALIGN(32) float log2e[8];
317 XNN_ALIGN(32) float minus_ln2_hi[8];
318 XNN_ALIGN(32) float minus_ln2_lo[8];
319 XNN_ALIGN(32) float c6[8];
320 XNN_ALIGN(32) float c5[8];
321 XNN_ALIGN(32) float c4[8];
322 XNN_ALIGN(32) float c3[8];
323 XNN_ALIGN(32) float c2[8];
324 XNN_ALIGN(32) float one[8];
325 int32_t mask_table[14];
326 } avx_rr2_p6;
327 struct {
328 XNN_ALIGN(32) float prescale[8];
329 XNN_ALIGN(32) float alpha[8];
330 XNN_ALIGN(32) float beta[8];
331 XNN_ALIGN(32) float sat_cutoff[8];
332 XNN_ALIGN(32) float magic_bias[8];
333 XNN_ALIGN(32) float log2e[8];
334 XNN_ALIGN(32) uint32_t index_mask[8];
335 XNN_ALIGN(32) float minus_ln2[8];
336 XNN_ALIGN(32) float c3[8];
337 XNN_ALIGN(32) float c2[8];
338 int32_t mask_table[14];
339 } avx2_rr1_lut16_p3;
340 struct {
341 XNN_ALIGN(32) float prescale[8];
342 XNN_ALIGN(32) float alpha[8];
343 XNN_ALIGN(32) float beta[8];
344 XNN_ALIGN(32) float sat_cutoff[8];
345 XNN_ALIGN(32) float magic_bias[8];
346 XNN_ALIGN(32) float log2e[8];
347 XNN_ALIGN(32) uint32_t table[8];
348 XNN_ALIGN(32) float minus_ln2[8];
349 XNN_ALIGN(32) float c4[8];
350 XNN_ALIGN(32) float c3[8];
351 XNN_ALIGN(32) float c2[8];
352 int32_t mask_table[14];
353 } avx2_rr1_lut8_p4;
354 struct {
355 XNN_ALIGN(32) float prescale[8];
356 XNN_ALIGN(32) float alpha[8];
357 XNN_ALIGN(32) float beta[8];
358 XNN_ALIGN(32) float sat_cutoff[8];
359 XNN_ALIGN(32) float magic_bias[8];
360 XNN_ALIGN(32) float log2e[8];
361 XNN_ALIGN(32) float table[8];
362 XNN_ALIGN(32) float minus_ln2[8];
363 XNN_ALIGN(32) float c4[8];
364 XNN_ALIGN(32) float c3[8];
365 XNN_ALIGN(32) float c2[8];
366 int32_t mask_table[14];
367 } avx2_rr1_lut4_p4;
368 struct {
369 XNN_ALIGN(32) float prescale[8];
370 XNN_ALIGN(32) float alpha[8];
371 XNN_ALIGN(32) float beta[8];
372 XNN_ALIGN(32) float sat_cutoff[8];
373 XNN_ALIGN(32) float magic_bias[8];
374 XNN_ALIGN(32) float log2e[8];
375 XNN_ALIGN(32) float minus_ln2[8];
376 XNN_ALIGN(32) float c6[8];
377 XNN_ALIGN(32) float c5[8];
378 XNN_ALIGN(32) float c4[8];
379 XNN_ALIGN(32) float c3[8];
380 XNN_ALIGN(32) float c2[8];
381 int32_t mask_table[14];
382 } avx2_rr1_p6;
383 struct {
384 float prescale;
385 float alpha;
386 float beta;
387 float sat_cutoff;
388 float magic_bias;
389 float log2e;
390 float minus_ln2;
391 float c3;
392 float c2;
393 XNN_ALIGN(64) uint32_t table[16];
394 } avx512_rr1_lut16_p3;
395 struct {
396 float prescale;
397 float alpha;
398 float beta;
399 float sat_cutoff;
400 float magic_bias;
401 float log2e;
402 float minus_ln2;
403 float c6;
404 float c5;
405 float c4;
406 float c3;
407 float c2;
408 } avx512_rr1_p6;
409 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
410 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
411 struct {
412 XNN_ALIGN(8) float prescale[2];
413 XNN_ALIGN(8) float alpha[2];
414 XNN_ALIGN(8) float beta[2];
415 XNN_ALIGN(8) float sat_cutoff[2];
416 XNN_ALIGN(8) float magic_bias[2];
417 XNN_ALIGN(8) float log2e[2];
418 XNN_ALIGN(8) uint32_t index_mask[2];
419 XNN_ALIGN(8) float minus_ln2_hi[2];
420 XNN_ALIGN(8) float minus_ln2_lo[2];
421 XNN_ALIGN(8) float c3[2];
422 XNN_ALIGN(8) float c2[2];
423 XNN_ALIGN(8) float one[2];
424 } wasmsimd_rr2_lut16_p3;
425 struct {
426 XNN_ALIGN(8) float prescale[2];
427 XNN_ALIGN(8) float alpha[2];
428 XNN_ALIGN(8) float beta[2];
429 XNN_ALIGN(8) float sat_cutoff[2];
430 XNN_ALIGN(8) float magic_bias[2];
431 XNN_ALIGN(8) float log2e[2];
432 XNN_ALIGN(8) float minus_ln2_hi[2];
433 XNN_ALIGN(8) float minus_ln2_lo[2];
434 XNN_ALIGN(8) float c6[2];
435 XNN_ALIGN(8) float c5[2];
436 XNN_ALIGN(8) float c4[2];
437 XNN_ALIGN(8) float c3[2];
438 XNN_ALIGN(8) float c2[2];
439 XNN_ALIGN(8) float one[2];
440 } wasmsimd_rr2_p6;
441 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
442 };
443
444 union xnn_f32_expminus_params {
445 struct {
446 float log2e;
447 float magic_bias;
448 float minus_ln2_hi;
449 float minus_ln2_lo;
450 float c5;
451 float c4;
452 float c3;
453 float c2;
454 float c1;
455 float denorm_cutoff;
456 } scalar_rr2_p5;
457 struct {
458 float log2e;
459 float magic_bias;
460 float minus_ln2_hi;
461 float minus_ln2_lo;
462 float c2;
463 float denorm_cutoff;
464 } scalar_rr2_lut64_p2;
465 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
466 struct {
467 float log2e;
468 float magic_bias;
469 float minus_ln2_hi;
470 float minus_ln2_lo;
471 float c5;
472 float c4;
473 float c3;
474 float c2;
475 float c1;
476 float denorm_cutoff;
477 } neon_rr2_p5;
478 struct {
479 float log2e;
480 float magic_bias;
481 float minus_ln2_hi;
482 float minus_ln2_lo;
483 float c2;
484 float denorm_cutoff;
485 } neon_rr2_lut64_p2;
486 struct {
487 float log2e;
488 float magic_bias;
489 float minus_ln2;
490 float c5;
491 float c4;
492 float c3;
493 float c2;
494 float c1;
495 float denorm_cutoff;
496 } neonfma_rr1_p5;
497 struct {
498 float log2e;
499 float magic_bias;
500 float minus_ln2;
501 float c2;
502 float denorm_cutoff;
503 } neonfma_rr1_lut64_p2;
504 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
505 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
506 struct {
507 XNN_ALIGN(16) float log2e[4];
508 XNN_ALIGN(16) float magic_bias[4];
509 XNN_ALIGN(16) float minus_ln2_hi[4];
510 XNN_ALIGN(16) float minus_ln2_lo[4];
511 XNN_ALIGN(16) float c5[4];
512 XNN_ALIGN(16) float c4[4];
513 XNN_ALIGN(16) float c3[4];
514 XNN_ALIGN(16) float c2[4];
515 XNN_ALIGN(16) float c1[4];
516 XNN_ALIGN(16) float denorm_cutoff[4];
517 } sse2_rr2_p5;
518 struct {
519 XNN_ALIGN(32) float log2e[8];
520 XNN_ALIGN(32) float magic_bias[8];
521 XNN_ALIGN(32) float minus_ln2[8];
522 XNN_ALIGN(32) float c5[8];
523 XNN_ALIGN(32) float c4[8];
524 XNN_ALIGN(32) float c3[8];
525 XNN_ALIGN(32) float c2[8];
526 XNN_ALIGN(32) float c1[8];
527 XNN_ALIGN(32) float denorm_cutoff[8];
528 int32_t mask_table[14];
529 } avx2_rr1_p5;
530 struct {
531 float log2e;
532 float minus_ln2;
533 float c5;
534 float c4;
535 float c3;
536 float c2;
537 float c1;
538 float c0;
539 } avx512_rr1_p5;
540 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
541 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
542 struct {
543 XNN_ALIGN(8) float log2e[2];
544 XNN_ALIGN(8) float magic_bias[2];
545 XNN_ALIGN(8) float minus_ln2_hi[2];
546 XNN_ALIGN(8) float minus_ln2_lo[2];
547 XNN_ALIGN(8) float c5[2];
548 XNN_ALIGN(8) float c4[2];
549 XNN_ALIGN(8) float c3[2];
550 XNN_ALIGN(8) float c2[2];
551 XNN_ALIGN(8) float c1[2];
552 XNN_ALIGN(8) float denorm_cutoff[2];
553 } wasmsimd_rr2_p5;
554 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
555 };
556
557 union xnn_f32_lrelu_params {
558 struct {
559 float slope;
560 } scalar;
561 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
562 struct {
563 XNN_ALIGN(16) float slope[4];
564 } sse;
565 struct {
566 XNN_ALIGN(32) float slope[8];
567 int32_t mask_table[14];
568 } avx;
569 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
570 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
571 struct {
572 XNN_ALIGN(8) float slope[2];
573 } wasmsimd;
574 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
575 };
576
577 union xnn_f32_sigmoid_params {
578 struct {
579 float magic_bias;
580 float minus_log2e;
581 float ln2_hi;
582 float ln2_lo;
583 float c1;
584 float one;
585 float denorm_cutoff;
586 } scalar_rr2_lut2048_p1;
587 struct {
588 float magic_bias;
589 float minus_log2e;
590 float ln2_hi;
591 float ln2_lo;
592 float c2;
593 float one;
594 float denorm_cutoff;
595 } scalar_rr2_lut64_p2;
596 struct {
597 float magic_bias;
598 float minus_log2e;
599 float ln2_hi;
600 float ln2_lo;
601 float c5;
602 float c4;
603 float c3;
604 float c2;
605 float c1;
606 float one;
607 float denorm_cutoff;
608 } scalar_rr2_p5;
609 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
610 struct {
611 float magic_bias;
612 float minus_log2e;
613 float ln2_hi;
614 float ln2_lo;
615 float c1;
616 float denorm_cutoff;
617 } neon_rr2_lut2048_p1;
618 struct {
619 float magic_bias;
620 float minus_log2e;
621 float ln2_hi;
622 float ln2_lo;
623 float c2;
624 float denorm_cutoff;
625 } neon_rr2_lut64_p2;
626 struct {
627 float magic_bias;
628 float minus_log2e;
629 float ln2_hi;
630 float ln2_lo;
631 float c5;
632 float c4;
633 float c3;
634 float c2;
635 float c1;
636 float denorm_cutoff;
637 } neon_rr2_p5;
638 struct {
639 float magic_bias;
640 float minus_log2e;
641 float ln2;
642 float c1;
643 float denorm_cutoff;
644 } neonfma_rr1_lut2048_p1;
645 struct {
646 float magic_bias;
647 float minus_log2e;
648 float ln2;
649 float c2;
650 float denorm_cutoff;
651 } neonfma_rr1_lut64_p2;
652 struct {
653 float magic_bias;
654 float minus_log2e;
655 float ln2;
656 float c5;
657 float c4;
658 float c3;
659 float c2;
660 float c1;
661 float denorm_cutoff;
662 } neonfma_rr1_p5;
663 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
664 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
665 struct {
666 XNN_ALIGN(16) float sign_mask[4];
667 XNN_ALIGN(16) float magic_bias[4];
668 XNN_ALIGN(16) float log2e[4];
669 XNN_ALIGN(16) uint32_t index_mask[4];
670 XNN_ALIGN(16) float minus_ln2_hi[4];
671 XNN_ALIGN(16) float minus_ln2_lo[4];
672 XNN_ALIGN(16) float c2[4];
673 XNN_ALIGN(16) float one[4];
674 XNN_ALIGN(16) float denorm_cutoff[4];
675 } sse2_rr2_lut64_p2;
676 struct {
677 XNN_ALIGN(16) float sign_mask[4];
678 XNN_ALIGN(16) float magic_bias[4];
679 XNN_ALIGN(16) float log2e[4];
680 XNN_ALIGN(16) float minus_ln2_hi[4];
681 XNN_ALIGN(16) float minus_ln2_lo[4];
682 XNN_ALIGN(16) float c5[4];
683 XNN_ALIGN(16) float c4[4];
684 XNN_ALIGN(16) float c3[4];
685 XNN_ALIGN(16) float c2[4];
686 XNN_ALIGN(16) float c1[4];
687 XNN_ALIGN(16) float one[4];
688 XNN_ALIGN(16) float denorm_cutoff[4];
689 } sse2_rr2_p5;
690 struct {
691 XNN_ALIGN(32) float sign_mask[8];
692 XNN_ALIGN(32) float magic_bias[8];
693 XNN_ALIGN(32) float log2e[8];
694 XNN_ALIGN(32) float minus_ln2_hi[8];
695 XNN_ALIGN(32) float minus_ln2_lo[8];
696 XNN_ALIGN(32) float c5[8];
697 XNN_ALIGN(32) float c4[8];
698 XNN_ALIGN(32) float c3[8];
699 XNN_ALIGN(32) float c2[8];
700 XNN_ALIGN(32) float c1[8];
701 XNN_ALIGN(32) float one[8];
702 XNN_ALIGN(32) float two[8];
703 XNN_ALIGN(32) float denorm_cutoff[8];
704 int32_t mask_table[14];
705 } avx_rr2_p5;
706 struct {
707 XNN_ALIGN(32) float sign_mask[8];
708 XNN_ALIGN(32) float magic_bias[8];
709 XNN_ALIGN(32) float log2e[8];
710 XNN_ALIGN(32) float minus_ln2[8];
711 XNN_ALIGN(32) float c5[8];
712 XNN_ALIGN(32) float c4[8];
713 XNN_ALIGN(32) float c3[8];
714 XNN_ALIGN(32) float c2[8];
715 XNN_ALIGN(32) float c1[8];
716 XNN_ALIGN(32) float one[8];
717 XNN_ALIGN(32) float denorm_cutoff[8];
718 int32_t mask_table[14];
719 } avx2_rr1_p5;
720 struct {
721 uint32_t sign_mask;
722 float magic_bias;
723 float log2e;
724 float minus_ln2;
725 float c3;
726 float c2;
727 float one;
728 XNN_ALIGN(64) float table[16];
729 } avx512_rr1_lut16_p3;
730 struct {
731 uint32_t sign_mask;
732 float magic_bias;
733 float log2e;
734 float minus_ln2_hi;
735 float minus_ln2_lo;
736 float c2;
737 float c1;
738 float one;
739 XNN_ALIGN(64) float table_lo[16];
740 XNN_ALIGN(64) float table_hi[16];
741 } avx512_rr2_lut32_p2;
742 struct {
743 uint32_t sign_mask;
744 float log2e;
745 float minus_ln2;
746 float c5;
747 float c4;
748 float c3;
749 float c2;
750 float c1;
751 float one;
752 } avx512_rr1_p5;
753 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
754 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
755 struct {
756 XNN_ALIGN(8) float magic_bias[2];
757 XNN_ALIGN(8) float minus_log2e[2];
758 XNN_ALIGN(8) uint32_t index_mask[2];
759 XNN_ALIGN(8) float ln2_hi[2];
760 XNN_ALIGN(8) float ln2_lo[2];
761 XNN_ALIGN(8) float c2[2];
762 XNN_ALIGN(8) float one[2];
763 XNN_ALIGN(8) float denorm_cutoff[2];
764 } wasmsimd_rr2_lut64_p2;
765 struct {
766 XNN_ALIGN(8) float magic_bias[2];
767 XNN_ALIGN(8) float minus_log2e[2];
768 XNN_ALIGN(8) float ln2_hi[2];
769 XNN_ALIGN(8) float ln2_lo[2];
770 XNN_ALIGN(8) float c5[2];
771 XNN_ALIGN(8) float c4[2];
772 XNN_ALIGN(8) float c3[2];
773 XNN_ALIGN(8) float c2[2];
774 XNN_ALIGN(8) float c1[2];
775 XNN_ALIGN(8) float one[2];
776 XNN_ALIGN(8) float denorm_cutoff[2];
777 } wasmsimd_rr2_p5;
778 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
779 };
780
781 union xnn_f32_sqrt_params {
782 char _; // Dummy member variable to comply with the C standard
783 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
784 struct {
785 int32_t mask_table[14];
786 } avx;
787 struct {
788 XNN_ALIGN(32) float half[8];
789 int32_t mask_table[14];
790 } fma;
791 struct {
792 float half;
793 } avx512;
794 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
795 };
796
797 union xnn_f32_chw_params {
798 struct {
799 XNN_ALIGN(16) int32_t mask_even[4]; // used by stride 2 kernels
800 XNN_ALIGN(16) int32_t mask_odd[4]; // used by stride 2 kernels
801 XNN_ALIGN(16) int32_t mask[4]; // used by stride 1 kernels
802 float min;
803 float max;
804 } scalar;
805 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
806 struct {
807 float min;
808 float max;
809 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
810 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
811 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
812 } neon;
813 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
814 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
815 struct {
816 XNN_ALIGN(16) float min[4];
817 XNN_ALIGN(16) float max[4];
818 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
819 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels
820 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
821 } sse;
822 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
823 };
824
825 union xnn_s8_minmax_params {
826 struct {
827 int32_t min;
828 int32_t max;
829 } scalar;
830 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
831 struct {
832 XNN_ALIGN(16) uint8_t bias[16];
833 XNN_ALIGN(16) uint8_t min_with_bias[16];
834 XNN_ALIGN(16) uint8_t max_with_bias[16];
835 } sse2;
836 struct {
837 XNN_ALIGN(16) int8_t min[16];
838 XNN_ALIGN(16) int8_t max[16];
839 } sse4;
840 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
841 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
842 struct {
843 int8_t min;
844 int8_t max;
845 } neon;
846 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
847 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
848 struct {
849 XNN_ALIGN(8) int8_t min[8];
850 XNN_ALIGN(8) int8_t max[8];
851 } wasmsimd;
852 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
853 };
854
855 union xnn_u8_minmax_params {
856 struct {
857 uint32_t min;
858 uint32_t max;
859 } scalar;
860 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
861 struct {
862 XNN_ALIGN(16) uint8_t min[16];
863 XNN_ALIGN(16) uint8_t max[16];
864 } sse2;
865 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
866 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
867 struct {
868 uint8_t min;
869 uint8_t max;
870 } neon;
871 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
872 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
873 struct {
874 XNN_ALIGN(8) uint8_t min[8];
875 XNN_ALIGN(8) uint8_t max[8];
876 } wasmsimd;
877 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
878 };
879
880 union xnn_f32_scaleminmax_params {
881 struct {
882 float scale;
883 float min;
884 float max;
885 } scalar;
886 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
887 struct {
888 XNN_ALIGN(16) float scale[4];
889 XNN_ALIGN(16) float min[4];
890 XNN_ALIGN(16) float max[4];
891 } sse;
892 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
893 };
894
895 union xnn_f32_gavgpool_params {
896 struct {
897 XNN_ALIGN(16) int32_t mask[4];
898 float multiplier;
899 float output_min;
900 float output_max;
901 } scalar;
902 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
903 struct {
904 XNN_ALIGN(16) float multiplier[4];
905 XNN_ALIGN(16) float output_min[4];
906 XNN_ALIGN(16) float output_max[4];
907 XNN_ALIGN(16) uint32_t mask[4];
908 } sse;
909 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
910 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
911 struct {
912 XNN_ALIGN(16) float multiplier;
913 XNN_ALIGN(16) float output_min;
914 XNN_ALIGN(16) float output_max;
915 XNN_ALIGN(16) uint32_t mask[4];
916 } neon;
917 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 */
918 };
919
920 union xnn_f16_hswish_params {
921 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
922 struct {
923 uint16_t sixth;
924 uint16_t three;
925 uint16_t six;
926 uint16_t pad; // pad to 8 bytes for neonfp16arith assembly.
927 } neon;
928 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 */
929 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
930 struct {
931 XNN_ALIGN(32) float sixth[8];
932 XNN_ALIGN(32) float three[8];
933 XNN_ALIGN(16) uint16_t six[8];
934 } avx;
935 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
936 };
937
938 union xnn_f32_hswish_params {
939 struct {
940 float sixth;
941 float three;
942 float six;
943 } scalar;
944 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
945 struct {
946 XNN_ALIGN(16) float sixth[4];
947 XNN_ALIGN(16) float half[4];
948 XNN_ALIGN(16) float one[4];
949 } sse;
950 struct {
951 XNN_ALIGN(32) float sixth[8];
952 XNN_ALIGN(32) float half[8];
953 XNN_ALIGN(32) float one[8];
954 int32_t mask_table[14];
955 } avx;
956 struct {
957 float sixth;
958 float half;
959 float one;
960 } avx512;
961 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
962 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
963 struct {
964 XNN_ALIGN(8) float sixth[2];
965 XNN_ALIGN(8) float three[2];
966 XNN_ALIGN(8) float six[2];
967 } wasmsimd;
968 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
969 };
970
971 union xnn_qu8_conv_minmax_params {
972 struct {
973 int32_t kernel_zero_point;
974 float scale;
975 float output_min_less_zero_point;
976 float output_max_less_zero_point;
977 float magic_bias;
978 int32_t magic_bias_less_output_zero_point;
979 } fp32_scalar_fmagic;
980 struct {
981 int32_t kernel_zero_point;
982 float scale;
983 float magic_bias;
984 int32_t magic_min;
985 int32_t magic_max;
986 int32_t magic_bias_less_zero_point;
987 } fp32_scalar_imagic;
988 struct {
989 int32_t kernel_zero_point;
990 float scale;
991 float output_min_less_zero_point;
992 float output_max_less_zero_point;
993 int32_t output_zero_point;
994 } fp32_scalar_lrintf;
995 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
996 struct {
997 uint8_t kernel_zero_point[4];
998 float scale;
999 float magic_bias;
1000 int32_t magic_bias_less_output_zero_point;
1001 uint8_t output_min;
1002 uint8_t output_max;
1003 } fp32_neon;
1004 struct {
1005 uint8_t kernel_zero_point[4];
1006 float scale;
1007 int16_t output_zero_point;
1008 uint8_t output_min;
1009 uint8_t output_max;
1010 } fp32_neonv8;
1011 struct {
1012 uint8_t kernel_zero_point[4];
1013 int32_t right_pre_shift;
1014 int32_t multiplier;
1015 int32_t right_post_shift;
1016 int16_t output_zero_point;
1017 uint8_t output_min;
1018 uint8_t output_max;
1019 } rndnu_neon;
1020 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1021 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1022 struct {
1023 XNN_ALIGN(16) int16_t kernel_zero_point[8];
1024 XNN_ALIGN(16) float scale[4];
1025 XNN_ALIGN(16) float output_max_less_zero_point[4];
1026 XNN_ALIGN(16) int16_t output_zero_point[8];
1027 XNN_ALIGN(16) uint8_t output_min[16];
1028 } fp32_sse2;
1029 struct {
1030 XNN_ALIGN(32) int16_t kernel_zero_point[16];
1031 XNN_ALIGN(32) float scale[8];
1032 XNN_ALIGN(32) float output_max_less_zero_point[8];
1033 XNN_ALIGN(32) int16_t output_zero_point[16];
1034 XNN_ALIGN(32) uint8_t output_min[32];
1035 } fp32_avx2;
1036 struct {
1037 XNN_ALIGN(64) int16_t kernel_zero_point[32];
1038 XNN_ALIGN(64) float scale[16];
1039 XNN_ALIGN(64) float output_max_less_zero_point[16];
1040 XNN_ALIGN(64) int16_t output_zero_point[32];
1041 XNN_ALIGN(64) uint8_t output_min[64];
1042 } fp32_avx512;
1043 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1044 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1045 struct {
1046 XNN_ALIGN(8) int16_t kernel_zero_point[4];
1047 XNN_ALIGN(8) float scale[2];
1048 XNN_ALIGN(8) float magic_bias[2];
1049 XNN_ALIGN(8) int32_t magic_min[2];
1050 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1051 XNN_ALIGN(8) int8_t output_max[8];
1052 } fp32_wasmsimd;
1053 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1054 };
1055
1056 union xnn_qs8_minmax_params {
1057 struct {
1058 float magic_bias;
1059 int32_t magic_min;
1060 int32_t magic_max;
1061 int32_t magic_bias_less_zero_point;
1062 } scalar_imagic;
1063 struct {
1064 float output_min_less_zero_point;
1065 float output_max_less_zero_point;
1066 float magic_bias;
1067 int32_t magic_bias_less_output_zero_point;
1068 } scalar_fmagic;
1069 struct {
1070 float output_min_less_zero_point;
1071 float output_max_less_zero_point;
1072 int32_t output_zero_point;
1073 } scalar_lrintf;
1074 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1075 struct {
1076 float magic_bias;
1077 int32_t magic_bias_less_output_zero_point;
1078 int8_t output_min;
1079 int8_t output_max;
1080 } neon;
1081 struct {
1082 int16_t output_zero_point;
1083 uint8_t output_min;
1084 uint8_t output_max;
1085 } neonv8;
1086 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1087 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1088 struct {
1089 XNN_ALIGN(16) float output_max_less_zero_point[4];
1090 XNN_ALIGN(16) int16_t output_zero_point[8];
1091 XNN_ALIGN(16) int16_t output_min[8];
1092 } sse2;
1093 struct {
1094 XNN_ALIGN(16) float output_max_less_zero_point[4];
1095 XNN_ALIGN(16) int16_t output_zero_point[8];
1096 XNN_ALIGN(16) int8_t output_min[16];
1097 } sse4;
1098 struct {
1099 XNN_ALIGN(32) float output_max_less_zero_point[8];
1100 XNN_ALIGN(32) int16_t output_zero_point[16];
1101 XNN_ALIGN(32) int8_t output_min[32];
1102 } avx2;
1103 struct {
1104 XNN_ALIGN(64) float output_max_less_zero_point[16];
1105 XNN_ALIGN(64) int16_t output_zero_point[32];
1106 XNN_ALIGN(64) int8_t output_min[64];
1107 } avx512;
1108 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1109 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1110 struct {
1111 XNN_ALIGN(8) float magic_bias[2];
1112 XNN_ALIGN(8) int32_t magic_min[2];
1113 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1114 XNN_ALIGN(8) int8_t output_max[8];
1115 } wasmsimd;
1116 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1117 };
1118
1119 union xnn_qs8_conv_minmax_params {
1120 struct {
1121 float scale;
1122 float output_min_less_zero_point;
1123 float output_max_less_zero_point;
1124 float magic_bias;
1125 int32_t magic_bias_less_output_zero_point;
1126 } fp32_scalar_fmagic;
1127 struct {
1128 float scale;
1129 float magic_bias;
1130 int32_t magic_min;
1131 int32_t magic_max;
1132 int32_t magic_bias_less_zero_point;
1133 } fp32_scalar_imagic;
1134 struct {
1135 float scale;
1136 float output_min_less_zero_point;
1137 float output_max_less_zero_point;
1138 int32_t output_zero_point;
1139 } fp32_scalar_lrintf;
1140 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1141 struct {
1142 float scale;
1143 float magic_bias;
1144 int32_t magic_bias_less_output_zero_point;
1145 int8_t output_min;
1146 int8_t output_max;
1147 } fp32_neon;
1148 struct {
1149 float scale;
1150 int16_t output_zero_point;
1151 int8_t output_min;
1152 int8_t output_max;
1153 } fp32_neonv8;
1154 struct {
1155 int32_t right_pre_shift;
1156 int32_t multiplier;
1157 int32_t right_post_shift;
1158 int16_t output_zero_point;
1159 int8_t output_min;
1160 int8_t output_max;
1161 } rndnu_neon;
1162 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1163 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1164 struct {
1165 XNN_ALIGN(16) float scale[4];
1166 XNN_ALIGN(16) float output_max_less_zero_point[4];
1167 XNN_ALIGN(16) int16_t output_zero_point[8];
1168 XNN_ALIGN(16) int16_t output_min[8];
1169 } fp32_sse2;
1170 struct {
1171 XNN_ALIGN(16) float scale[4];
1172 XNN_ALIGN(16) float output_max_less_zero_point[4];
1173 XNN_ALIGN(16) int16_t output_zero_point[8];
1174 XNN_ALIGN(16) int8_t output_min[16];
1175 } fp32_sse4;
1176 struct {
1177 XNN_ALIGN(32) float scale[8];
1178 XNN_ALIGN(32) float output_max_less_zero_point[8];
1179 XNN_ALIGN(32) int16_t output_zero_point[16];
1180 XNN_ALIGN(32) int8_t output_min[32];
1181 } fp32_avx2;
1182 struct {
1183 XNN_ALIGN(64) float scale[16];
1184 XNN_ALIGN(64) float output_max_less_zero_point[16];
1185 XNN_ALIGN(64) int16_t output_zero_point[32];
1186 XNN_ALIGN(64) int8_t output_min[64];
1187 } fp32_avx512;
1188 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1189 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1190 struct {
1191 XNN_ALIGN(8) float scale[2];
1192 XNN_ALIGN(8) float magic_bias[2];
1193 XNN_ALIGN(8) int32_t magic_min[2];
1194 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1195 XNN_ALIGN(8) int8_t output_max[8];
1196 } fp32_wasmsimd;
1197 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1198 };
1199
1200 union xnn_qu8_addsub_minmax_params {
1201 struct {
1202 int32_t bias;
1203 int32_t a_multiplier;
1204 int32_t b_multiplier;
1205 int32_t rounding;
1206 uint32_t shift;
1207 int32_t output_min_less_zero_point;
1208 int32_t output_max_less_zero_point;
1209 int32_t output_zero_point;
1210 } scalar;
1211 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1212 struct {
1213 uint8_t a_zero_point;
1214 uint8_t b_zero_point;
1215 int16_t output_zero_point;
1216 int32_t a_multiplier;
1217 int32_t b_multiplier;
1218 int32_t right_shift;
1219 uint8_t output_min;
1220 uint8_t output_max;
1221 } neon;
1222 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1223 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1224 struct {
1225 XNN_ALIGN(16) int32_t bias[4];
1226 XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
1227 XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
1228 XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
1229 XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
1230 uint32_t shift;
1231 uint32_t b_multiplier;
1232 XNN_ALIGN(16) int16_t output_zero_point[8];
1233 XNN_ALIGN(16) uint8_t output_min[16];
1234 XNN_ALIGN(16) uint8_t output_max[16];
1235 } sse2;
1236 struct {
1237 XNN_ALIGN(16) int32_t bias[4];
1238 XNN_ALIGN(16) int32_t a_multiplier[4];
1239 XNN_ALIGN(16) int32_t b_multiplier[4];
1240 XNN_ALIGN(16) uint32_t shift[4];
1241 XNN_ALIGN(16) int16_t output_zero_point[8];
1242 XNN_ALIGN(16) uint8_t output_min[16];
1243 XNN_ALIGN(16) uint8_t output_max[16];
1244 } sse4;
1245 struct {
1246 XNN_ALIGN(32) int32_t bias[8];
1247 XNN_ALIGN(32) int32_t a_multiplier[8];
1248 XNN_ALIGN(32) int32_t b_multiplier[8];
1249 XNN_ALIGN(32) uint32_t shift[8];
1250 XNN_ALIGN(32) int16_t output_zero_point[16];
1251 XNN_ALIGN(16) uint8_t output_min[16];
1252 XNN_ALIGN(16) uint8_t output_max[16];
1253 } avx2;
1254 struct {
1255 XNN_ALIGN(64) int32_t bias[16];
1256 XNN_ALIGN(64) int32_t a_multiplier[16];
1257 XNN_ALIGN(64) int32_t b_multiplier[16];
1258 XNN_ALIGN(64) uint32_t shift[16];
1259 XNN_ALIGN(64) int16_t output_zero_point[32];
1260 XNN_ALIGN(32) uint8_t output_min[32];
1261 XNN_ALIGN(32) uint8_t output_max[32];
1262 } avx512;
1263 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1264 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1265 struct {
1266 XNN_ALIGN(8) int32_t bias[2];
1267 XNN_ALIGN(8) int32_t a_multiplier[2];
1268 XNN_ALIGN(8) int32_t b_multiplier[2];
1269 uint32_t shift;
1270 XNN_ALIGN(8) int16_t output_zero_point[4];
1271 XNN_ALIGN(8) uint8_t output_min[8];
1272 XNN_ALIGN(8) uint8_t output_max[8];
1273 } wasmsimd;
1274 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1275 };
1276
1277 union xnn_qs8_addsub_minmax_params {
1278 struct {
1279 int32_t bias;
1280 int32_t a_multiplier;
1281 int32_t b_multiplier;
1282 uint32_t shift;
1283 int32_t output_min_less_zero_point;
1284 int32_t output_max_less_zero_point;
1285 int32_t output_zero_point;
1286 } scalar;
1287 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1288 struct {
1289 int8_t a_zero_point;
1290 int8_t b_zero_point;
1291 int16_t output_zero_point;
1292 int32_t a_multiplier;
1293 int32_t b_multiplier;
1294 int32_t right_shift;
1295 int8_t output_min;
1296 int8_t output_max;
1297 } neon;
1298 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1299 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1300 struct {
1301 XNN_ALIGN(16) int32_t bias[4];
1302 XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
1303 XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
1304 XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
1305 XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
1306 uint32_t shift;
1307 uint32_t b_multiplier;
1308 XNN_ALIGN(16) int16_t output_zero_point[8];
1309 XNN_ALIGN(16) int16_t output_min[8];
1310 XNN_ALIGN(16) int16_t output_max[8];
1311 } sse2;
1312 struct {
1313 XNN_ALIGN(16) int32_t bias[4];
1314 XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
1315 XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
1316 XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
1317 XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
1318 uint32_t shift;
1319 uint32_t b_multiplier;
1320 XNN_ALIGN(16) int16_t output_zero_point[8];
1321 XNN_ALIGN(16) int8_t output_min[16];
1322 XNN_ALIGN(16) int8_t output_max[16];
1323 } sse4_mul16;
1324 struct {
1325 XNN_ALIGN(16) int32_t bias[4];
1326 XNN_ALIGN(16) int32_t a_multiplier[4];
1327 XNN_ALIGN(16) int32_t b_multiplier[4];
1328 XNN_ALIGN(16) uint32_t shift[4];
1329 XNN_ALIGN(16) int16_t output_zero_point[8];
1330 XNN_ALIGN(16) int8_t output_min[16];
1331 XNN_ALIGN(16) int8_t output_max[16];
1332 } sse4_mul32;
1333 struct {
1334 XNN_ALIGN(32) int32_t bias[8];
1335 XNN_ALIGN(32) int32_t a_multiplier[8];
1336 XNN_ALIGN(32) int32_t b_multiplier[8];
1337 XNN_ALIGN(32) uint32_t shift[8];
1338 XNN_ALIGN(32) int16_t output_zero_point[16];
1339 XNN_ALIGN(16) int8_t output_min[16];
1340 XNN_ALIGN(16) int8_t output_max[16];
1341 } avx2;
1342 struct {
1343 XNN_ALIGN(64) int32_t bias[16];
1344 XNN_ALIGN(64) int32_t a_multiplier[16];
1345 XNN_ALIGN(64) int32_t b_multiplier[16];
1346 XNN_ALIGN(64) uint32_t shift[16];
1347 XNN_ALIGN(64) int16_t output_zero_point[32];
1348 XNN_ALIGN(32) int8_t output_min[32];
1349 XNN_ALIGN(32) int8_t output_max[32];
1350 } avx512;
1351 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1352 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1353 struct {
1354 XNN_ALIGN(8) int32_t bias[2];
1355 XNN_ALIGN(8) int32_t a_multiplier[2];
1356 XNN_ALIGN(8) int32_t b_multiplier[2];
1357 uint32_t shift;
1358 XNN_ALIGN(8) int16_t output_zero_point[4];
1359 XNN_ALIGN(8) int8_t output_min[8];
1360 XNN_ALIGN(8) int8_t output_max[8];
1361 } wasmsimd;
1362 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1363 };
1364
1365 union xnn_qu8_mul_minmax_params {
1366 struct {
1367 int32_t a_zero_point;
1368 int32_t b_zero_point;
1369 float scale;
1370 float output_min_less_zero_point;
1371 float output_max_less_zero_point;
1372 float magic_bias;
1373 int32_t magic_bias_less_output_zero_point;
1374 } fp32_scalar;
1375 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1376 struct {
1377 uint8_t a_zero_point[2];
1378 uint8_t b_zero_point[2];
1379 float scale;
1380 float magic_bias;
1381 int32_t magic_bias_less_output_zero_point;
1382 uint8_t output_min;
1383 uint8_t output_max;
1384 } fp32_neon;
1385 struct {
1386 uint8_t a_zero_point[2];
1387 uint8_t b_zero_point[2];
1388 float scale;
1389 int16_t output_zero_point;
1390 uint8_t output_min;
1391 uint8_t output_max;
1392 } fp32_neonv8;
1393 struct {
1394 uint8_t a_zero_point[2];
1395 uint8_t b_zero_point[2];
1396 int32_t left_pre_shift;
1397 int32_t multiplier;
1398 int32_t left_post_shift;
1399 int16_t output_zero_point;
1400 uint8_t output_min;
1401 uint8_t output_max;
1402 } rndnu_neon;
1403 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1404 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1405 struct {
1406 XNN_ALIGN(16) int16_t a_zero_point[8];
1407 XNN_ALIGN(16) int16_t b_zero_point[8];
1408 XNN_ALIGN(16) float scale[4];
1409 XNN_ALIGN(16) int16_t output_zero_point[8];
1410 XNN_ALIGN(16) uint8_t output_min[16];
1411 XNN_ALIGN(16) uint8_t output_max[16];
1412 } fp32_sse2;
1413 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1414 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1415 struct {
1416 XNN_ALIGN(8) int16_t a_zero_point[4];
1417 XNN_ALIGN(8) int16_t b_zero_point[4];
1418 XNN_ALIGN(8) float scale[2];
1419 XNN_ALIGN(8) float magic_bias[2];
1420 XNN_ALIGN(8) int32_t magic_min[2];
1421 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1422 XNN_ALIGN(8) uint8_t output_max[8];
1423 } fp32_wasmsimd;
1424 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1425 };
1426
1427 union xnn_qs8_mul_minmax_params {
1428 struct {
1429 int32_t a_zero_point;
1430 int32_t b_zero_point;
1431 float scale;
1432 float output_min_less_zero_point;
1433 float output_max_less_zero_point;
1434 float magic_bias;
1435 int32_t magic_bias_less_output_zero_point;
1436 } fp32_scalar;
1437 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1438 struct {
1439 int8_t a_zero_point[2];
1440 int8_t b_zero_point[2];
1441 float scale;
1442 float magic_bias;
1443 int32_t magic_bias_less_output_zero_point;
1444 int8_t output_min;
1445 int8_t output_max;
1446 } fp32_neon;
1447 struct {
1448 int8_t a_zero_point[2];
1449 int8_t b_zero_point[2];
1450 float scale;
1451 int16_t output_zero_point;
1452 int8_t output_min;
1453 int8_t output_max;
1454 } fp32_neonv8;
1455 struct {
1456 int8_t a_zero_point[2];
1457 int8_t b_zero_point[2];
1458 int32_t left_pre_shift;
1459 int32_t multiplier;
1460 int32_t left_post_shift;
1461 int16_t output_zero_point;
1462 int8_t output_min;
1463 int8_t output_max;
1464 } rndnu_neon;
1465 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1466 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1467 struct {
1468 XNN_ALIGN(16) int16_t a_zero_point[8];
1469 XNN_ALIGN(16) int16_t b_zero_point[8];
1470 XNN_ALIGN(16) float scale[4];
1471 XNN_ALIGN(16) int16_t output_zero_point[8];
1472 XNN_ALIGN(16) int16_t output_min[8];
1473 XNN_ALIGN(16) int16_t output_max[8];
1474 } fp32_sse2;
1475 struct {
1476 XNN_ALIGN(16) int16_t a_zero_point[8];
1477 XNN_ALIGN(16) int16_t b_zero_point[8];
1478 XNN_ALIGN(16) float scale[4];
1479 XNN_ALIGN(16) int16_t output_zero_point[8];
1480 XNN_ALIGN(16) int8_t output_min[16];
1481 XNN_ALIGN(16) int8_t output_max[16];
1482 } fp32_sse4;
1483 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1484 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1485 struct {
1486 XNN_ALIGN(8) int16_t a_zero_point[4];
1487 XNN_ALIGN(8) int16_t b_zero_point[4];
1488 XNN_ALIGN(8) float scale[2];
1489 XNN_ALIGN(8) float magic_bias[2];
1490 XNN_ALIGN(8) int32_t magic_min[2];
1491 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1492 XNN_ALIGN(8) int8_t output_max[8];
1493 } fp32_wasmsimd;
1494 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1495 };
1496
1497 union xnn_qu8_avgpool_minmax_params {
1498 struct {
1499 int32_t init_bias;
1500 float scale;
1501 float output_min_less_zero_point;
1502 float output_max_less_zero_point;
1503 float magic_bias;
1504 int32_t magic_bias_less_output_zero_point;
1505 } fp32_scalar_fmagic;
1506 struct {
1507 int32_t init_bias;
1508 float scale;
1509 float magic_bias;
1510 int32_t magic_min;
1511 int32_t magic_max;
1512 int32_t magic_bias_less_zero_point;
1513 } fp32_scalar_imagic;
1514 struct {
1515 int32_t init_bias;
1516 float scale;
1517 float output_min_less_zero_point;
1518 float output_max_less_zero_point;
1519 int32_t output_zero_point;
1520 } fp32_scalar_lrintf;
1521 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1522 struct {
1523 int32_t init_bias;
1524 float scale;
1525 float magic_bias;
1526 int32_t magic_bias_less_output_zero_point;
1527 uint8_t output_min;
1528 uint8_t output_max;
1529 } fp32_neon;
1530 struct {
1531 int32_t init_bias;
1532 float scale;
1533 int16_t output_zero_point;
1534 uint8_t output_min;
1535 uint8_t output_max;
1536 } fp32_neonv8;
1537 struct {
1538 int32_t init_bias;
1539 int32_t left_pre_shift;
1540 int32_t multiplier;
1541 int32_t left_post_shift;
1542 int16_t output_zero_point;
1543 uint8_t output_min;
1544 uint8_t output_max;
1545 } rndnu_neon;
1546 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1547 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1548 struct {
1549 XNN_ALIGN(16) int32_t init_bias[4];
1550 XNN_ALIGN(16) float scale[4];
1551 XNN_ALIGN(16) float output_max_less_zero_point[4];
1552 XNN_ALIGN(16) int16_t output_zero_point[8];
1553 XNN_ALIGN(16) uint8_t output_min[16];
1554 } fp32_sse2;
1555 struct {
1556 XNN_ALIGN(16) int32_t init_bias[4];
1557 XNN_ALIGN(16) float scale[4];
1558 XNN_ALIGN(16) float output_max_less_zero_point[4];
1559 XNN_ALIGN(16) int16_t output_zero_point[8];
1560 XNN_ALIGN(16) uint8_t output_min[16];
1561 } fp32_sse4;
1562 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1563 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1564 struct {
1565 XNN_ALIGN(8) int32_t init_bias[2];
1566 XNN_ALIGN(8) float scale[2];
1567 XNN_ALIGN(8) float magic_bias[2];
1568 XNN_ALIGN(8) int32_t magic_min[2];
1569 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1570 XNN_ALIGN(8) uint8_t output_max[8];
1571 } fp32_wasmsimd;
1572 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1573
1574 // Legacy parameters used by QU8 AVGPOOL microkernels
1575 struct {
1576 int32_t bias;
1577 int32_t multiplier;
1578 int64_t rounding;
1579 uint32_t right_shift;
1580 int32_t output_min_less_zero_point;
1581 int32_t output_max_less_zero_point;
1582 int32_t output_zero_point;
1583 } scalar;
1584 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1585 struct {
1586 int32_t bias;
1587 int32_t multiplier;
1588 int64_t left_shift;
1589 int16_t output_zero_point;
1590 uint8_t output_min;
1591 uint8_t output_max;
1592 } neon;
1593 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1594 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1595 struct {
1596 XNN_ALIGN(16) int32_t bias[4];
1597 XNN_ALIGN(16) uint32_t multiplier[4];
1598 XNN_ALIGN(16) uint64_t rounding[2];
1599 XNN_ALIGN(16) uint64_t right_shift[2];
1600 XNN_ALIGN(16) int16_t output_zero_point[8];
1601 XNN_ALIGN(16) uint8_t output_min[16];
1602 XNN_ALIGN(16) uint8_t output_max[16];
1603 } sse2;
1604 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1605 };
1606
1607 union xnn_qs8_avgpool_minmax_params {
1608 struct {
1609 int32_t init_bias;
1610 float scale;
1611 float output_min_less_zero_point;
1612 float output_max_less_zero_point;
1613 float magic_bias;
1614 int32_t magic_bias_less_output_zero_point;
1615 } fp32_scalar_fmagic;
1616 struct {
1617 int32_t init_bias;
1618 float scale;
1619 float magic_bias;
1620 int32_t magic_min;
1621 int32_t magic_max;
1622 int32_t magic_bias_less_zero_point;
1623 } fp32_scalar_imagic;
1624 struct {
1625 int32_t init_bias;
1626 float scale;
1627 float output_min_less_zero_point;
1628 float output_max_less_zero_point;
1629 int32_t output_zero_point;
1630 } fp32_scalar_lrintf;
1631 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1632 struct {
1633 int32_t init_bias;
1634 float scale;
1635 float magic_bias;
1636 int32_t magic_bias_less_output_zero_point;
1637 int8_t output_min;
1638 int8_t output_max;
1639 } fp32_neon;
1640 struct {
1641 int32_t init_bias;
1642 float scale;
1643 int16_t output_zero_point;
1644 int8_t output_min;
1645 int8_t output_max;
1646 } fp32_neonv8;
1647 struct {
1648 int32_t init_bias;
1649 int32_t left_pre_shift;
1650 int32_t multiplier;
1651 int32_t left_post_shift;
1652 int16_t output_zero_point;
1653 int8_t output_min;
1654 int8_t output_max;
1655 } rndnu_neon;
1656 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1657 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1658 struct {
1659 XNN_ALIGN(16) int32_t init_bias[4];
1660 XNN_ALIGN(16) float scale[4];
1661 XNN_ALIGN(16) float output_max_less_zero_point[4];
1662 XNN_ALIGN(16) int16_t output_zero_point[8];
1663 XNN_ALIGN(16) int16_t output_min[8];
1664 } fp32_sse2;
1665 struct {
1666 XNN_ALIGN(16) int32_t init_bias[4];
1667 XNN_ALIGN(16) float scale[4];
1668 XNN_ALIGN(16) float output_max_less_zero_point[4];
1669 XNN_ALIGN(16) int16_t output_zero_point[8];
1670 XNN_ALIGN(16) int8_t output_min[16];
1671 } fp32_sse4;
1672 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1673 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1674 struct {
1675 XNN_ALIGN(8) int32_t init_bias[2];
1676 XNN_ALIGN(8) float scale[2];
1677 XNN_ALIGN(8) float magic_bias[2];
1678 XNN_ALIGN(8) int32_t magic_min[2];
1679 XNN_ALIGN(8) int32_t magic_bias_less_output_zero_point[2];
1680 XNN_ALIGN(8) int8_t output_max[8];
1681 } fp32_wasmsimd;
1682 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1683 };
1684
1685 union xnn_f16_f32_cvt_params {
1686 struct {
1687 uint32_t sign_mask;
1688 uint32_t exp_offset;
1689 float exp_scale;
1690 uint32_t magic_mask;
1691 float magic_bias;
1692 uint32_t denorm_cutoff;
1693 } scalar;
1694 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1695 struct {
1696 float exp_scale;
1697 } neon;
1698 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1699 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1700 struct {
1701 XNN_ALIGN(16) uint16_t sign_mask[8];
1702 XNN_ALIGN(16) uint16_t exp_offset[8];
1703 XNN_ALIGN(16) float exp_scale[4];
1704 XNN_ALIGN(16) uint16_t magic_mask[8];
1705 XNN_ALIGN(16) float magic_bias[4];
1706 XNN_ALIGN(16) int16_t denorm_cutoff[8];
1707 } sse_int16;
1708 struct {
1709 XNN_ALIGN(16) uint32_t sign_mask[4];
1710 XNN_ALIGN(16) uint32_t exp_offset[4];
1711 XNN_ALIGN(16) float exp_scale[4];
1712 XNN_ALIGN(16) uint32_t magic_bias[4];
1713 XNN_ALIGN(16) int32_t denorm_cutoff[4];
1714 } sse_int32;
1715 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1716 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1717 struct {
1718 XNN_ALIGN(8) uint16_t sign_mask[4];
1719 XNN_ALIGN(8) uint16_t exp_offset[4];
1720 XNN_ALIGN(8) float exp_scale[2];
1721 XNN_ALIGN(8) uint16_t magic_mask[4];
1722 XNN_ALIGN(8) float magic_bias[2];
1723 XNN_ALIGN(8) int16_t denorm_cutoff[4];
1724 } wasmsimd_int16;
1725 struct {
1726 XNN_ALIGN(8) uint32_t sign_mask[2];
1727 XNN_ALIGN(8) uint32_t exp_offset[2];
1728 XNN_ALIGN(8) float exp_scale[2];
1729 XNN_ALIGN(8) uint32_t magic_bias[2];
1730 XNN_ALIGN(8) int32_t denorm_cutoff[2];
1731 } wasmsimd_int32;
1732 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1733 };
1734
1735 union xnn_f32_f16_cvt_params {
1736 struct {
1737 uint32_t nonsign_mask;
1738 uint32_t exp_bias;
1739 float scale_to_inf;
1740 uint32_t expw_max;
1741 float scale_to_zero;
1742 uint32_t bias_min;
1743 uint16_t exph_mask;
1744 uint16_t manth_mask;
1745 uint16_t nanh;
1746 } scalar_bitcast;
1747 struct {
1748 float scale_to_inf;
1749 uint32_t exp_bias;
1750 float scale_to_zero;
1751 uint32_t expw_max;
1752 uint32_t bias_min;
1753 uint16_t exph_mask;
1754 uint16_t manth_mask;
1755 uint16_t nanh;
1756 } scalar_fabsf;
1757 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1758 struct {
1759 uint32_t exp_bias;
1760 float scale_to_inf;
1761 uint32_t expw_max;
1762 float scale_to_zero;
1763 } neon;
1764 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1765 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1766 struct {
1767 XNN_ALIGN(16) uint32_t nonsign_mask[4];
1768 XNN_ALIGN(16) uint32_t exp_bias[4];
1769 XNN_ALIGN(16) float scale_to_inf[4];
1770 XNN_ALIGN(16) uint32_t expw_max[4];
1771 XNN_ALIGN(16) float scale_to_zero[4];
1772 XNN_ALIGN(16) int16_t bias_min[8];
1773 XNN_ALIGN(16) uint32_t manth_mask[4];
1774 XNN_ALIGN(16) uint32_t exph_mask[4];
1775 XNN_ALIGN(16) uint16_t nanh[8];
1776 } sse2;
1777 struct {
1778 int32_t mask_table[14];
1779 } f16c;
1780 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1781 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1782 struct {
1783 XNN_ALIGN(8) uint32_t exp_bias[2];
1784 XNN_ALIGN(8) float scale_to_inf[2];
1785 XNN_ALIGN(8) uint32_t expw_max[2];
1786 XNN_ALIGN(8) float scale_to_zero[2];
1787 XNN_ALIGN(8) int16_t bias_min[4];
1788 XNN_ALIGN(8) uint32_t manth_mask[2];
1789 XNN_ALIGN(8) uint32_t exph_mask[2];
1790 XNN_ALIGN(8) uint16_t nanh[4];
1791 } wasmsimd;
1792 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1793 };
1794
1795 union xnn_f32_qs8_cvt_params {
1796 struct {
1797 float scale;
1798 float output_min_less_zero_point;
1799 float output_max_less_zero_point;
1800 float magic_bias;
1801 int32_t magic_bias_less_zero_point;
1802 } scalar_fmagic;
1803 struct {
1804 float scale;
1805 float magic_bias;
1806 int32_t magic_min;
1807 int32_t magic_max;
1808 int32_t magic_bias_less_zero_point;
1809 } scalar_imagic;
1810 struct {
1811 float scale;
1812 float output_min_less_zero_point;
1813 float output_max_less_zero_point;
1814 int32_t output_zero_point;
1815 } scalar_lrintf;
1816 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1817 struct {
1818 float scale;
1819 float magic_bias;
1820 int32_t magic_bias_less_zero_point;
1821 int8_t output_min;
1822 int8_t output_max;
1823 } neon;
1824 struct {
1825 float scale;
1826 int16_t output_zero_point;
1827 int8_t output_min;
1828 int8_t output_max;
1829 } neonv8;
1830 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1831 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1832 struct {
1833 XNN_ALIGN(16) float scale[4];
1834 XNN_ALIGN(16) float output_max_less_zero_point[4];
1835 XNN_ALIGN(16) int16_t output_zero_point[8];
1836 XNN_ALIGN(16) int16_t output_min[8];
1837 } sse2;
1838 struct {
1839 XNN_ALIGN(16) float scale[4];
1840 XNN_ALIGN(16) float output_max_less_zero_point[4];
1841 XNN_ALIGN(16) int16_t output_zero_point[8];
1842 XNN_ALIGN(16) int8_t output_min[16];
1843 } sse4;
1844 struct {
1845 XNN_ALIGN(32) float scale[8];
1846 XNN_ALIGN(32) float output_max_less_zero_point[8];
1847 XNN_ALIGN(16) int16_t output_zero_point[8];
1848 XNN_ALIGN(16) int8_t output_min[16];
1849 int32_t mask_table[14];
1850 } avx;
1851 struct {
1852 XNN_ALIGN(32) float scale[8];
1853 XNN_ALIGN(32) float output_max_less_zero_point[8];
1854 XNN_ALIGN(32) int16_t output_zero_point[16];
1855 XNN_ALIGN(32) uint32_t shuffle_mask[8];
1856 XNN_ALIGN(32) int8_t output_min[32];
1857 int32_t mask_table[14];
1858 } avx2;
1859 struct {
1860 XNN_ALIGN(64) float scale[16];
1861 XNN_ALIGN(64) float output_max_less_zero_point[16];
1862 XNN_ALIGN(64) int16_t output_zero_point[32];
1863 XNN_ALIGN(64) int8_t output_min[64];
1864 XNN_ALIGN(64) uint32_t shuffle512_mask[16];
1865 XNN_ALIGN(32) uint32_t shuffle256_mask[8];
1866 } avx512;
1867 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1868 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1869 struct {
1870 XNN_ALIGN(8) float scale[2];
1871 XNN_ALIGN(8) int16_t output_zero_point[4];
1872 XNN_ALIGN(8) int8_t output_min[8];
1873 XNN_ALIGN(8) int8_t output_max[8];
1874 } wasmsimd_cvt;
1875 struct {
1876 XNN_ALIGN(8) float scale[2];
1877 XNN_ALIGN(8) float magic_bias[2];
1878 XNN_ALIGN(8) int32_t magic_min[2];
1879 XNN_ALIGN(8) int32_t magic_bias_less_zero_point[2];
1880 XNN_ALIGN(8) int8_t output_max[8];
1881 } wasmsimd_magic;
1882 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1883 };
1884
1885 union xnn_f32_qu8_cvt_params {
1886 struct {
1887 float scale;
1888 float output_min_less_zero_point;
1889 float output_max_less_zero_point;
1890 float magic_bias;
1891 int32_t magic_bias_less_zero_point;
1892 } scalar_fmagic;
1893 struct {
1894 float scale;
1895 float magic_bias;
1896 int32_t magic_min;
1897 int32_t magic_max;
1898 int32_t magic_bias_less_zero_point;
1899 } scalar_imagic;
1900 struct {
1901 float scale;
1902 float output_min_less_zero_point;
1903 float output_max_less_zero_point;
1904 int32_t output_zero_point;
1905 } scalar_lrintf;
1906 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1907 struct {
1908 float scale;
1909 float magic_bias;
1910 int32_t magic_bias_less_zero_point;
1911 uint8_t output_min;
1912 uint8_t output_max;
1913 } neon;
1914 struct {
1915 float scale;
1916 int16_t output_zero_point;
1917 uint8_t output_min;
1918 uint8_t output_max;
1919 } neonv8;
1920 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1921 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1922 struct {
1923 XNN_ALIGN(16) float scale[4];
1924 XNN_ALIGN(16) float output_max_less_zero_point[4];
1925 XNN_ALIGN(16) int16_t output_zero_point[8];
1926 XNN_ALIGN(16) uint8_t output_min[16];
1927 } sse2;
1928 struct {
1929 XNN_ALIGN(32) float scale[8];
1930 XNN_ALIGN(32) float output_max_less_zero_point[8];
1931 XNN_ALIGN(16) int16_t output_zero_point[8];
1932 XNN_ALIGN(16) uint8_t output_min[16];
1933 int32_t mask_table[14];
1934 } avx;
1935 struct {
1936 XNN_ALIGN(32) float scale[8];
1937 XNN_ALIGN(32) float output_max_less_zero_point[8];
1938 XNN_ALIGN(32) int16_t output_zero_point[16];
1939 XNN_ALIGN(32) uint32_t shuffle_mask[8];
1940 XNN_ALIGN(32) uint8_t output_min[32];
1941 int32_t mask_table[14];
1942 } avx2;
1943 struct {
1944 XNN_ALIGN(64) float scale[16];
1945 XNN_ALIGN(64) float output_max_less_zero_point[16];
1946 XNN_ALIGN(64) int16_t output_zero_point[32];
1947 XNN_ALIGN(64) uint8_t output_min[64];
1948 XNN_ALIGN(64) uint32_t shuffle512_mask[16];
1949 XNN_ALIGN(32) uint32_t shuffle256_mask[8];
1950 } avx512;
1951 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
1952 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1953 struct {
1954 XNN_ALIGN(8) float scale[2];
1955 XNN_ALIGN(8) int16_t output_zero_point[4];
1956 XNN_ALIGN(8) uint8_t output_min[8];
1957 XNN_ALIGN(8) uint8_t output_max[8];
1958 } wasmsimd_cvt;
1959 struct {
1960 XNN_ALIGN(8) float scale[2];
1961 XNN_ALIGN(8) float magic_bias[2];
1962 XNN_ALIGN(8) int32_t magic_min[2];
1963 XNN_ALIGN(8) int32_t magic_bias_less_zero_point[2];
1964 XNN_ALIGN(8) uint8_t output_max[8];
1965 } wasmsimd_magic;
1966 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
1967 };
1968
1969 union xnn_qs8_f32_cvt_params {
1970 struct {
1971 int32_t zero_point;
1972 float scale;
1973 } scalar;
1974 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
1975 struct {
1976 int16_t minus_zero_point[2];
1977 float scale;
1978 } neon;
1979 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
1980 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1981 struct {
1982 XNN_ALIGN(16) uint8_t sign_mask[16];
1983 XNN_ALIGN(16) uint16_t magic_exp[8];
1984 XNN_ALIGN(16) float magic_bias[4];
1985 XNN_ALIGN(16) float scale[4];
1986 } sse2;
1987 struct {
1988 XNN_ALIGN(16) int32_t minus_zero_point[4];
1989 XNN_ALIGN(16) float scale[4];
1990 } sse4;
1991 struct {
1992 XNN_ALIGN(32) int32_t minus_zero_point[8];
1993 XNN_ALIGN(32) float scale[8];
1994 } avx;
1995 struct {
1996 XNN_ALIGN(64) int32_t minus_zero_point[16];
1997 XNN_ALIGN(64) float scale[16];
1998 } avx512;
1999 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
2000 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
2001 struct {
2002 XNN_ALIGN(8) int16_t minus_zero_point[4];
2003 XNN_ALIGN(8) float scale[2];
2004 } wasmsimd;
2005 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
2006 };
2007
2008 union xnn_qu8_f32_cvt_params {
2009 struct {
2010 int32_t zero_point;
2011 float scale;
2012 } scalar;
2013 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
2014 struct {
2015 int16_t minus_zero_point[2];
2016 float scale;
2017 } neon;
2018 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
2019 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
2020 struct {
2021 XNN_ALIGN(16) uint16_t magic_exp[8];
2022 XNN_ALIGN(16) float magic_bias[4];
2023 XNN_ALIGN(16) float scale[4];
2024 } sse2;
2025 struct {
2026 XNN_ALIGN(16) int32_t minus_zero_point[4];
2027 XNN_ALIGN(16) float scale[4];
2028 } sse4;
2029 struct {
2030 XNN_ALIGN(32) int32_t minus_zero_point[8];
2031 XNN_ALIGN(32) float scale[8];
2032 } avx;
2033 struct {
2034 XNN_ALIGN(64) int32_t minus_zero_point[16];
2035 XNN_ALIGN(64) float scale[16];
2036 } avx512;
2037 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
2038 #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
2039 struct {
2040 XNN_ALIGN(8) int16_t minus_zero_point[4];
2041 XNN_ALIGN(8) float scale[2];
2042 } wasmsimd;
2043 #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
2044 };
2045
2046 typedef void (*xnn_ppmm_ukernel_function)(
2047 size_t mr,
2048 size_t nc,
2049 size_t kc,
2050 const void* a,
2051 const void* w,
2052 void* c,
2053 size_t cm_stride,
2054 size_t cn_stride,
2055 const void* params);
2056
2057 typedef void (*xnn_f32_ppmm_minmax_ukernel_function)(
2058 size_t mr,
2059 size_t nc,
2060 size_t kc,
2061 const float* a,
2062 const float* w,
2063 float* c,
2064 size_t cm_stride,
2065 size_t cn_stride,
2066 const union xnn_f32_minmax_params* params);
2067
2068 typedef void (*xnn_f16_ppmm_ukernel_function)(
2069 size_t mr,
2070 size_t nc,
2071 size_t kc,
2072 const void* a,
2073 const void* w,
2074 void* c,
2075 size_t cm_stride,
2076 size_t cn_stride,
2077 const union xnn_f16_scaleminmax_params* params);
2078
2079 typedef void (*xnn_gemm_ukernel_function)(
2080 size_t mr,
2081 size_t nr,
2082 size_t k,
2083 const void* a,
2084 size_t a_stride,
2085 const void* w,
2086 void* c,
2087 size_t cm_stride,
2088 size_t cn_stride,
2089 const void* params);
2090
2091 typedef void (*xnn_f32_gemm_ukernel_function)(
2092 size_t mr,
2093 size_t nr,
2094 size_t k,
2095 const float* a,
2096 size_t a_stride,
2097 const float* w,
2098 float* c,
2099 size_t cm_stride,
2100 size_t cn_stride,
2101 const union xnn_f32_default_params* params);
2102
2103 typedef void (*xnn_x8_transpose_ukernel_function)(
2104 const uint8_t* a,
2105 uint8_t* b,
2106 size_t input_stride,
2107 size_t output_stride,
2108 size_t block_width,
2109 size_t block_height);
2110
2111 typedef void (*xnn_x16_transpose_ukernel_function)(
2112 const uint16_t* a,
2113 uint16_t* b,
2114 size_t input_stride,
2115 size_t output_stride,
2116 size_t block_width,
2117 size_t block_height);
2118
2119 typedef void (*xnn_x32_transpose_ukernel_function)(
2120 const uint32_t* a,
2121 uint32_t* b,
2122 size_t input_stride,
2123 size_t output_stride,
2124 size_t block_width,
2125 size_t block_height);
2126
2127 typedef void (*xnn_x64_transpose_ukernel_function)(
2128 const uint64_t* a,
2129 uint64_t* b,
2130 size_t input_stride,
2131 size_t output_stride,
2132 size_t block_width,
2133 size_t block_height);
2134
2135 typedef void (*xnn_f32_gemm_relu_ukernel_function)(
2136 size_t mr,
2137 size_t nr,
2138 size_t k,
2139 const float* a,
2140 size_t a_stride,
2141 const float* w,
2142 float* c,
2143 size_t cm_stride,
2144 size_t cn_stride,
2145 const union xnn_f32_relu_params* params);
2146
2147 typedef void (*xnn_f32_gemm_minmax_ukernel_function)(
2148 size_t mr,
2149 size_t nr,
2150 size_t k,
2151 const float* a,
2152 size_t a_stride,
2153 const float* w,
2154 float* c,
2155 size_t cm_stride,
2156 size_t cn_stride,
2157 const union xnn_f32_minmax_params* params);
2158
2159 typedef void (*xnn_f32_gemminc_minmax_ukernel_function)(
2160 size_t mr,
2161 size_t nr,
2162 size_t k,
2163 const float* a,
2164 size_t a_stride,
2165 const float* w,
2166 float* c,
2167 size_t cm_stride,
2168 size_t cn_stride,
2169 const float* acc,
2170 const union xnn_f32_minmax_params* params);
2171
2172 typedef void (*xnn_f16_gemm_minmax_ukernel_function)(
2173 size_t mr,
2174 size_t nr,
2175 size_t k,
2176 const void* a,
2177 size_t a_stride,
2178 const void* w,
2179 void* c,
2180 size_t cm_stride,
2181 size_t cn_stride,
2182 const union xnn_f16_scaleminmax_params* params);
2183
2184 typedef void (*xnn_f16_igemm_minmax_ukernel_function)(
2185 size_t mr,
2186 size_t nr,
2187 size_t kc,
2188 size_t ks,
2189 const void** a,
2190 const void* w,
2191 void* c,
2192 size_t cm_stride,
2193 size_t cn_stride,
2194 size_t a_offset,
2195 const void* zero,
2196 const union xnn_f16_scaleminmax_params* params);
2197
2198 typedef void (*xnn_qc8_gemm_minmax_ukernel_function)(
2199 size_t mr,
2200 size_t nr,
2201 size_t k,
2202 const int8_t* a,
2203 size_t a_stride,
2204 const void* w,
2205 int8_t* c,
2206 size_t cm_stride,
2207 size_t cn_stride,
2208 const union xnn_qs8_minmax_params* params);
2209
2210 typedef void (*xnn_qs8_gemm_minmax_ukernel_function)(
2211 size_t mr,
2212 size_t nr,
2213 size_t k,
2214 const int8_t* a,
2215 size_t a_stride,
2216 const void* w,
2217 int8_t* c,
2218 size_t cm_stride,
2219 size_t cn_stride,
2220 const union xnn_qs8_conv_minmax_params* params);
2221
2222 typedef void (*xnn_qu8_gemm_minmax_ukernel_function)(
2223 size_t mr,
2224 size_t nr,
2225 size_t k,
2226 const uint8_t* a,
2227 size_t a_stride,
2228 const void* w,
2229 uint8_t* c,
2230 size_t cm_stride,
2231 size_t cn_stride,
2232 const union xnn_qu8_conv_minmax_params* params);
2233
2234 typedef void (*xnn_igemm_ukernel_function)(
2235 size_t mr,
2236 size_t nr,
2237 size_t kc,
2238 size_t ks,
2239 const void** a,
2240 const void* w,
2241 void* c,
2242 size_t cm_stride,
2243 size_t cn_stride,
2244 size_t a_offset,
2245 const void* zero,
2246 const void* params);
2247
2248 typedef void (*xnn_f32_igemm_ukernel_function)(
2249 size_t mr,
2250 size_t nr,
2251 size_t kc,
2252 size_t ks,
2253 const float** a,
2254 const float* w,
2255 float* c,
2256 size_t cm_stride,
2257 size_t cn_stride,
2258 size_t a_offset,
2259 const float* zero,
2260 const union xnn_f32_default_params* params);
2261
2262 typedef void (*xnn_f32_igemm_relu_ukernel_function)(
2263 size_t mr,
2264 size_t nr,
2265 size_t kc,
2266 size_t ks,
2267 const float** a,
2268 const float* w,
2269 float* c,
2270 size_t cm_stride,
2271 size_t cn_stride,
2272 size_t a_offset,
2273 const float* zero,
2274 const union xnn_f32_relu_params* params);
2275
2276 typedef void (*xnn_f32_igemm_minmax_ukernel_function)(
2277 size_t mr,
2278 size_t nr,
2279 size_t kc,
2280 size_t ks,
2281 const float** a,
2282 const float* w,
2283 float* c,
2284 size_t cm_stride,
2285 size_t cn_stride,
2286 size_t a_offset,
2287 const float* zero,
2288 const union xnn_f32_minmax_params* params);
2289
2290 typedef void (*xnn_qu8_igemm_minmax_ukernel_function)(
2291 size_t mr,
2292 size_t nr,
2293 size_t kc,
2294 size_t ks,
2295 const uint8_t** a,
2296 const void* w,
2297 uint8_t* c,
2298 size_t cm_stride,
2299 size_t cn_stride,
2300 size_t a_offset,
2301 const uint8_t* zero,
2302 const union xnn_qu8_conv_minmax_params* params);
2303
2304 typedef void (*xnn_qc8_igemm_minmax_ukernel_function)(
2305 size_t mr,
2306 size_t nr,
2307 size_t kc,
2308 size_t ks,
2309 const int8_t** a,
2310 const void* w,
2311 int8_t* c,
2312 size_t cm_stride,
2313 size_t cn_stride,
2314 size_t a_offset,
2315 const int8_t* zero,
2316 const union xnn_qs8_minmax_params* params);
2317
2318 typedef void (*xnn_qs8_igemm_minmax_ukernel_function)(
2319 size_t mr,
2320 size_t nr,
2321 size_t kc,
2322 size_t ks,
2323 const int8_t** a,
2324 const void* w,
2325 int8_t* c,
2326 size_t cm_stride,
2327 size_t cn_stride,
2328 size_t a_offset,
2329 const int8_t* zero,
2330 const union xnn_qs8_conv_minmax_params* params);
2331
2332 typedef void (*xnn_conv_hwc_ukernel_function)(
2333 size_t input_height,
2334 size_t input_width,
2335 size_t output_y_start,
2336 size_t output_y_end,
2337 const void* input,
2338 const void* zero,
2339 const void* weights,
2340 void* output,
2341 size_t input_padding_top,
2342 size_t output_channels,
2343 size_t output_height_stride,
2344 size_t output_width_stride,
2345 const void* params);
2346
2347 typedef void (*xnn_f32_conv_hwc_ukernel_function)(
2348 size_t input_height,
2349 size_t input_width,
2350 size_t output_y_start,
2351 size_t output_y_end,
2352 const float* input,
2353 const float* zero,
2354 const float* weights,
2355 float* output,
2356 size_t input_padding_top,
2357 size_t output_channels,
2358 size_t output_height_stride,
2359 size_t output_width_stride,
2360 const union xnn_f32_minmax_params* params);
2361
2362 typedef void (*xnn_conv_hwc2chw_ukernel_function)(
2363 size_t input_height,
2364 size_t input_width,
2365 size_t output_y_start,
2366 size_t output_y_end,
2367 const void* input,
2368 const void* zero,
2369 const void* weights,
2370 void* output,
2371 size_t input_padding_top,
2372 size_t output_channels,
2373 size_t output_height_stride,
2374 size_t output_channel_stride,
2375 const void* params);
2376
2377 typedef void (*xnn_f32_conv_hwc2chw_ukernel_function)(
2378 size_t input_height,
2379 size_t input_width,
2380 size_t output_y_start,
2381 size_t output_y_end,
2382 const float* input,
2383 const float* zero,
2384 const float* weights,
2385 float* output,
2386 size_t input_padding_top,
2387 size_t output_channels,
2388 size_t output_height_stride,
2389 size_t output_channel_stride,
2390 const union xnn_f32_minmax_params* params);
2391
2392 typedef void (*xnn_spmm_ukernel_function)(
2393 size_t batch_size,
2394 size_t output_channels,
2395 const void* input,
2396 const void* weights,
2397 const int32_t* widx_dmap,
2398 const uint32_t* nidx_nnzmap,
2399 void* output,
2400 size_t output_stride,
2401 const void* params);
2402
2403 typedef void (*xnn_f16_spmm_minmax_ukernel_function)(
2404 size_t batch_size,
2405 size_t output_channels,
2406 const void* input,
2407 const void* weights,
2408 const int32_t* widx_dmap,
2409 const uint32_t* nidx_nnzmap,
2410 void* output,
2411 size_t output_stride,
2412 const union xnn_f16_scaleminmax_params* params);
2413
2414 typedef void (*xnn_f32_spmm_minmax_ukernel_function)(
2415 size_t batch_size,
2416 size_t output_channels,
2417 const float* input,
2418 const float* weights,
2419 const int32_t* widx_dmap,
2420 const uint32_t* nidx_nnzmap,
2421 float* output,
2422 size_t output_stride,
2423 const union xnn_f32_minmax_params* params);
2424
2425 typedef void (*xnn_packx_ukernel_function)(
2426 size_t m,
2427 size_t k,
2428 const void* x,
2429 size_t x_stride,
2430 void* y);
2431
2432 typedef void (*xnn_x32_packx_ukernel_function)(
2433 size_t m,
2434 size_t k,
2435 const uint32_t* x,
2436 size_t x_stride,
2437 uint32_t* y);
2438
2439 typedef void (*xnn_fill_ukernel_function)(
2440 size_t rows,
2441 size_t channels,
2442 void* output,
2443 size_t output_stride,
2444 const uint32_t fill_pattern);
2445
2446 typedef void (*xnn_depthtospace2d_chw2hwc_ukernel_function)(
2447 size_t output_channels,
2448 size_t input_height,
2449 size_t input_width,
2450 size_t block_size,
2451 const void* input,
2452 void* output,
2453 size_t output_channels_stride);
2454
2455 typedef void (*xnn_x32_depthtospace2d_chw2hwc_ukernel_function)(
2456 size_t output_channels,
2457 size_t input_height,
2458 size_t input_width,
2459 size_t block_size,
2460 const uint32_t* input,
2461 uint32_t* output,
2462 size_t output_channel_stride);
2463
2464 typedef void (*xnn_pad_ukernel_function)(
2465 size_t rows,
2466 size_t channels,
2467 size_t pre_padding,
2468 size_t post_padding,
2469 const void* input,
2470 size_t input_stride,
2471 void* output,
2472 size_t output_stride,
2473 const uint32_t fill_value);
2474
2475 typedef void (*xnn_unpool_ukernel_function)(
2476 size_t p,
2477 size_t c,
2478 uint32_t f,
2479 const void* input,
2480 const uint32_t* index,
2481 void** output);
2482
2483 typedef void (*xnn_x32_unpool_ukernel_function)(
2484 size_t p,
2485 size_t c,
2486 uint32_t f,
2487 const uint32_t* input,
2488 const uint32_t* index,
2489 uint32_t** output);
2490
2491 typedef void (*xnn_zipc_ukernel_function)(
2492 size_t n,
2493 const void* x,
2494 void* y);
2495
2496 typedef void (*xnn_x8_zipc_ukernel_function)(
2497 size_t n,
2498 const uint8_t* x,
2499 uint8_t* y);
2500
2501 typedef void (*xnn_x32_zipc_ukernel_function)(
2502 size_t n,
2503 const uint32_t* x,
2504 uint32_t* y);
2505
2506 typedef void (*xnn_zipv_ukernel_function)(
2507 size_t n,
2508 size_t m,
2509 const void* x,
2510 void* y);
2511
2512 typedef void (*xnn_x8_zipv_ukernel_function)(
2513 size_t n,
2514 size_t m,
2515 const uint8_t* x,
2516 uint8_t* y);
2517
2518 typedef void (*xnn_x32_zipv_ukernel_function)(
2519 size_t n,
2520 size_t m,
2521 const uint32_t* x,
2522 uint32_t* y);
2523
2524 typedef void (*xnn_x8_lut_ukernel_function)(
2525 size_t n,
2526 const uint8_t* x,
2527 uint8_t* y,
2528 const uint8_t* t);
2529
2530 typedef void (*xnn_dwconv2d_chw_ukernel_function)(
2531 size_t input_height,
2532 size_t input_width,
2533 const void* input,
2534 const void* weights,
2535 const void* zero,
2536 void* output,
2537 uint32_t padding_top,
2538 const void* params);
2539
2540 typedef void (*xnn_f32_dwconv2d_chw_ukernel_function)(
2541 size_t input_height,
2542 size_t input_width,
2543 const float* input,
2544 const float* weights,
2545 const float* zero,
2546 float* output,
2547 uint32_t padding_top,
2548 const union xnn_f32_chw_params* params);
2549
2550 typedef void (*xnn_dwconv_unipass_ukernel_function)(
2551 size_t channels,
2552 size_t output_width,
2553 const void** input,
2554 const void* weights,
2555 void* output,
2556 size_t input_stride,
2557 size_t output_increment,
2558 size_t input_offset,
2559 const void* zero,
2560 const void* params);
2561
2562 typedef void (*xnn_f32_dwconv_unipass_ukernel_function)(
2563 size_t channels,
2564 size_t output_width,
2565 const float** input,
2566 const float* weights,
2567 float* output,
2568 size_t input_stride,
2569 size_t output_increment,
2570 size_t input_offset,
2571 const float* zero,
2572 const union xnn_f32_default_params* params);
2573
2574 typedef void (*xnn_f32_dwconv_minmax_unipass_ukernel_function)(
2575 size_t channels,
2576 size_t output_width,
2577 const float** input,
2578 const float* weights,
2579 float* output,
2580 size_t input_stride,
2581 size_t output_increment,
2582 size_t input_offset,
2583 const float* zero,
2584 const union xnn_f32_minmax_params* params);
2585
2586 typedef void (*xnn_f16_dwconv_minmax_unipass_ukernel_function)(
2587 size_t channels,
2588 size_t output_width,
2589 const void** input,
2590 const void* weights,
2591 void* output,
2592 size_t input_stride,
2593 size_t output_increment,
2594 size_t input_offset,
2595 const void* zero,
2596 const union xnn_f16_minmax_params* params);
2597
2598 typedef void (*xnn_qc8_dwconv_minmax_unipass_ukernel_function)(
2599 size_t channels,
2600 size_t output_width,
2601 const int8_t** input,
2602 const void* weights,
2603 int8_t* output,
2604 size_t input_stride,
2605 size_t output_increment,
2606 size_t input_offset,
2607 const int8_t* zero,
2608 const union xnn_qs8_minmax_params* params);
2609
2610 typedef void (*xnn_qs8_dwconv_minmax_unipass_ukernel_function)(
2611 size_t channels,
2612 size_t output_width,
2613 const int8_t** input,
2614 const void* weights,
2615 int8_t* output,
2616 size_t input_stride,
2617 size_t output_increment,
2618 size_t input_offset,
2619 const int8_t* zero,
2620 const union xnn_qs8_conv_minmax_params* params);
2621
2622 typedef void (*xnn_qu8_dwconv_minmax_unipass_ukernel_function)(
2623 size_t channels,
2624 size_t output_width,
2625 const uint8_t** input,
2626 const void* weights,
2627 uint8_t* output,
2628 size_t input_stride,
2629 size_t output_increment,
2630 size_t input_offset,
2631 const uint8_t* zero,
2632 const union xnn_qu8_conv_minmax_params* params);
2633
2634 typedef void (*xnn_dwconv_multipass_ukernel_function)(
2635 size_t channels,
2636 size_t output_width,
2637 const void** input,
2638 const void* weights,
2639 void* buffer,
2640 void* output,
2641 size_t input_stride,
2642 size_t output_increment,
2643 size_t input_offset,
2644 const void* zero,
2645 const void* params);
2646
2647 typedef void (*xnn_f32_ibilinear_ukernel_function)(
2648 size_t output_pixels,
2649 size_t channels,
2650 const float** input,
2651 size_t input_offset,
2652 const float* weights,
2653 float* output,
2654 size_t output_increment);
2655
2656 typedef void (*xnn_s8_ibilinear_ukernel_function)(
2657 size_t output_pixels,
2658 size_t channels,
2659 const int8_t** input,
2660 size_t input_offset,
2661 const int16_t* weights,
2662 int8_t* output,
2663 size_t output_increment);
2664
2665 typedef void (*xnn_u8_ibilinear_ukernel_function)(
2666 size_t output_pixels,
2667 size_t channels,
2668 const uint8_t** input,
2669 size_t input_offset,
2670 const int16_t* weights,
2671 uint8_t* output,
2672 size_t output_increment);
2673
2674 typedef void (*xnn_ibilinear_ukernel_function)(
2675 size_t output_pixels,
2676 size_t channels,
2677 const void** input,
2678 size_t input_offset,
2679 const void* weights,
2680 void* output,
2681 size_t output_increment);
2682
2683 typedef void (*xnn_f32_ibilinear_chw_ukernel_function)(
2684 size_t output_pixels,
2685 size_t channels,
2686 const float** input,
2687 size_t input_offset,
2688 const float* weights,
2689 float* output,
2690 size_t input_increment);
2691
2692 typedef void (*xnn_ibilinear_chw_ukernel_function)(
2693 size_t output_pixels,
2694 size_t channels,
2695 const void** input,
2696 size_t input_offset,
2697 const void* weights,
2698 void* output,
2699 size_t input_increment);
2700
2701 typedef void (*xnn_gavgpool_unipass_ukernel_function)(
2702 size_t rows,
2703 size_t channels,
2704 const void* input,
2705 size_t input_stride,
2706 const void* zero,
2707 void* output,
2708 const void* params);
2709
2710 typedef void (*xnn_f16_gavgpool_minmax_unipass_ukernel_function)(
2711 size_t rows,
2712 size_t channels,
2713 const void* input,
2714 size_t input_stride,
2715 const void* zero,
2716 void* output,
2717 const union xnn_f16_scaleminmax_params* params);
2718
2719 typedef void (*xnn_f32_gavgpool_minmax_unipass_ukernel_function)(
2720 size_t rows,
2721 size_t channels,
2722 const float* input,
2723 size_t input_stride,
2724 const float* zero,
2725 float* output,
2726 const union xnn_f32_scaleminmax_params* params);
2727
2728 typedef void (*xnn_qu8_gavgpool_minmax_unipass_ukernel_function)(
2729 size_t rows,
2730 size_t channels,
2731 const uint8_t* input,
2732 size_t input_stride,
2733 const uint8_t* zero,
2734 uint8_t* output,
2735 const union xnn_qu8_avgpool_minmax_params* params);
2736
2737 typedef void (*xnn_qs8_gavgpool_minmax_unipass_ukernel_function)(
2738 size_t rows,
2739 size_t channels,
2740 const int8_t* input,
2741 size_t input_stride,
2742 const int8_t* zero,
2743 int8_t* output,
2744 const union xnn_qs8_avgpool_minmax_params* params);
2745
2746 typedef void (*xnn_gavgpool_multipass_ukernel_function)(
2747 size_t rows,
2748 size_t channels,
2749 const void* input,
2750 size_t input_stride,
2751 const void* zero,
2752 void* buffer,
2753 void* output,
2754 const void* params);
2755
2756 typedef void (*xnn_f16_gavgpool_minmax_multipass_ukernel_function)(
2757 size_t rows,
2758 size_t channels,
2759 const void* input,
2760 size_t input_stride,
2761 const void* zero,
2762 void* buffer,
2763 void* output,
2764 const union xnn_f16_scaleminmax_params* params);
2765
2766 typedef void (*xnn_f32_gavgpool_minmax_multipass_ukernel_function)(
2767 size_t rows,
2768 size_t channels,
2769 const float* input,
2770 size_t input_stride,
2771 const float* zero,
2772 float* buffer,
2773 float* output,
2774 const union xnn_f32_scaleminmax_params* params);
2775
2776 typedef void (*xnn_qu8_gavgpool_minmax_multipass_ukernel_function)(
2777 size_t rows,
2778 size_t channels,
2779 const uint8_t* input,
2780 size_t input_stride,
2781 const uint8_t* zero,
2782 int32_t* buffer,
2783 uint8_t* output,
2784 const union xnn_qu8_avgpool_minmax_params* params);
2785
2786 typedef void (*xnn_qs8_gavgpool_minmax_multipass_ukernel_function)(
2787 size_t rows,
2788 size_t channels,
2789 const int8_t* input,
2790 size_t input_stride,
2791 const int8_t* zero,
2792 int32_t* buffer,
2793 int8_t* output,
2794 const union xnn_qs8_avgpool_minmax_params* params);
2795
2796 typedef void (*xnn_gavgpool_cw_ukernel_function)(
2797 size_t elements,
2798 size_t channels,
2799 const float* input,
2800 float* output,
2801 const void* params);
2802
2803 typedef void (*xnn_f32_gavgpool_cw_ukernel_function)(
2804 size_t elements,
2805 size_t channels,
2806 const float* input,
2807 float* output,
2808 const union xnn_f32_gavgpool_params* params);
2809
2810 typedef void (*xnn_avgpool_unipass_ukernel_function)(
2811 size_t output_pixels,
2812 size_t kernel_elements,
2813 size_t channels,
2814 const void** input,
2815 size_t input_offset,
2816 const void* zero,
2817 void* output,
2818 size_t input_increment,
2819 size_t output_increment,
2820 const void* params);
2821
2822 typedef void (*xnn_f32_avgpool_minmax_unipass_ukernel_function)(
2823 size_t output_pixels,
2824 size_t kernel_elements,
2825 size_t channels,
2826 const float** input,
2827 size_t input_offset,
2828 const float* zero,
2829 float* output,
2830 size_t input_increment,
2831 size_t output_increment,
2832 const union xnn_f32_scaleminmax_params* params);
2833
2834 typedef void (*xnn_qu8_avgpool_minmax_unipass_ukernel_function)(
2835 size_t output_pixels,
2836 size_t kernel_elements,
2837 size_t channels,
2838 const uint8_t** input,
2839 size_t input_offset,
2840 const uint8_t* zero,
2841 uint8_t* output,
2842 size_t input_increment,
2843 size_t output_increment,
2844 const union xnn_qu8_avgpool_minmax_params* params);
2845
2846 typedef void (*xnn_avgpool_multipass_ukernel_function)(
2847 size_t output_pixels,
2848 size_t kernel_elements,
2849 size_t channels,
2850 const void** input,
2851 size_t input_offset,
2852 const void* zero,
2853 void* buffer,
2854 void* output,
2855 size_t input_increment,
2856 size_t output_increment,
2857 const void* params);
2858
2859 typedef void (*xnn_f32_avgpool_minmax_multipass_ukernel_function)(
2860 size_t output_pixels,
2861 size_t kernel_elements,
2862 size_t channels,
2863 const float** input,
2864 size_t input_offset,
2865 const float* zero,
2866 float* buffer,
2867 float* output,
2868 size_t input_increment,
2869 size_t output_increment,
2870 const union xnn_f32_scaleminmax_params* params);
2871
2872 typedef void (*xnn_qu8_avgpool_minmax_multipass_ukernel_function)(
2873 size_t output_pixels,
2874 size_t kernel_elements,
2875 size_t channels,
2876 const uint8_t** input,
2877 size_t input_offset,
2878 const uint8_t* zero,
2879 int32_t* buffer,
2880 uint8_t* output,
2881 size_t input_increment,
2882 size_t output_increment,
2883 const union xnn_qu8_avgpool_minmax_params* params);
2884
2885 typedef void (*xnn_pavgpool_unipass_ukernel_function)(
2886 size_t output_pixels,
2887 size_t kernel_elements,
2888 size_t channels,
2889 const void** input,
2890 size_t input_offset,
2891 const void* zero,
2892 const void* multiplier,
2893 void* output,
2894 size_t input_increment,
2895 size_t output_increment,
2896 const void* params);
2897
2898 typedef void (*xnn_f32_pavgpool_minmax_unipass_ukernel_function)(
2899 size_t output_pixels,
2900 size_t kernel_elements,
2901 size_t channels,
2902 const float** input,
2903 size_t input_offset,
2904 const float* zero,
2905 const float* multiplier,
2906 float* output,
2907 size_t input_increment,
2908 size_t output_increment,
2909 const union xnn_f32_minmax_params* params);
2910
2911 typedef void (*xnn_pavgpool_multipass_ukernel_function)(
2912 size_t output_pixels,
2913 size_t kernel_elements,
2914 size_t channels,
2915 const void** input,
2916 size_t input_offset,
2917 const void* zero,
2918 const void* multiplier,
2919 void* buffer,
2920 void* output,
2921 size_t input_increment,
2922 size_t output_increment,
2923 const void* params);
2924
2925 typedef void (*xnn_f32_pavgpool_minmax_multipass_ukernel_function)(
2926 size_t output_pixels,
2927 size_t kernel_elements,
2928 size_t channels,
2929 const float** input,
2930 size_t input_offset,
2931 const float* zero,
2932 const float* multiplier,
2933 float* buffer,
2934 float* output,
2935 size_t input_increment,
2936 size_t output_increment,
2937 const union xnn_f32_minmax_params* params);
2938
2939 typedef void (*xnn_maxpool_ukernel_function)(
2940 size_t output_pixels,
2941 size_t kernel_elements,
2942 size_t channels,
2943 const void** input,
2944 size_t input_offset,
2945 void* output,
2946 size_t input_increment,
2947 size_t output_increment,
2948 const void* params);
2949
2950 typedef void (*xnn_f16_maxpool_ukernel_function)(
2951 size_t output_pixels,
2952 size_t kernel_elements,
2953 size_t channels,
2954 const void** input,
2955 size_t input_offset,
2956 void* output,
2957 size_t input_increment,
2958 size_t output_increment,
2959 const union xnn_f16_minmax_params* params);
2960
2961 typedef void (*xnn_f32_maxpool_ukernel_function)(
2962 size_t output_pixels,
2963 size_t kernel_elements,
2964 size_t channels,
2965 const float** input,
2966 size_t input_offset,
2967 float* output,
2968 size_t input_increment,
2969 size_t output_increment,
2970 const union xnn_f32_minmax_params* params);
2971
2972 typedef void (*xnn_s8_maxpool_ukernel_function)(
2973 size_t output_pixels,
2974 size_t kernel_elements,
2975 size_t channels,
2976 const int8_t** input,
2977 size_t input_offset,
2978 int8_t* output,
2979 size_t input_increment,
2980 size_t output_increment,
2981 const union xnn_s8_minmax_params* params);
2982
2983 typedef void (*xnn_u8_maxpool_ukernel_function)(
2984 size_t output_pixels,
2985 size_t kernel_elements,
2986 size_t channels,
2987 const uint8_t** input,
2988 size_t input_offset,
2989 uint8_t* output,
2990 size_t input_increment,
2991 size_t output_increment,
2992 const union xnn_u8_minmax_params* params);
2993
2994 typedef void (*xnn_argmaxpool_unipass_ukernel_function)(
2995 size_t output_pixels,
2996 size_t kernel_elements,
2997 size_t channels,
2998 const void** input,
2999 size_t input_offset,
3000 void* output,
3001 uint32_t* index,
3002 size_t input_increment,
3003 size_t output_increment);
3004
3005 typedef void (*xnn_f32_argmaxpool_unipass_ukernel_function)(
3006 size_t output_pixels,
3007 size_t kernel_elements,
3008 size_t channels,
3009 const float** input,
3010 size_t input_offset,
3011 float* output,
3012 uint32_t* index,
3013 size_t input_increment,
3014 size_t output_increment);
3015
3016 typedef void (*xnn_argmaxpool_multipass_ukernel_function)(
3017 size_t output_pixels,
3018 size_t kernel_elements,
3019 size_t channels,
3020 const void** input,
3021 size_t input_offset,
3022 void* accumulation_buffer,
3023 uint32_t* index_buffer,
3024 void* output,
3025 uint32_t* index,
3026 size_t input_increment,
3027 size_t output_increment);
3028
3029 typedef void (*xnn_f32_argmaxpool_multipass_ukernel_function)(
3030 size_t output_pixels,
3031 size_t kernel_elements,
3032 size_t channels,
3033 const float** input,
3034 size_t input_offset,
3035 float* accumulation_buffer,
3036 uint32_t* index_buffer,
3037 float* output,
3038 uint32_t* index,
3039 size_t input_increment,
3040 size_t output_increment);
3041
3042 typedef void (*xnn_univector_ukernel_function)(
3043 size_t n,
3044 const void* x,
3045 void* y,
3046 const void* params);
3047
3048 typedef void (*xnn_f16_vclamp_ukernel_function)(
3049 size_t n,
3050 const void* x,
3051 void* y,
3052 const union xnn_f16_minmax_params* params);
3053
3054 typedef void (*xnn_f32_vclamp_ukernel_function)(
3055 size_t n,
3056 const float* x,
3057 float* y,
3058 const union xnn_f32_minmax_params* params);
3059
3060 typedef void (*xnn_s8_vclamp_ukernel_function)(
3061 size_t n,
3062 const int8_t* x,
3063 int8_t* y,
3064 const union xnn_s8_minmax_params* params);
3065
3066 typedef void (*xnn_u8_vclamp_ukernel_function)(
3067 size_t n,
3068 const uint8_t* x,
3069 uint8_t* y,
3070 const union xnn_u8_minmax_params* params);
3071
3072 typedef void (*xnn_f32_vrelu_ukernel_function)(
3073 size_t n,
3074 const float* x,
3075 float* y,
3076 const union xnn_f32_relu_params* params);
3077
3078 typedef void (*xnn_f16_vhswish_ukernel_function)(
3079 size_t n,
3080 const void* x,
3081 void* y,
3082 const union xnn_f16_hswish_params* params);
3083
3084 typedef void (*xnn_f32_vabs_ukernel_function)(
3085 size_t n,
3086 const float* x,
3087 float* y,
3088 const union xnn_f32_abs_params* params);
3089
3090 typedef void (*xnn_f32_vhswish_ukernel_function)(
3091 size_t n,
3092 const float* x,
3093 float* y,
3094 const union xnn_f32_hswish_params* params);
3095
3096 typedef void (*xnn_f32_vlrelu_ukernel_function)(
3097 size_t n,
3098 const float* x,
3099 float* y,
3100 const union xnn_f32_lrelu_params* params);
3101
3102 typedef void (*xnn_f32_vneg_ukernel_function)(
3103 size_t n,
3104 const float* x,
3105 float* y,
3106 const union xnn_f32_neg_params* params);
3107
3108 typedef void (*xnn_f32_vround_ukernel_function)(
3109 size_t n,
3110 const float* x,
3111 float* y,
3112 const union xnn_f32_rnd_params* params);
3113
3114 typedef void (*xnn_f32_vsigmoid_ukernel_function)(
3115 size_t n,
3116 const float* x,
3117 float* y,
3118 const union xnn_f32_sigmoid_params* params);
3119
3120 typedef void (*xnn_rmax_ukernel_function)(
3121 size_t n,
3122 const void* x,
3123 void* y);
3124
3125 typedef void (*xnn_u8_rmax_ukernel_function)(
3126 size_t n,
3127 const uint8_t* x,
3128 uint8_t* y);
3129
3130 typedef void (*xnn_f32_rmax_ukernel_function)(
3131 size_t n,
3132 const float* x,
3133 float* y);
3134
3135 typedef void (*xnn_u8_lut32norm_ukernel_function)(
3136 size_t n,
3137 const uint8_t* x,
3138 const uint32_t* t,
3139 uint8_t* y);
3140
3141 typedef void (*xnn_vadd_ukernel_function)(
3142 size_t n,
3143 const void* a,
3144 const void* b,
3145 void* y,
3146 const void* params);
3147
3148 typedef void (*xnn_qu8_vaddsub_minmax_ukernel_function)(
3149 size_t n,
3150 const uint8_t* input_x,
3151 const uint8_t* input_y,
3152 uint8_t* output,
3153 const union xnn_qu8_addsub_minmax_params* params);
3154
3155 typedef void (*xnn_qs8_vaddsub_minmax_ukernel_function)(
3156 size_t n,
3157 const int8_t* input_x,
3158 const int8_t* input_y,
3159 int8_t* output,
3160 const union xnn_qs8_addsub_minmax_params* params);
3161
3162 typedef void (*xnn_qu8_vmul_minmax_ukernel_function)(
3163 size_t n,
3164 const uint8_t* input_x,
3165 const uint8_t* input_y,
3166 uint8_t* output,
3167 const union xnn_qu8_mul_minmax_params* params);
3168
3169 typedef void (*xnn_qs8_vmul_minmax_ukernel_function)(
3170 size_t n,
3171 const int8_t* input_x,
3172 const int8_t* input_y,
3173 int8_t* output,
3174 const union xnn_qs8_mul_minmax_params* params);
3175
3176 typedef void (*xnn_f32_velu_ukernel_function)(
3177 size_t n,
3178 const float* x,
3179 float* y,
3180 const union xnn_f32_elu_params* params);
3181
3182
3183 typedef void (*xnn_f32_vsqr_ukernel_function)(
3184 size_t n,
3185 const float* x,
3186 float* y,
3187 const union xnn_f32_default_params* params);
3188
3189 typedef void (*xnn_f32_vsqrt_ukernel_function)(
3190 size_t n,
3191 const float* x,
3192 float* y,
3193 const union xnn_f32_sqrt_params* params);
3194
3195 typedef void (*xnn_vbinary_ukernel_function)(
3196 size_t n,
3197 const void* a,
3198 const void* b,
3199 void* y,
3200 const void* params);
3201
3202 typedef void (*xnn_f16_vbinary_ukernel_function)(
3203 size_t n,
3204 const void* a,
3205 const void* b,
3206 void* y,
3207 const union xnn_f16_default_params* params);
3208
3209 typedef void (*xnn_f16_vbinary_minmax_ukernel_function)(
3210 size_t n,
3211 const void* a,
3212 const void* b,
3213 void* y,
3214 const union xnn_f16_minmax_params* params);
3215
3216 typedef void (*xnn_f32_vbinary_ukernel_function)(
3217 size_t n,
3218 const float* a,
3219 const float* b,
3220 float* y,
3221 const union xnn_f32_default_params* params);
3222
3223 typedef void (*xnn_f32_vbinary_minmax_ukernel_function)(
3224 size_t n,
3225 const float* a,
3226 const float* b,
3227 float* y,
3228 const union xnn_f32_minmax_params* params);
3229
3230 typedef void (*xnn_f32_vbinary_relu_ukernel_function)(
3231 size_t n,
3232 const float* a,
3233 const float* b,
3234 float* y,
3235 const union xnn_f32_relu_params* params);
3236
3237 typedef void (*xnn_vunary_ukernel_function)(
3238 size_t n,
3239 const void* x,
3240 void* y,
3241 const void* params);
3242
3243 typedef void (*xnn_s8_vunary_ukernel_function)(
3244 size_t n,
3245 const int8_t* x,
3246 int8_t* y,
3247 const void* params);
3248
3249 typedef void (*xnn_u8_vunary_ukernel_function)(
3250 size_t n,
3251 const uint8_t* x,
3252 uint8_t* y,
3253 const void* params);
3254
3255 typedef void (*xnn_f16_vunary_ukernel_function)(
3256 size_t n,
3257 const uint16_t* x,
3258 uint16_t* y,
3259 const void* params);
3260
3261 typedef void (*xnn_f32_vunary_ukernel_function)(
3262 size_t n,
3263 const float* x,
3264 float* y,
3265 const void* params);
3266
3267 typedef void (*xnn_f16_f32_vcvt_ukernel_function)(
3268 size_t n,
3269 const void* input,
3270 float* output,
3271 const union xnn_f16_f32_cvt_params* params);
3272
3273 typedef void (*xnn_f32_f16_vcvt_ukernel_function)(
3274 size_t n,
3275 const float* input,
3276 void* output,
3277 const union xnn_f32_f16_cvt_params* params);
3278
3279 typedef void (*xnn_f32_qs8_vcvt_ukernel_function)(
3280 size_t n,
3281 const float* input,
3282 int8_t* output,
3283 const union xnn_f32_qs8_cvt_params* params);
3284
3285 typedef void (*xnn_f32_qu8_vcvt_ukernel_function)(
3286 size_t n,
3287 const float* input,
3288 uint8_t* output,
3289 const union xnn_f32_qu8_cvt_params* params);
3290
3291 typedef void (*xnn_qs8_f32_vcvt_ukernel_function)(
3292 size_t n,
3293 const int8_t* input,
3294 float* output,
3295 const union xnn_qs8_f32_cvt_params* params);
3296
3297 typedef void (*xnn_qu8_f32_vcvt_ukernel_function)(
3298 size_t n,
3299 const uint8_t* input,
3300 float* output,
3301 const union xnn_qu8_f32_cvt_params* params);
3302
3303 typedef void (*xnn_vmulcaddc_ukernel_function)(
3304 size_t m,
3305 size_t c,
3306 const void* x,
3307 size_t x_stride,
3308 const void* w,
3309 void* y,
3310 size_t y_stride,
3311 const void* params);
3312
3313 typedef void (*xnn_f16_vmulcaddc_ukernel_function)(
3314 size_t m,
3315 size_t c,
3316 const void* x,
3317 size_t x_stride,
3318 const void* w,
3319 void* y,
3320 size_t y_stride,
3321 const union xnn_f16_minmax_params* params);
3322
3323 typedef void (*xnn_f32_vmulcaddc_ukernel_function)(
3324 size_t m,
3325 size_t c,
3326 const float* x,
3327 size_t x_stride,
3328 const float* w,
3329 float* y,
3330 size_t y_stride,
3331 const union xnn_f32_minmax_params* params);
3332
3333 typedef void (*xnn_prelu_ukernel_function)(
3334 size_t mr,
3335 size_t n,
3336 const void* x,
3337 size_t x_stride,
3338 const void* w,
3339 void* y,
3340 size_t y_stride);
3341
3342 typedef void (*xnn_f16_prelu_ukernel_function)(
3343 size_t mr,
3344 size_t n,
3345 const void* x,
3346 size_t x_stride,
3347 const void* w,
3348 void* y,
3349 size_t y_stride);
3350
3351 typedef void (*xnn_f32_prelu_ukernel_function)(
3352 size_t mr,
3353 size_t n,
3354 const float* x,
3355 size_t x_stride,
3356 const float* w,
3357 float* y,
3358 size_t y_stride);
3359
3360 typedef void (*xnn_f32_raddexpminusmax_ukernel_function)(
3361 size_t n,
3362 const float* input,
3363 float* sum,
3364 float max);
3365
3366 typedef void (*xnn_f32_raddstoreexpminusmax_ukernel_function)(
3367 size_t n,
3368 const float* input,
3369 const float* max,
3370 float* output,
3371 float* sum,
3372 const union xnn_f32_expminus_params* params);
3373
3374 typedef void (*xnn_f32_vscaleexpminusmax_ukernel_function)(
3375 size_t n,
3376 const float* input,
3377 float* output,
3378 float max,
3379 float scale);
3380
3381 typedef void (*xnn_f32_vscale_ukernel_function)(
3382 size_t n,
3383 const float* x,
3384 float* y,
3385 float c);
3386
3387 // Reduce-Add Extended ("mantissa" + "exponent") Exponentials
3388 typedef void (*xnn_f32_raddextexp_ukernel_function)(
3389 size_t n,
3390 const float* input,
3391 float* sum);
3392
3393 // Vector Scale Extended ("mantissa" + "exponent") Exponentials
3394 typedef void (*xnn_f32_vscaleextexp_ukernel_function)(
3395 size_t n,
3396 const float* input,
3397 float* output,
3398 float scale_mantissa,
3399 float scale_exponent);
3400
3401 typedef void (*xnn_init_f16_f32_cvt_params_fn)(
3402 union xnn_f16_f32_cvt_params params[XNN_MIN_ELEMENTS(1)]);
3403
3404 typedef void (*xnn_init_f32_f16_cvt_params_fn)(
3405 union xnn_f32_f16_cvt_params params[XNN_MIN_ELEMENTS(1)]);
3406
3407 typedef void (*xnn_init_f32_qs8_cvt_params_fn)(
3408 union xnn_f32_qs8_cvt_params params[XNN_MIN_ELEMENTS(1)],
3409 float scale,
3410 int8_t output_zero_point,
3411 int8_t output_min,
3412 int8_t output_max);
3413
3414 typedef void (*xnn_init_f32_qu8_cvt_params_fn)(
3415 union xnn_f32_qu8_cvt_params params[XNN_MIN_ELEMENTS(1)],
3416 float scale,
3417 uint8_t output_zero_point,
3418 uint8_t output_min,
3419 uint8_t output_max);
3420
3421 typedef void (*xnn_init_qs8_f32_cvt_params_fn)(
3422 union xnn_qs8_f32_cvt_params params[XNN_MIN_ELEMENTS(1)],
3423 float scale,
3424 int8_t zero_point);
3425
3426 typedef void (*xnn_init_qu8_f32_cvt_params_fn)(
3427 union xnn_qu8_f32_cvt_params params[XNN_MIN_ELEMENTS(1)],
3428 float scale,
3429 uint8_t zero_point);
3430
3431 typedef void (*xnn_init_qs8_minmax_params_fn)(
3432 union xnn_qs8_minmax_params params[XNN_MIN_ELEMENTS(1)],
3433 int8_t output_zero_point,
3434 int8_t output_min,
3435 int8_t output_max);
3436
3437 typedef void (*xnn_init_qs8_conv_minmax_params_fn)(
3438 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
3439 float scale,
3440 int8_t output_zero_point,
3441 int8_t output_min,
3442 int8_t output_max);
3443
3444 typedef void (*xnn_init_qu8_conv_minmax_params_fn)(
3445 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
3446 uint8_t kernel_zero_point,
3447 float scale,
3448 uint8_t output_zero_point,
3449 uint8_t output_min,
3450 uint8_t output_max);
3451
3452 typedef void (*xnn_init_qs8_avgpool_minmax_params_fn)(
3453 union xnn_qs8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)],
3454 int32_t bias,
3455 float scale,
3456 int8_t output_zero_point,
3457 int8_t output_min,
3458 int8_t output_max);
3459
3460 typedef void (*xnn_init_qu8_avgpool_minmax_params_fn)(
3461 union xnn_qu8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)],
3462 int32_t bias,
3463 float scale,
3464 uint8_t output_zero_point,
3465 uint8_t output_min,
3466 uint8_t output_max);
3467
3468 typedef void (*xnn_update_qs8_avgpool_minmax_params_fn)(
3469 union xnn_qs8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)],
3470 int32_t bias,
3471 float scale);
3472
3473 typedef void (*xnn_update_qu8_avgpool_minmax_params_fn)(
3474 union xnn_qu8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)],
3475 int32_t bias,
3476 float scale);
3477
3478 typedef void (*xnn_init_qs8_addsub_minmax_params_fn)(
3479 union xnn_qs8_addsub_minmax_params params[XNN_MIN_ELEMENTS(1)],
3480 int8_t a_zero_point,
3481 int8_t b_zero_point,
3482 int8_t output_zero_point,
3483 float a_output_scale,
3484 float b_output_scale,
3485 int8_t output_min,
3486 int8_t output_max);
3487
3488 typedef void (*xnn_init_qu8_addsub_minmax_params_fn)(
3489 union xnn_qu8_addsub_minmax_params params[XNN_MIN_ELEMENTS(1)],
3490 uint8_t a_zero_point,
3491 uint8_t b_zero_point,
3492 uint8_t output_zero_point,
3493 float a_output_scale,
3494 float b_output_scale,
3495 uint8_t output_min,
3496 uint8_t output_max);
3497
3498 typedef void (*xnn_init_qs8_mul_minmax_params_fn)(
3499 union xnn_qs8_mul_minmax_params params[XNN_MIN_ELEMENTS(1)],
3500 int8_t a_zero_point,
3501 int8_t b_zero_point,
3502 int8_t output_zero_point,
3503 float product_output_scale,
3504 int8_t output_min,
3505 int8_t output_max);
3506
3507 typedef void (*xnn_init_qu8_mul_minmax_params_fn)(
3508 union xnn_qu8_mul_minmax_params params[XNN_MIN_ELEMENTS(1)],
3509 uint8_t a_zero_point,
3510 uint8_t b_zero_point,
3511 uint8_t output_zero_point,
3512 float product_output_scale,
3513 uint8_t output_min,
3514 uint8_t output_max);
3515
3516 typedef void (*xnn_init_f16_hswish_params_fn)(
3517 union xnn_f16_hswish_params params[XNN_MIN_ELEMENTS(1)]);
3518
3519 typedef void (*xnn_init_f16_minmax_params_fn)(
3520 union xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)],
3521 uint16_t min,
3522 uint16_t max);
3523
3524 typedef void (*xnn_init_f16_scaleminmax_params_fn)(
3525 union xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
3526 uint16_t scale,
3527 uint16_t min,
3528 uint16_t max);
3529
3530 typedef void (*xnn_update_f16_scaleminmax_params_fn)(
3531 union xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
3532 uint16_t scale);
3533
3534 typedef void (*xnn_init_f32_abs_params_fn)(
3535 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)]);
3536
3537 typedef void (*xnn_init_f32_default_params_fn)(
3538 union xnn_f32_default_params params[XNN_MIN_ELEMENTS(1)]);
3539
3540 typedef void (*xnn_init_f32_expminus_params_fn)(
3541 union xnn_f32_expminus_params params[XNN_MIN_ELEMENTS(1)]);
3542
3543 typedef void (*xnn_init_f32_elu_params_fn)(
3544 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)],
3545 float prescale,
3546 float alpha,
3547 float beta);
3548
3549 typedef void (*xnn_init_f32_hswish_params_fn)(
3550 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)]);
3551
3552 typedef void (*xnn_init_f32_lrelu_params_fn)(
3553 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)],
3554 float slope);
3555
3556 typedef void (*xnn_init_f32_minmax_params_fn)(
3557 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)],
3558 float output_min,
3559 float output_max);
3560
3561 typedef void (*xnn_init_f32_neg_params_fn)(
3562 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)]);
3563
3564 typedef void (*xnn_init_f32_rnd_params_fn)(
3565 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)]);
3566
3567 typedef void (*xnn_init_f32_scaleminmax_params_fn)(
3568 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
3569 float scale,
3570 float output_min,
3571 float output_max);
3572
3573 typedef void (*xnn_update_f32_scaleminmax_params_fn)(
3574 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)],
3575 float scale);
3576
3577 typedef void (*xnn_init_f32_sigmoid_params_fn)(
3578 union xnn_f32_sigmoid_params params[XNN_MIN_ELEMENTS(1)]);
3579
3580 typedef void (*xnn_init_f32_sqrt_params_fn)(
3581 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)]);
3582
3583 typedef void (*xnn_init_s8_minmax_params_fn)(
3584 union xnn_s8_minmax_params params[XNN_MIN_ELEMENTS(1)],
3585 int8_t output_min,
3586 int8_t output_max);
3587
3588 typedef void (*xnn_init_u8_minmax_params_fn)(
3589 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)],
3590 uint8_t output_min,
3591 uint8_t output_max);
3592
3593 typedef void (*xnn_init_qc8_scale_params_fn)(
3594 size_t channels,
3595 size_t channels_tile,
3596 size_t stride,
3597 const float scale[XNN_MIN_ELEMENTS(1)],
3598 void* packed_w);
3599
3600 // Forward declare to avoid circular includes between this and allocator.h.
3601 struct xnn_code_buffer;
3602
3603 struct jit_gemm_params {
3604 struct {
3605 float min;
3606 float max;
3607 } f32_minmax;
3608 };
3609
3610 typedef enum xnn_status (*xnn_jit_gemm_code_generator_function)(
3611 struct xnn_code_buffer *code, size_t nc, size_t kc, const void *params);
3612 typedef enum xnn_status (*xnn_jit_igemm_code_generator_function)(
3613 struct xnn_code_buffer *code, size_t nc, size_t kc, size_t ks, const void *params);
3614
3615 struct xnn_hmp_gemm_ukernel {
3616 xnn_gemm_ukernel_function function[XNN_MAX_UARCH_TYPES];
3617 };
3618
xnn_init_hmp_gemm_ukernel(xnn_gemm_ukernel_function function)3619 static inline struct xnn_hmp_gemm_ukernel xnn_init_hmp_gemm_ukernel(xnn_gemm_ukernel_function function) {
3620 struct xnn_hmp_gemm_ukernel ukernel = {{ function }};
3621 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3622 ukernel.function[i] = function;
3623 }
3624 return ukernel;
3625 }
3626
xnn_is_hmp_gemm_ukernel(struct xnn_hmp_gemm_ukernel ukernel)3627 static inline bool xnn_is_hmp_gemm_ukernel(struct xnn_hmp_gemm_ukernel ukernel) {
3628 #if XNN_MAX_UARCH_TYPES == 1
3629 return false;
3630 #else
3631 uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
3632 uintptr_t difference = 0;
3633 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3634 difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
3635 }
3636 return difference != 0;
3637 #endif
3638 }
3639
3640 struct xnn_hmp_igemm_ukernel {
3641 xnn_igemm_ukernel_function function[XNN_MAX_UARCH_TYPES];
3642 };
3643
xnn_init_hmp_igemm_ukernel(xnn_igemm_ukernel_function function)3644 static inline struct xnn_hmp_igemm_ukernel xnn_init_hmp_igemm_ukernel(xnn_igemm_ukernel_function function) {
3645 struct xnn_hmp_igemm_ukernel ukernel = {{ function }};
3646 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3647 ukernel.function[i] = function;
3648 }
3649 return ukernel;
3650 }
3651
xnn_is_hmp_igemm_ukernel(struct xnn_hmp_igemm_ukernel ukernel)3652 static inline bool xnn_is_hmp_igemm_ukernel(struct xnn_hmp_igemm_ukernel ukernel) {
3653 #if XNN_MAX_UARCH_TYPES == 1
3654 return false;
3655 #else
3656 uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
3657 uintptr_t difference = 0;
3658 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3659 difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
3660 }
3661 return difference != 0;
3662 #endif
3663 }
3664
3665 struct gemm_fused_ukernels {
3666 struct xnn_hmp_gemm_ukernel gemm;
3667 struct xnn_hmp_igemm_ukernel igemm;
3668 // Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters.
3669 struct xnn_hmp_gemm_ukernel gemm1;
3670 struct xnn_hmp_igemm_ukernel igemm1;
3671 };
3672
3673 #if XNN_PLATFORM_JIT
3674 struct xnn_hmp_gemm_codegen {
3675 xnn_jit_gemm_code_generator_function function[XNN_MAX_UARCH_TYPES];
3676 };
3677
xnn_init_hmp_gemm_codegen(xnn_jit_gemm_code_generator_function function)3678 static inline struct xnn_hmp_gemm_codegen xnn_init_hmp_gemm_codegen(xnn_jit_gemm_code_generator_function function) {
3679 struct xnn_hmp_gemm_codegen ukernel = {{ function }};
3680 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3681 ukernel.function[i] = function;
3682 }
3683 return ukernel;
3684 }
3685
xnn_is_hmp_gemm_codegen(struct xnn_hmp_gemm_codegen ukernel)3686 static inline bool xnn_is_hmp_gemm_codegen(struct xnn_hmp_gemm_codegen ukernel) {
3687 #if XNN_MAX_UARCH_TYPES == 1
3688 return false;
3689 #else
3690 uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
3691 uintptr_t difference = 0;
3692 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3693 difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
3694 }
3695 return difference != 0;
3696 #endif
3697 }
3698
3699 struct xnn_hmp_igemm_codegen {
3700 xnn_jit_igemm_code_generator_function function[XNN_MAX_UARCH_TYPES];
3701 };
3702
xnn_init_hmp_igemm_codegen(xnn_jit_igemm_code_generator_function function)3703 static inline struct xnn_hmp_igemm_codegen xnn_init_hmp_igemm_codegen(xnn_jit_igemm_code_generator_function function) {
3704 struct xnn_hmp_igemm_codegen ukernel = {{ function }};
3705 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3706 ukernel.function[i] = function;
3707 }
3708 return ukernel;
3709 }
3710
xnn_is_hmp_igemm_codegen(struct xnn_hmp_igemm_codegen ukernel)3711 static inline bool xnn_is_hmp_igemm_codegen(struct xnn_hmp_igemm_codegen ukernel) {
3712 #if XNN_MAX_UARCH_TYPES == 1
3713 return false;
3714 #else
3715 uintptr_t default_function = (uintptr_t) ukernel.function[XNN_UARCH_DEFAULT];
3716 uintptr_t difference = 0;
3717 for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) {
3718 difference |= (default_function ^ (uintptr_t) ukernel.function[i]);
3719 }
3720 return difference != 0;
3721 #endif
3722 }
3723
3724 struct gemm_codegens {
3725 struct xnn_hmp_gemm_codegen gemm;
3726 struct xnn_hmp_igemm_codegen igemm;
3727 // Optional JIT GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters.
3728 struct xnn_hmp_gemm_codegen gemm1;
3729 struct xnn_hmp_igemm_codegen igemm1;
3730 };
3731 #endif // XNN_PLATFORM_JIT
3732
3733 struct gemm_parameters {
3734 struct gemm_fused_ukernels minmax;
3735 struct gemm_fused_ukernels relu;
3736 struct gemm_fused_ukernels linear;
3737 #if XNN_PLATFORM_JIT
3738 struct gemm_codegens generator;
3739 #endif // XNN_PLATFORM_JIT
3740 union {
3741 xnn_init_qs8_minmax_params_fn qc8;
3742 xnn_init_qs8_conv_minmax_params_fn qs8;
3743 xnn_init_qu8_conv_minmax_params_fn qu8;
3744 xnn_init_f16_scaleminmax_params_fn f16;
3745 xnn_init_f32_minmax_params_fn f32;
3746 } init;
3747 uint8_t mr;
3748 uint8_t nr;
3749 uint8_t log2_kr;
3750 uint8_t log2_sr;
3751 };
3752
3753 struct vunary_parameters {
3754 xnn_univector_ukernel_function ukernel;
3755 union {
3756 xnn_init_f16_f32_cvt_params_fn f16_f32_cvt;
3757 xnn_init_f16_hswish_params_fn f16_hswish;
3758 xnn_init_f32_abs_params_fn f32_abs;
3759 xnn_init_f32_default_params_fn f32_default;
3760 xnn_init_f32_elu_params_fn f32_elu;
3761 xnn_init_f32_f16_cvt_params_fn f32_f16_cvt;
3762 xnn_init_f32_hswish_params_fn f32_hswish;
3763 xnn_init_f32_lrelu_params_fn f32_lrelu;
3764 xnn_init_f32_minmax_params_fn f32_minmax;
3765 xnn_init_f32_neg_params_fn f32_neg;
3766 xnn_init_f32_qs8_cvt_params_fn f32_qs8_cvt;
3767 xnn_init_f32_qu8_cvt_params_fn f32_qu8_cvt;
3768 xnn_init_f32_rnd_params_fn f32_rnd;
3769 xnn_init_f32_sigmoid_params_fn f32_sigmoid;
3770 xnn_init_f32_sqrt_params_fn f32_sqrt;
3771 xnn_init_qs8_f32_cvt_params_fn qs8_f32_cvt;
3772 xnn_init_qu8_f32_cvt_params_fn qu8_f32_cvt;
3773 xnn_init_s8_minmax_params_fn s8_minmax;
3774 xnn_init_u8_minmax_params_fn u8_minmax;
3775 } init;
3776 // Number of elements in a tile.
3777 // For best efficiency, micro-kernel must process a multiple of this number of elements in each call.
3778 uint8_t element_tile;
3779 };
3780
3781 struct vbinary_fused_ukernels {
3782 xnn_vbinary_ukernel_function op_ukernel;
3783 xnn_vbinary_ukernel_function opc_ukernel;
3784 xnn_vbinary_ukernel_function ropc_ukernel;
3785 };
3786
3787 struct vbinary_parameters {
3788 struct vbinary_fused_ukernels minmax;
3789 struct vbinary_fused_ukernels linear;
3790 union {
3791 xnn_init_f16_minmax_params_fn f16_minmax;
3792 xnn_init_f32_default_params_fn f32_default;
3793 xnn_init_f32_minmax_params_fn f32_minmax;
3794 xnn_init_qs8_addsub_minmax_params_fn qs8_addsub;
3795 xnn_init_qs8_mul_minmax_params_fn qs8_mul;
3796 xnn_init_qu8_addsub_minmax_params_fn qu8_addsub;
3797 xnn_init_qu8_mul_minmax_params_fn qu8_mul;
3798 } init;
3799 // Number of elements in a tile.
3800 // For best efficiency, micro-kernel must process a multiple of this number of elements in each call.
3801 uint8_t element_tile;
3802 };
3803
3804 struct spmm_parameters {
3805 xnn_spmm_ukernel_function ukernel;
3806 // Number of M-dimension elements in a tile.
3807 // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator.
3808 uint8_t mr;
3809 // Number of N-dimension elements in a tile.
3810 // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator.
3811 uint8_t nr;
3812 };
3813
3814 struct conv_hwc2chw_parameters {
3815 xnn_conv_hwc2chw_ukernel_function ukernel_with_symm_padding;
3816 // Number of output channels in a tile.
3817 // This parameter must be passed as is to weight packing function.
3818 uint8_t output_channel_tile;
3819 // Number of output height pixels in a tile.
3820 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
3821 uint8_t output_height_tile;
3822 // Number of output width pixels in a tile.
3823 uint8_t output_width_tile;
3824 };
3825
3826 struct dwconv2d_chw_parameters {
3827 xnn_dwconv2d_chw_ukernel_function ukernel;
3828 // Number of output width pixels in a tile.
3829 uint8_t output_width_tile;
3830 // Number of output height pixels in a tile.
3831 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
3832 uint8_t output_height_tile;
3833 };
3834
3835 struct gavgpool_cw_parameters {
3836 xnn_gavgpool_cw_ukernel_function ukernel;
3837 // Number of channels in a tile.
3838 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3839 uint8_t channel_tile;
3840 };
3841
3842 union dwconv_fused_ukernels {
3843 xnn_dwconv_unipass_ukernel_function unipass;
3844 xnn_dwconv_multipass_ukernel_function multipass;
3845 };
3846
3847 struct dwconv_parameters {
3848 union dwconv_fused_ukernels minmax;
3849 union dwconv_fused_ukernels linear;
3850 union {
3851 xnn_init_qs8_minmax_params_fn qc8;
3852 xnn_init_qs8_conv_minmax_params_fn qs8;
3853 xnn_init_qu8_conv_minmax_params_fn qu8;
3854 xnn_init_f16_minmax_params_fn f16;
3855 xnn_init_f32_minmax_params_fn f32;
3856 } init;
3857 uint8_t channel_tile;
3858 uint8_t primary_tile;
3859 uint8_t incremental_tile;
3860 };
3861
3862 struct depthtospace2d_chw2hwc_parameters {
3863 xnn_depthtospace2d_chw2hwc_ukernel_function ukernel;
3864 // Number of output pixels in a tile.
3865 // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call.
3866 uint8_t pixel_tile;
3867 // Number of channels in a tile.
3868 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3869 uint8_t channel_tile;
3870 };
3871
3872 struct gavgpool_parameters {
3873 xnn_gavgpool_unipass_ukernel_function unipass;
3874 xnn_gavgpool_multipass_ukernel_function multipass;
3875 union {
3876 xnn_init_f16_scaleminmax_params_fn f16;
3877 xnn_init_f32_scaleminmax_params_fn f32;
3878 xnn_init_qs8_avgpool_minmax_params_fn qs8;
3879 xnn_init_qu8_avgpool_minmax_params_fn qu8;
3880 } init;
3881 union {
3882 xnn_update_f16_scaleminmax_params_fn f16;
3883 xnn_update_f32_scaleminmax_params_fn f32;
3884 xnn_update_qs8_avgpool_minmax_params_fn qs8;
3885 xnn_update_qu8_avgpool_minmax_params_fn qu8;
3886 } update;
3887 // Number of rows in a tile.
3888 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
3889 uint16_t row_tile;
3890 // Number of channels in a tile.
3891 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3892 uint16_t channel_tile;
3893 };
3894
3895 struct avgpool_parameters {
3896 xnn_avgpool_unipass_ukernel_function unipass;
3897 xnn_avgpool_multipass_ukernel_function multipass;
3898 union {
3899 xnn_init_f32_scaleminmax_params_fn f32;
3900 xnn_init_qu8_avgpool_minmax_params_fn qu8;
3901 } init;
3902 // Number of rows in a primary tile.
3903 // Unipass micro-kernel must be called with this number of rows, or fewer.
3904 // Multipass micro-kernel must be called with more than this number of rows.
3905 uint8_t primary_tile;
3906 // Number of rows in an incremental tile.
3907 // For best efficiency, multipass micro-kernel must process the number of rows in the primary tile plus a multiple
3908 // of this number of rows in each call. This number has no meaning for the unipass micro-kernel.
3909 uint8_t incremental_tile;
3910 // Number of channels in a tile.
3911 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3912 uint16_t channel_tile;
3913 };
3914
3915 struct pavgpool_parameters {
3916 xnn_pavgpool_unipass_ukernel_function unipass;
3917 xnn_pavgpool_multipass_ukernel_function multipass;
3918 // Number of rows in a primary tile.
3919 // Unipass micro-kernel must be called with this number of rows, or fewer.
3920 // Multipass micro-kernel must be called with more than this number of rows.
3921 uint8_t primary_tile;
3922 // Number of rows in an incremental tile.
3923 // For best efficiency, multipass micro-kernel must process the number of rows in the primary tile plus a multiple
3924 // of this number of rows in each call. This number has no meaning for the unipass micro-kernel.
3925 uint8_t incremental_tile;
3926 // Number of channels in a tile.
3927 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3928 uint16_t channel_tile;
3929 };
3930
3931 struct argmaxpool_parameters {
3932 union {
3933 xnn_argmaxpool_unipass_ukernel_function up;
3934 xnn_argmaxpool_multipass_ukernel_function mp;
3935 };
3936 uint8_t mr;
3937 uint8_t qr;
3938 };
3939
3940 struct maxpool_parameters {
3941 xnn_maxpool_ukernel_function ukernel;
3942 union {
3943 xnn_init_s8_minmax_params_fn s8;
3944 xnn_init_u8_minmax_params_fn u8;
3945 xnn_init_f32_minmax_params_fn f32;
3946 xnn_init_f16_minmax_params_fn f16;
3947 } init;
3948 uint8_t mr;
3949 uint8_t qr;
3950 };
3951
3952 struct ibilinear_parameters {
3953 xnn_ibilinear_ukernel_function ukernel;
3954 // Number of output pixels in a tile.
3955 // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call.
3956 uint8_t pixel_tile;
3957 // Number of channels in a tile.
3958 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3959 uint8_t channel_tile;
3960 };
3961
3962 struct ibilinear_chw_parameters {
3963 xnn_ibilinear_chw_ukernel_function ukernel;
3964 // Number of output pixels in a tile.
3965 // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call.
3966 uint8_t pixel_tile;
3967 // Number of channels in a tile.
3968 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
3969 uint8_t channel_tile;
3970 };
3971
3972 struct zip_parameters {
3973 xnn_zipc_ukernel_function x2;
3974 xnn_zipc_ukernel_function x3;
3975 xnn_zipc_ukernel_function x4;
3976 xnn_zipv_ukernel_function xm;
3977 };
3978
3979 struct prelu_parameters {
3980 xnn_prelu_ukernel_function ukernel;
3981 uint16_t row_tile;
3982 uint16_t channel_tile;
3983 };
3984
3985 struct raddstoreexpminusmax_parameters {
3986 xnn_f32_raddstoreexpminusmax_ukernel_function ukernel;
3987 xnn_init_f32_expminus_params_fn init;
3988 // Number of elements in a tile.
3989 // For best efficiency, micro-kernel must process a multiple of this number of elements in each call.
3990 uint8_t element_tile;
3991 };
3992
3993 struct fill_parameters {
3994 xnn_fill_ukernel_function ukernel;
3995 // Number of rows of inputs processed in one tile.
3996 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
3997 uint8_t row_tile;
3998 };
3999
4000 struct pad_parameters {
4001 xnn_pad_ukernel_function ukernel;
4002 // Number of rows of inputs processed in one tile.
4003 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
4004 uint8_t row_tile;
4005 };
4006
4007 struct vmulcaddc_parameters {
4008 xnn_vmulcaddc_ukernel_function ukernel;
4009 union {
4010 xnn_init_f16_minmax_params_fn f16;
4011 xnn_init_f32_minmax_params_fn f32;
4012 } init;
4013 uint8_t channel_tile;
4014 uint8_t row_tile;
4015 };
4016
4017 #define XNN_MAX_QC8_DWCONV_UKERNELS 2
4018 #define XNN_MAX_QS8_DWCONV_UKERNELS 2
4019 #define XNN_MAX_QU8_DWCONV_UKERNELS 2
4020 #define XNN_MAX_F16_DWCONV_UKERNELS 3
4021 #define XNN_MAX_F32_DWCONV_UKERNELS 4
4022 #define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3
4023
4024 // Indicates that XNNPACK as a whole has initialized.
4025 // This does not guarantee that any particular microkernels are available.
4026 #define XNN_INIT_FLAG_XNNPACK 0x00000001
4027 // Indicates that F32 XNNPACK microkernels are available for use.
4028 #define XNN_INIT_FLAG_F32 0x00000002
4029 // Indicates that X32 XNNPACK microkernels are available for use.
4030 #define XNN_INIT_FLAG_X32 0x00000004
4031 // Indicates that F16 XNNPACK microkernels are available for use.
4032 #define XNN_INIT_FLAG_F16 0x00000008
4033 // Indicates that X16 XNNPACK microkernels are available for use.
4034 #define XNN_INIT_FLAG_X16 0x00000010
4035 // Indicates that QC8 XNNPACK microkernels are available for use.
4036 #define XNN_INIT_FLAG_QC8 0x00000020
4037 // Indicates that QS8 XNNPACK microkernels are available for use.
4038 #define XNN_INIT_FLAG_QS8 0x00000040
4039 // Indicates that QU8 XNNPACK microkernels are available for use.
4040 #define XNN_INIT_FLAG_QU8 0x00000080
4041 // Indicates that S8 XNNPACK microkernels are available for use.
4042 #define XNN_INIT_FLAG_S8 0x00000100
4043 // Indicates that U8 XNNPACK microkernels are available for use.
4044 #define XNN_INIT_FLAG_U8 0x00000200
4045 // Indicates that X8 XNNPACK microkernels are available for use.
4046 #define XNN_INIT_FLAG_X8 0x00000400
4047 // Indicates that XX XNNPACK microkernels are available for use.
4048 #define XNN_INIT_FLAG_XX 0x00000800
4049 // Indicates that VCVT XNNPACK microkernels are available for use.
4050 #define XNN_INIT_FLAG_VCVT 0x00001000
4051 // Indicates that CHW XNNPACK microkernels are optimized for the host platform.
4052 #define XNN_INIT_FLAG_CHW_OPT 0x00002000
4053
4054 struct xnn_parameters {
4055 // Bitwise combination of XNN_INIT_FLAG_* flags
4056 uint32_t init_flags;
4057 struct xnn_allocator allocator;
4058 struct {
4059 struct gemm_parameters gemm;
4060 struct dwconv_parameters dwconv[XNN_MAX_QC8_DWCONV_UKERNELS];
4061 } qc8;
4062 struct {
4063 struct gemm_parameters gemm;
4064 struct dwconv_parameters dwconv[XNN_MAX_QS8_DWCONV_UKERNELS];
4065 struct gavgpool_parameters gavgpool;
4066 struct vbinary_parameters vadd;
4067 struct vbinary_parameters vmul;
4068 } qs8;
4069 struct {
4070 struct gemm_parameters gemm;
4071 struct dwconv_parameters dwconv[XNN_MAX_QU8_DWCONV_UKERNELS];
4072 struct avgpool_parameters avgpool;
4073 struct gavgpool_parameters gavgpool;
4074 struct vbinary_parameters vadd;
4075 struct vbinary_parameters vmul;
4076 } qu8;
4077 struct {
4078 struct vunary_parameters clamp;
4079 // Bilinear interpolation (2D).
4080 struct ibilinear_parameters ibilinear;
4081 struct maxpool_parameters maxpool;
4082 } s8;
4083 struct {
4084 struct vunary_parameters clamp;
4085 // Bilinear interpolation (2D).
4086 struct ibilinear_parameters ibilinear;
4087 struct maxpool_parameters maxpool;
4088 xnn_u8_lut32norm_ukernel_function lut32norm;
4089 xnn_u8_rmax_ukernel_function rmax;
4090 } u8;
4091 struct {
4092 xnn_x8_lut_ukernel_function lut;
4093 struct zip_parameters zip;
4094 } x8;
4095 struct {
4096 struct gavgpool_parameters gavgpool;
4097 struct gemm_parameters gemm;
4098 struct gemm_parameters gemm2;
4099 struct dwconv_parameters dwconv[XNN_MAX_F16_DWCONV_UKERNELS];
4100 struct maxpool_parameters maxpool;
4101 struct vunary_parameters hswish;
4102 struct prelu_parameters prelu;
4103 struct vbinary_parameters vadd;
4104 struct vbinary_parameters vmul;
4105 struct vmulcaddc_parameters vmulcaddc;
4106 } f16;
4107 struct {
4108 struct gemm_parameters gemm;
4109 struct gemm_parameters gemm2;
4110 struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS];
4111 struct avgpool_parameters avgpool;
4112 struct pavgpool_parameters pavgpool;
4113 struct gavgpool_parameters gavgpool;
4114 struct maxpool_parameters maxpool;
4115 struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS];
4116 // Bilinear interpolation (2D).
4117 struct ibilinear_parameters ibilinear;
4118 struct vunary_parameters abs;
4119 struct vunary_parameters clamp;
4120 struct vunary_parameters elu;
4121 struct vunary_parameters hswish;
4122 struct vunary_parameters lrelu;
4123 struct vunary_parameters neg;
4124 xnn_univector_ukernel_function relu;
4125 struct vunary_parameters rndne;
4126 struct vunary_parameters rndz;
4127 struct vunary_parameters rndu;
4128 struct vunary_parameters rndd;
4129 struct vunary_parameters sigmoid;
4130 struct vunary_parameters sqr;
4131 struct vunary_parameters sqrt;
4132 struct prelu_parameters prelu;
4133 struct vbinary_parameters vadd;
4134 struct vbinary_parameters vdiv;
4135 struct vbinary_parameters vmax;
4136 struct vbinary_parameters vmin;
4137 struct vbinary_parameters vmul;
4138 struct vbinary_parameters vsub;
4139 struct vbinary_parameters vsqrdiff;
4140 struct vmulcaddc_parameters vmulcaddc;
4141 struct raddstoreexpminusmax_parameters raddstoreexpminusmax;
4142 xnn_f32_rmax_ukernel_function rmax;
4143 // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
4144 struct spmm_parameters spmm;
4145 // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).
4146 struct spmm_parameters spmm2;
4147 // Sparse Matrix-Dense Matrix Multiplication (NR=4 block).
4148 struct spmm_parameters spmm4;
4149 // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->CHW layout conversion.
4150 struct conv_hwc2chw_parameters conv_hwc2chw_3x3c3s2;
4151 // Direct 3x3 stride-1 Convolution with padding 1 on left and right in CHW layout.
4152 struct dwconv2d_chw_parameters dwconv2d_chw_3x3;
4153 // Direct 3x3 stride-2 Convolution with padding 1 on left and right in CHW layout.
4154 struct dwconv2d_chw_parameters dwconv2d_chw_3x3s2;
4155 // Direct 5x5 stride-1 Convolution with padding 2 on left and right in CHW layout.
4156 struct dwconv2d_chw_parameters dwconv2d_chw_5x5;
4157 // Direct 5x5 stride-2 Convolution with padding 2 on left and right in CHW layout.
4158 struct dwconv2d_chw_parameters dwconv2d_chw_5x5s2;
4159 // Global Average Pooling in CW layout.
4160 struct gavgpool_cw_parameters gavgpool_cw;
4161 // Bilinear interpolation (2D) in CHW layout.
4162 struct ibilinear_chw_parameters ibilinear_chw;
4163 } f32;
4164 struct {
4165 struct vunary_parameters f16_to_f32;
4166 struct vunary_parameters f32_to_f16;
4167 struct vunary_parameters f32_to_qs8;
4168 struct vunary_parameters f32_to_qu8;
4169 struct vunary_parameters qs8_to_f32;
4170 struct vunary_parameters qu8_to_f32;
4171 } vcvt;
4172 struct {
4173 xnn_unpool_ukernel_function unpool;
4174 struct zip_parameters zip;
4175 // Depth To Space 2D with CHW->HWC layout conversion.
4176 struct depthtospace2d_chw2hwc_parameters depthtospace2d_chw2hwc;
4177 } x32;
4178 struct {
4179 xnn_univector_ukernel_function copy;
4180 struct fill_parameters fill;
4181 struct pad_parameters pad;
4182 } xx;
4183 };
4184
4185 #ifdef __cplusplus
4186 extern "C" XNN_INTERNAL struct xnn_parameters xnn_params;
4187 #else
4188 extern XNN_INTERNAL struct xnn_parameters xnn_params;
4189 #endif
4190