• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 
15 #include <xnnpack.h>
16 #include <xnnpack/common.h>
17 
18 struct xnn_f16_output_params {
19   uint16_t scale;
20   uint16_t max;
21   uint16_t min;
22 };
23 
24 union xnn_f32_output_params {
25   struct {
26     float max;
27     float min;
28   } scalar;
29 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
30   struct {
31     XNN_ALIGN(16) float max[4];
32     XNN_ALIGN(16) float min[4];
33   } sse;
34 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
35 };
36 
37 union xnn_f32_spchw_params {
38   struct {
39     float max;
40     float min;
41   } scalar;
42 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
43   struct {
44     float min;
45     float max;
46     XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
47     XNN_ALIGN(16) uint32_t mask_odd[4];  // used by stride 2 kernels
48     XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
49   } neon;
50 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
51 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
52   struct {
53     XNN_ALIGN(16) float max[4];
54     XNN_ALIGN(16) float min[4];
55     XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels
56     XNN_ALIGN(16) uint32_t mask_odd[4];  // used by stride 2 kernels
57     XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels
58   } sse;
59 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
60 };
61 
62 union xnn_u8_output_params {
63   struct {
64     int32_t max;
65     int32_t min;
66   } scalar;
67 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
68   struct {
69     uint8_t max;
70     uint8_t min;
71   } neon;
72 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
73 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
74   struct {
75     XNN_ALIGN(16) uint8_t max[16];
76     XNN_ALIGN(16) uint8_t min[16];
77   } sse2;
78 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
79 };
80 
81 union xnn_f32_avgpool_params {
82   struct {
83     float multiplier;
84     float output_min;
85     float output_max;
86   } scalar;
87 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
88   struct {
89     XNN_ALIGN(16) float multiplier[4];
90     XNN_ALIGN(16) float output_max[4];
91     XNN_ALIGN(16) float output_min[4];
92   } sse2;
93 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
94 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
95   struct {
96     XNN_ALIGN(16) float multiplier;
97     XNN_ALIGN(16) float output_max;
98     XNN_ALIGN(16) float output_min;
99   } neon;
100 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
101 };
102 
103 union xnn_f32_gavgpool_params {
104   struct {
105     float multiplier;
106     float output_min;
107     float output_max;
108   } scalar;
109 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
110   struct {
111     XNN_ALIGN(16) float multiplier[4];
112     XNN_ALIGN(16) float output_max[4];
113     XNN_ALIGN(16) float output_min[4];
114     XNN_ALIGN(16) uint32_t mask[4];
115   } sse;
116 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
117 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
118   struct {
119     XNN_ALIGN(16) float multiplier;
120     XNN_ALIGN(16) float output_max;
121     XNN_ALIGN(16) float output_min;
122     XNN_ALIGN(16) uint32_t mask[4];
123   } neon;
124 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64 */
125 };
126 
127 union xnn_f32_hswish_params {
128   struct {
129     float sixth;
130     float half;
131     float one;
132   } scalar;
133 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
134   struct {
135     XNN_ALIGN(16) float sixth[4];
136     XNN_ALIGN(16) float half[4];
137     XNN_ALIGN(16) float one[4];
138   } sse;
139 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
140 };
141 
142 union xnn_q8_gemm_params {
143   struct {
144     int32_t kernel_zero_point;
145     int32_t input_zero_point;
146     int32_t multiplier;
147     int32_t remainder_mask;
148     int32_t remainder_threshold;
149     uint32_t shift;
150     int32_t output_min_less_zero_point;
151     int32_t output_max_less_zero_point;
152     int32_t output_zero_point;
153   } scalar;
154 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
155   struct {
156     int16_t kernel_zero_point;
157     int16_t input_zero_point;
158     int32_t multiplier;
159     int32_t right_shift;
160     int16_t output_zero_point;
161     uint8_t output_max;
162     uint8_t output_min;
163   } neon;
164 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
165 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
166   struct {
167     XNN_ALIGN(16) int16_t kernel_zero_point[8];
168     XNN_ALIGN(16) int16_t input_zero_point[8];
169     XNN_ALIGN(16) uint32_t multiplier[4];
170     XNN_ALIGN(16) uint64_t rounding[2];
171     XNN_ALIGN(16) int32_t remainder_mask[4];
172     XNN_ALIGN(16) int32_t remainder_threshold[4];
173     XNN_ALIGN(16) uint64_t shift[2];
174     XNN_ALIGN(16) int16_t output_zero_point[8];
175     XNN_ALIGN(16) uint8_t output_max[16];
176     XNN_ALIGN(16) uint8_t output_min[16];
177   } sse2;
178 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
179 };
180 
181 union xnn_q8_add_params {
182   struct {
183     int32_t zero_point_product;
184     uint32_t a_multiplier;
185     uint32_t b_multiplier;
186     uint32_t shift;
187     int32_t remainder_mask;
188     int32_t remainder_threshold;
189     int32_t y_zero_point;
190     int32_t y_max;
191     int32_t y_min;
192   } scalar;
193 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
194   struct {
195     uint8_t a_zero_point;
196     uint8_t b_zero_point;
197     int16_t y_zero_point;
198     int32_t a_multiplier;
199     int32_t b_multiplier;
200     int32_t right_shift;
201     uint8_t y_max;
202     uint8_t y_min;
203   } neon;
204 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
205 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
206   struct {
207     XNN_ALIGN(16) int32_t zero_point_product[4];
208     XNN_ALIGN(16) uint16_t a_multiplier_lo[8];
209     XNN_ALIGN(16) uint16_t a_multiplier_hi[8];
210     XNN_ALIGN(16) uint16_t b_multiplier_lo[8];
211     XNN_ALIGN(16) uint16_t b_multiplier_hi[8];
212     XNN_ALIGN(16) int32_t remainder_mask[4];
213     XNN_ALIGN(16) int32_t remainder_threshold[4];
214     XNN_ALIGN(16) int16_t y_zero_point[8];
215     XNN_ALIGN(16) uint8_t y_max[16];
216     XNN_ALIGN(16) uint8_t y_min[16];
217     uint32_t shift;
218     uint32_t a_multiplier;
219     uint32_t b_multiplier;
220   } sse2;
221 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
222 };
223 
224 union xnn_q8_avgpool_params {
225   struct {
226     int32_t bias;
227     int32_t multiplier;
228     int64_t rounding;
229     uint32_t right_shift;
230     int32_t output_min_less_zero_point;
231     int32_t output_max_less_zero_point;
232     int32_t output_zero_point;
233   } scalar;
234 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
235   struct {
236     int32_t bias;
237     int32_t multiplier;
238     int64_t left_shift;
239     int16_t output_zero_point;
240     uint8_t output_max;
241     uint8_t output_min;
242   } neon;
243 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
244 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
245   struct {
246     XNN_ALIGN(16) int32_t bias[4];
247     XNN_ALIGN(16) uint32_t multiplier[4];
248     XNN_ALIGN(16) uint64_t rounding[2];
249     XNN_ALIGN(16) uint64_t right_shift[2];
250     XNN_ALIGN(16) int16_t output_zero_point[8];
251     XNN_ALIGN(16) uint8_t output_max[16];
252     XNN_ALIGN(16) uint8_t output_min[16];
253   } sse2;
254 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
255 };
256 
257 union xnn_fp32_requantization_params {
258   struct {
259     float scale;
260     float min_less_zero_point;
261     float max_less_zero_point;
262     float magic;
263     int32_t magic_less_zero_point;
264   } scalar;
265   struct {
266     float scale;
267     float max;
268     float min;
269     float magic;
270     int32_t magic_less_zero_point;
271   } neon;
272   struct {
273     float scale;
274     int16_t zero_point;
275     uint8_t max;
276     uint8_t min;
277   } neonv8;
278   struct {
279     XNN_ALIGN(16) float scale[4];
280     XNN_ALIGN(16) int16_t zero_point[8];
281     XNN_ALIGN(16) uint8_t max[16];
282     XNN_ALIGN(16) uint8_t min[16];
283   } sse2;
284   struct {
285     XNN_ALIGN(16) float scale[4];
286     XNN_ALIGN(16) float min_less_zero_point[4];
287     XNN_ALIGN(16) float max_less_zero_point[4];
288     XNN_ALIGN(16) float magic[4];
289     XNN_ALIGN(16) int32_t magic_less_zero_point[4];
290   } psimd;
291 };
292 
293 union xnn_precise_requantization_params {
294   struct {
295     uint32_t multiplier;
296     uint32_t rounding_lo;
297     uint32_t rounding_hi;
298     uint32_t shift_less_32;
299     int32_t min_less_zero_point;
300     int32_t max_less_zero_point;
301     int32_t zero_point;
302   } scalar;
303   struct {
304     int32_t multiplier;
305     int32_t right_shift;
306     int16_t zero_point;
307     uint8_t max;
308     uint8_t min;
309   } neon;
310   struct {
311     XNN_ALIGN(16) uint32_t multiplier[4];
312     XNN_ALIGN(16) uint64_t rounding[2];
313     XNN_ALIGN(16) uint32_t shift[4];
314     XNN_ALIGN(16) int16_t zero_point[8];
315     XNN_ALIGN(16) uint8_t max[16];
316     XNN_ALIGN(16) uint8_t min[16];
317   } sse2;
318 };
319 
320 union xnn_q31_requantization_params {
321   struct {
322     int32_t multiplier;
323     int32_t remainder_mask;
324     int32_t remainder_threshold;
325     uint32_t shift;
326     int32_t min_less_zero_point;
327     int32_t max_less_zero_point;
328     int32_t zero_point;
329   } scalar;
330 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
331   struct {
332     int32_t multiplier;
333     int32_t right_shift;
334     int16_t zero_point;
335     uint8_t max;
336     uint8_t min;
337   } neon;
338 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
339 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
340   struct {
341     XNN_ALIGN(16) uint32_t multiplier[4];
342     XNN_ALIGN(16) uint64_t rounding[2];
343     XNN_ALIGN(16) int32_t remainder_mask[4];
344     XNN_ALIGN(16) int32_t remainder_threshold[4];
345     XNN_ALIGN(16) uint64_t shift[2];
346     XNN_ALIGN(16) int16_t zero_point[8];
347     XNN_ALIGN(16) uint8_t max[16];
348     XNN_ALIGN(16) uint8_t min[16];
349   } sse2;
350 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
351 };
352 
353 union xnn_requantization_params {
354   union xnn_precise_requantization_params precise;
355   union xnn_fp32_requantization_params fp32;
356   union xnn_q31_requantization_params q31;
357 };
358 
359 typedef void (*xnn_ppmm_ukernel_function)(
360     size_t mr,
361     size_t nc,
362     size_t kc,
363     const void* a,
364     const void* w,
365     void* c,
366     size_t cm_stride,
367     size_t cn_stride,
368     const void* params);
369 
370 typedef void (*xnn_f32_ppmm_ukernel_function)(
371     size_t mr,
372     size_t nc,
373     size_t kc,
374     const float* a,
375     const float* w,
376     float* c,
377     size_t cm_stride,
378     size_t cn_stride,
379     const union xnn_f32_output_params* params);
380 
381 typedef void (*xnn_f16_ppmm_ukernel_function)(
382     size_t mr,
383     size_t nc,
384     size_t kc,
385     const void* a,
386     const void* w,
387     void* c,
388     size_t cm_stride,
389     size_t cn_stride,
390     const struct xnn_f16_output_params* params);
391 
392 typedef void (*xnn_gemm_ukernel_function)(
393     size_t mr,
394     size_t nr,
395     size_t k,
396     const void* a,
397     size_t a_stride,
398     const void* w,
399     void* c,
400     size_t cm_stride,
401     size_t cn_stride,
402     const void* params);
403 
404 typedef void (*xnn_f32_gemm_ukernel_function)(
405     size_t mr,
406     size_t nr,
407     size_t k,
408     const float* a,
409     size_t a_stride,
410     const float* w,
411     float* c,
412     size_t cm_stride,
413     size_t cn_stride,
414     const union xnn_f32_output_params* params);
415 
416 typedef void (*xnn_f32_gemminc_ukernel_function)(
417     size_t mr,
418     size_t nr,
419     size_t k,
420     const float* a,
421     size_t a_stride,
422     const float* w,
423     float* c,
424     size_t cm_stride,
425     size_t cn_stride,
426     const float* acc,
427     const union xnn_f32_output_params* params);
428 
429 typedef void (*xnn_f16_gemm_ukernel_function)(
430     size_t mr,
431     size_t nr,
432     size_t k,
433     const void* a,
434     size_t a_stride,
435     const void* w,
436     void* c,
437     size_t cm_stride,
438     size_t cn_stride,
439     const struct xnn_f16_output_params* params);
440 
441 typedef void (*xnn_q8_gemm_ukernel_function)(
442     size_t mr,
443     size_t nr,
444     size_t k,
445     const uint8_t* a,
446     size_t a_stride,
447     const void* w,
448     uint8_t* c,
449     size_t cm_stride,
450     size_t cn_stride,
451     const union xnn_q8_gemm_params* params);
452 
453 typedef void (*xnn_igemm_ukernel_function)(
454     size_t mr,
455     size_t nr,
456     size_t kc,
457     size_t ks,
458     const void** a,
459     const void* w,
460     void* c,
461     size_t cm_stride,
462     size_t cn_stride,
463     size_t a_offset,
464     const void* zero,
465     const void* params);
466 
467 typedef void (*xnn_f32_igemm_ukernel_function)(
468     size_t mr,
469     size_t nr,
470     size_t kc,
471     size_t ks,
472     const float** a,
473     const float* w,
474     float* c,
475     size_t cm_stride,
476     size_t cn_stride,
477     size_t a_offset,
478     const float* zero,
479     const union xnn_f32_output_params* params);
480 
481 typedef void (*xnn_q8_igemm_ukernel_function)(
482     size_t mr,
483     size_t nr,
484     size_t kc,
485     size_t ks,
486     const uint8_t** a,
487     const void* w,
488     uint8_t* c,
489     size_t cm_stride,
490     size_t cn_stride,
491     size_t a_offset,
492     const uint8_t* zero,
493     const union xnn_q8_gemm_params* params);
494 
495 typedef void (*xnn_conv_hwc_ukernel_function)(
496     size_t input_height,
497     size_t input_width,
498     size_t output_y_start,
499     size_t output_y_end,
500     const void* input,
501     const void* zero,
502     const void* weights,
503     void* output,
504     size_t input_padding_top,
505     size_t output_channels,
506     size_t output_height_stride,
507     size_t output_width_stride,
508     const void* params);
509 
510 typedef void (*xnn_f32_conv_hwc_ukernel_function)(
511     size_t input_height,
512     size_t input_width,
513     size_t output_y_start,
514     size_t output_y_end,
515     const float* input,
516     const float* zero,
517     const float* weights,
518     float* output,
519     size_t input_padding_top,
520     size_t output_channels,
521     size_t output_height_stride,
522     size_t output_width_stride,
523     const union xnn_f32_output_params* params);
524 
525 typedef void (*xnn_conv_hwc2spchw_ukernel_function)(
526     size_t input_height,
527     size_t input_width,
528     size_t output_y_start,
529     size_t output_y_end,
530     const void* input,
531     const void* zero,
532     const void* weights,
533     void* output,
534     size_t input_padding_top,
535     size_t output_channels,
536     size_t output_height_stride,
537     size_t output_channel_stride,
538     const void* params);
539 
540 typedef void (*xnn_f32_conv_hwc2spchw_ukernel_function)(
541     size_t input_height,
542     size_t input_width,
543     size_t output_y_start,
544     size_t output_y_end,
545     const float* input,
546     const float* zero,
547     const float* weights,
548     float* output,
549     size_t input_padding_top,
550     size_t output_channels,
551     size_t output_height_stride,
552     size_t output_channel_stride,
553     const union xnn_f32_output_params* params);
554 
555 typedef void (*xnn_spmm_ukernel_function)(
556     uint32_t m,
557     uint32_t n,
558     const void* a,
559     const void* w,
560     const int32_t* dmap,
561     const uint32_t* nmap,
562     void* c,
563     const void* params);
564 
565 typedef void (*xnn_f16_spmm_ukernel_function)(
566     uint32_t m,
567     uint32_t n,
568     const void* a,
569     const void* w,
570     const int32_t* dmap,
571     const uint32_t* nmap,
572     void* c,
573     const struct xnn_f16_output_params* params);
574 
575 typedef void (*xnn_f32_spmm_ukernel_function)(
576     uint32_t m,
577     uint32_t n,
578     const float* a,
579     const float* w,
580     const int32_t* dmap,
581     const uint32_t* nmap,
582     float* c,
583     const union xnn_f32_output_params* params);
584 
585 typedef void (*xnn_packx_ukernel_function)(
586     size_t m,
587     size_t k,
588     const void* x,
589     size_t x_stride,
590     void* y);
591 
592 typedef void (*xnn_x32_packx_ukernel_function)(
593     size_t m,
594     size_t k,
595     const uint32_t* x,
596     size_t x_stride,
597     uint32_t* y);
598 
599 typedef void (*xnn_pad_ukernel_function)(
600     size_t m,
601     size_t n,
602     size_t l,
603     size_t r,
604     uint32_t c,
605     const void* x,
606     size_t x_stride,
607     void* y,
608     size_t y_stride);
609 
610 typedef void (*xnn_unpool_ukernel_function)(
611     size_t p,
612     size_t c,
613     uint32_t f,
614     const void* input,
615     const uint32_t* index,
616     void** output);
617 
618 typedef void (*xnn_x32_unpool_ukernel_function)(
619     size_t p,
620     size_t c,
621     uint32_t f,
622     const uint32_t* input,
623     const uint32_t* index,
624     uint32_t** output);
625 
626 typedef void (*xnn_zipc_ukernel_function)(
627     size_t n,
628     const void* x,
629     void* y);
630 
631 typedef void (*xnn_x8_zipc_ukernel_function)(
632     size_t n,
633     const uint8_t* x,
634     uint8_t* y);
635 
636 typedef void (*xnn_x32_zipc_ukernel_function)(
637     size_t n,
638     const uint32_t* x,
639     uint32_t* y);
640 
641 typedef void (*xnn_zipv_ukernel_function)(
642     size_t n,
643     size_t m,
644     const void* x,
645     void* y);
646 
647 typedef void (*xnn_x8_zipv_ukernel_function)(
648     size_t n,
649     size_t m,
650     const uint8_t* x,
651     uint8_t* y);
652 
653 typedef void (*xnn_x32_zipv_ukernel_function)(
654     size_t n,
655     size_t m,
656     const uint32_t* x,
657     uint32_t* y);
658 
659 typedef void (*xnn_x8_lut_ukernel_function)(
660     size_t n,
661     const uint8_t* x,
662     const uint8_t* t,
663     uint8_t* y);
664 
665 typedef void (*xnn_dwconv_spchw_ukernel_function)(
666     size_t output_height,
667     size_t input_width,
668     const void* input,
669     const void* weights,
670     void* output,
671     size_t input_tuple_stride,
672     size_t output_tuple_stride,
673     size_t input_height_stride,
674     size_t output_height_stride,
675     const void* params);
676 
677 typedef void (*xnn_f32_dwconv_spchw_ukernel_function)(
678     size_t output_height,
679     size_t input_width,
680     const float* input,
681     const float* weights,
682     float* output,
683     size_t input_tuple_stride,
684     size_t output_tuple_stride,
685     size_t input_height_stride,
686     size_t output_height_stride,
687     const union xnn_f32_spchw_params* params);
688 
689 typedef void (*xnn_dwconv_up_ukernel_function)(
690     size_t channels,
691     size_t output_width,
692     const void** input,
693     const void* weights,
694     void* output,
695     size_t input_stride,
696     size_t output_increment,
697     const void* params);
698 
699 typedef void (*xnn_f32_dwconv_up_ukernel_function)(
700     size_t channels,
701     size_t output_width,
702     const float** input,
703     const float* weights,
704     float* output,
705     size_t input_stride,
706     size_t output_increment,
707     const union xnn_f32_output_params* params);
708 
709 typedef void (*xnn_q8_dwconv_up_ukernel_function)(
710     size_t channels,
711     size_t output_width,
712     const uint8_t** input,
713     const void* weights,
714     uint8_t* output,
715     size_t input_stride,
716     size_t output_increment,
717     const union xnn_q8_gemm_params* params);
718 
719 typedef void (*xnn_dwconv_mp_ukernel_function)(
720     size_t channels,
721     size_t output_width,
722     const void** input,
723     const void* weights,
724     void* buffer,
725     void* output,
726     size_t input_stride,
727     size_t output_increment,
728     const void* params);
729 
730 typedef void (*xnn_f32_bilinear_ukernel_function)(
731     size_t output_pixels,
732     size_t channels,
733     const float** input,
734     size_t input_offset,
735     const float* weights,
736     float* output,
737     size_t output_increment);
738 
739 typedef void (*xnn_bilinear_ukernel_function)(
740     size_t output_pixels,
741     size_t channels,
742     const void** input,
743     size_t input_offset,
744     const void* weights,
745     void* output,
746     size_t output_increment);
747 
748 typedef void (*xnn_gavgpool_up_ukernel_function)(
749     size_t m,
750     size_t n,
751     const void* x,
752     size_t x_stride,
753     const void* zero,
754     void* y,
755     const void* params);
756 
757 typedef void (*xnn_f32_gavgpool_up_ukernel_function)(
758     size_t m,
759     size_t n,
760     const float* x,
761     size_t x_stride,
762     const float* zero,
763     float* y,
764     const union xnn_f32_avgpool_params* params);
765 
766 typedef void (*xnn_gavgpool_spchw_ukernel_function)(
767     size_t elements,
768     size_t channels,
769     const float* input,
770     float* output,
771     const void* params);
772 
773 typedef void (*xnn_f32_gavgpool_spchw_ukernel_function)(
774     size_t elements,
775     size_t channels,
776     const float* input,
777     float* output,
778     const union xnn_f32_gavgpool_params* params);
779 
780 typedef void (*xnn_q8_gavgpool_up_ukernel_function)(
781     size_t m,
782     size_t n,
783     const uint8_t* x,
784     size_t x_stride,
785     const uint8_t* zero,
786     uint8_t* y,
787     const union xnn_q8_avgpool_params* params);
788 
789 typedef void (*xnn_gavgpool_mp_ukernel_function)(
790     size_t m,
791     size_t n,
792     const void* x,
793     size_t x_stride,
794     const void* zero,
795     void* buffer,
796     void* y,
797     const void* params);
798 
799 typedef void (*xnn_f32_gavgpool_mp_ukernel_function)(
800     size_t m,
801     size_t n,
802     const float* x,
803     size_t x_stride,
804     const float* zero,
805     float* buffer,
806     float* y,
807     const union xnn_f32_avgpool_params* params);
808 
809 typedef void (*xnn_q8_gavgpool_mp_ukernel_function)(
810     size_t m,
811     size_t n,
812     const uint8_t* x,
813     size_t x_stride,
814     const uint8_t* zero,
815     int32_t* buffer,
816     uint8_t* y,
817     const union xnn_q8_avgpool_params* params);
818 
819 typedef void (*xnn_avgpool_up_ukernel_function)(
820     size_t n,
821     size_t ks,
822     size_t kc,
823     const void** x,
824     const void* zero,
825     void* y,
826     size_t x_increment,
827     size_t y_increment,
828     const void* params);
829 
830 typedef void (*xnn_f32_avgpool_up_ukernel_function)(
831     size_t n,
832     size_t ks,
833     size_t kc,
834     const float** x,
835     const float* zero,
836     float* y,
837     size_t x_increment,
838     size_t y_increment,
839     const union xnn_f32_avgpool_params* params);
840 
841 typedef void (*xnn_q8_avgpool_up_ukernel_function)(
842     size_t n,
843     size_t ks,
844     size_t kc,
845     const uint8_t** x,
846     const uint8_t* zero,
847     uint8_t* y,
848     size_t x_increment,
849     size_t y_increment,
850     const union xnn_q8_avgpool_params* params);
851 
852 typedef void (*xnn_avgpool_mp_ukernel_function)(
853     size_t n,
854     size_t ks,
855     size_t kc,
856     const void** x,
857     const void* zero,
858     void* buffer,
859     void* y,
860     size_t x_increment,
861     size_t y_increment,
862     const void* params);
863 
864 typedef void (*xnn_f32_avgpool_mp_ukernel_function)(
865     size_t n,
866     size_t ks,
867     size_t kc,
868     const float** x,
869     const float* zero,
870     float* buffer,
871     float* y,
872     size_t x_increment,
873     size_t y_increment,
874     const union xnn_f32_avgpool_params* params);
875 
876 typedef void (*xnn_q8_avgpool_mp_ukernel_function)(
877     size_t n,
878     size_t ks,
879     size_t kc,
880     const uint8_t** x,
881     const uint8_t* zero,
882     int32_t* buffer,
883     uint8_t* y,
884     size_t x_increment,
885     size_t y_increment,
886     const union xnn_q8_avgpool_params* params);
887 
888 typedef void (*xnn_pavgpool_up_ukernel_function)(
889     size_t n,
890     size_t ks,
891     size_t kc,
892     const void** x,
893     const void* zero,
894     const void* multiplier,
895     void* y,
896     size_t x_increment,
897     size_t y_increment,
898     const void* params);
899 
900 typedef void (*xnn_f32_pavgpool_up_ukernel_function)(
901     size_t n,
902     size_t ks,
903     size_t kc,
904     const float** x,
905     const float* zero,
906     const float* multiplier,
907     float* y,
908     size_t x_increment,
909     size_t y_increment,
910     const union xnn_f32_output_params* params);
911 
912 typedef void (*xnn_pavgpool_mp_ukernel_function)(
913     size_t n,
914     size_t ks,
915     size_t kc,
916     const void** x,
917     const void* zero,
918     const void* multiplier,
919     void* buffer,
920     void* y,
921     size_t x_increment,
922     size_t y_increment,
923     const void* params);
924 
925 typedef void (*xnn_f32_pavgpool_mp_ukernel_function)(
926     size_t n,
927     size_t ks,
928     size_t kc,
929     const float** x,
930     const float* zero,
931     const float* multiplier,
932     float* buffer,
933     float* y,
934     size_t x_increment,
935     size_t y_increment,
936     const union xnn_f32_output_params* params);
937 
938 typedef void (*xnn_maxpool_ukernel_function)(
939     size_t output_pixels,
940     size_t kernel_elements,
941     size_t channels,
942     const void** input,
943     size_t input_offset,
944     void* output,
945     size_t input_increment,
946     size_t output_increment,
947     const void* params);
948 
949 typedef void (*xnn_f32_maxpool_ukernel_function)(
950     size_t output_pixels,
951     size_t kernel_elements,
952     size_t channels,
953     const float** input,
954     size_t input_offset,
955     float* output,
956     size_t input_increment,
957     size_t output_increment,
958     const union xnn_f32_output_params* params);
959 
960 typedef void (*xnn_u8_maxpool_ukernel_function)(
961     size_t output_pixels,
962     size_t kernel_elements,
963     size_t channels,
964     const uint8_t** input,
965     size_t input_offset,
966     uint8_t* output,
967     size_t input_increment,
968     size_t output_increment,
969     const union xnn_u8_output_params* params);
970 
971 typedef void (*xnn_argmaxpool_up_ukernel_function)(
972     size_t output_pixels,
973     size_t kernel_elements,
974     size_t channels,
975     const void** input,
976     size_t input_offset,
977     void* output,
978     uint32_t* index,
979     size_t input_increment,
980     size_t output_increment,
981     const void* params);
982 
983 typedef void (*xnn_f32_argmaxpool_up_ukernel_function)(
984     size_t output_pixels,
985     size_t kernel_elements,
986     size_t channels,
987     const float** input,
988     size_t input_offset,
989     float* output,
990     uint32_t* index,
991     size_t input_increment,
992     size_t output_increment,
993     const union xnn_f32_output_params* params);
994 
995 typedef void (*xnn_argmaxpool_mp_ukernel_function)(
996     size_t output_pixels,
997     size_t kernel_elements,
998     size_t channels,
999     const void** input,
1000     size_t input_offset,
1001     void* accumulation_buffer,
1002     uint32_t* index_buffer,
1003     void* output,
1004     uint32_t* index,
1005     size_t input_increment,
1006     size_t output_increment,
1007     const void* params);
1008 
1009 typedef void (*xnn_f32_argmaxpool_mp_ukernel_function)(
1010     size_t output_pixels,
1011     size_t kernel_elements,
1012     size_t channels,
1013     const float** input,
1014     size_t input_offset,
1015     float* accumulation_buffer,
1016     uint32_t* index_buffer,
1017     float* output,
1018     uint32_t* index,
1019     size_t input_increment,
1020     size_t output_increment,
1021     const union xnn_f32_output_params* params);
1022 
1023 typedef void (*xnn_univector_ukernel_function)(
1024     size_t n,
1025     const void* x,
1026     void* y,
1027     const void* params);
1028 
1029 typedef void (*xnn_f32_clamp_ukernel_function)(
1030     size_t n,
1031     const float* x,
1032     float* y,
1033     const union xnn_f32_output_params* params);
1034 
1035 typedef void (*xnn_u8_clamp_ukernel_function)(
1036     size_t n,
1037     const uint8_t* x,
1038     uint8_t* y,
1039     const union xnn_u8_output_params* params);
1040 
1041 typedef void (*xnn_f32_hswish_ukernel_function)(
1042     size_t n,
1043     const float* x,
1044     float* y,
1045     const union xnn_f32_hswish_params* params);
1046 
1047 typedef void (*xnn_rmax_ukernel_function)(
1048     size_t n,
1049     const void* x,
1050     void* y);
1051 
1052 typedef void (*xnn_u8_rmax_ukernel_function)(
1053     size_t n,
1054     const uint8_t* x,
1055     uint8_t* y);
1056 
1057 typedef void (*xnn_f32_rmax_ukernel_function)(
1058     size_t n,
1059     const float* x,
1060     float* y);
1061 
1062 typedef void (*xnn_u8_lut32norm_ukernel_function)(
1063     size_t n,
1064     const uint8_t* x,
1065     const uint32_t* t,
1066     uint8_t* y);
1067 
1068 typedef void (*xnn_vadd_ukernel_function)(
1069     size_t n,
1070     const void* a,
1071     const void* b,
1072     void* y,
1073     const void* params);
1074 
1075 typedef void (*xnn_f32_vadd_ukernel_function)(
1076     size_t n,
1077     const float* a,
1078     const float* b,
1079     float* y,
1080     const union xnn_f32_output_params* params);
1081 
1082 typedef void (*xnn_q8_vadd_ukernel_function)(
1083     size_t n,
1084     const uint8_t* a,
1085     const uint8_t* b,
1086     uint8_t* y,
1087     const union xnn_q8_add_params* params);
1088 
1089 typedef void (*xnn_vbinary_ukernel_function)(
1090     size_t n,
1091     const void* a,
1092     const void* b,
1093     void* y,
1094     const void* params);
1095 
1096 typedef void (*xnn_f32_vbinary_ukernel_function)(
1097     size_t n,
1098     const float* a,
1099     const float* b,
1100     float* y,
1101     const union xnn_f32_output_params* params);
1102 
1103 typedef void (*xnn_vunary_ukernel_function)(
1104     size_t n,
1105     const void* x,
1106     void* y,
1107     const void* params);
1108 
1109 typedef void (*xnn_f32_vunary_ukernel_function)(
1110     size_t n,
1111     const float* x,
1112     float* y,
1113     const void* params);
1114 
1115 typedef void (*xnn_vmulcaddc_ukernel_function)(
1116     size_t m,
1117     size_t c,
1118     const void* x,
1119     size_t x_stride,
1120     const void* w,
1121     void* y,
1122     size_t y_stride,
1123     const void* params);
1124 
1125 typedef void (*xnn_f32_vmulcaddc_ukernel_function)(
1126     size_t m,
1127     size_t c,
1128     const float* x,
1129     size_t x_stride,
1130     const float* w,
1131     float* y,
1132     size_t y_stride,
1133     const union xnn_f32_output_params* params);
1134 
1135 typedef void (*xnn_prelu_ukernel_function)(
1136     size_t mr,
1137     size_t n,
1138     const void* x,
1139     size_t x_stride,
1140     const void* w,
1141     void* y,
1142     size_t y_stride,
1143     const void* params);
1144 
1145 typedef void (*xnn_f32_prelu_ukernel_function)(
1146     size_t mr,
1147     size_t n,
1148     const float* x,
1149     size_t x_stride,
1150     const float* w,
1151     float* y,
1152     size_t y_stride,
1153     const union xnn_f32_output_params* params);
1154 
1155 typedef void (*xnn_f32_raddexpminusmax_ukernel_function)(
1156     size_t n,
1157     const float* input,
1158     float* sum,
1159     float max);
1160 
1161 typedef void (*xnn_f32_raddstoreexpminusmax_ukernel_function)(
1162     size_t n,
1163     const float* input,
1164     float* output,
1165     float* sum,
1166     float max);
1167 
1168 typedef void (*xnn_f32_vscaleexpminusmax_ukernel_function)(
1169     size_t n,
1170     const float* input,
1171     float* output,
1172     float max,
1173     float scale);
1174 
1175 typedef void (*xnn_f32_vscale_ukernel_function)(
1176     size_t n,
1177     const float* x,
1178     float* y,
1179     float c);
1180 
1181 // Reduce-Add Extended ("mantissa" + "exponent") Exponentials
1182 typedef void (*xnn_f32_raddextexp_ukernel_function)(
1183     size_t n,
1184     const float* input,
1185     float* sum);
1186 
1187 // Vector Scale Extended ("mantissa" + "exponent") Exponentials
1188 typedef void (*xnn_f32_vscaleextexp_ukernel_function)(
1189     size_t n,
1190     const float* input,
1191     float* output,
1192     float scale_mantissa,
1193     float scale_exponent);
1194 
1195 
1196 struct gemm_parameters {
1197   xnn_gemm_ukernel_function gemm;
1198   xnn_igemm_ukernel_function igemm;
1199   // Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters.
1200   xnn_gemm_ukernel_function gemm1;
1201   xnn_igemm_ukernel_function igemm1;
1202   uint8_t mr;
1203   uint8_t nr;
1204   uint8_t log2_kr;
1205   uint8_t log2_sr;
1206 };
1207 
1208 struct vbinary_parameters {
1209   xnn_vbinary_ukernel_function op_ukernel;
1210   xnn_vbinary_ukernel_function opc_ukernel;
1211   xnn_vbinary_ukernel_function ropc_ukernel;
1212   // Number of elements in a tile.
1213   // For best efficiency, micro-kernel must process a multiple of this number of elements in each call.
1214   uint8_t element_tile;
1215 };
1216 
1217 struct spmm_parameters {
1218   xnn_spmm_ukernel_function ukernel;
1219   // Number of M-dimension elements in a tile.
1220   // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator.
1221   uint8_t mr;
1222   // Number of N-dimension elements in a tile.
1223   // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator.
1224   uint8_t nr;
1225 };
1226 
1227 struct hwc2spchw_dconv_parameters {
1228   xnn_conv_hwc2spchw_ukernel_function ukernel_with_symm_padding;
1229   // Number of output channels in a tile.
1230   // This parameter must be passed as is to weight packing function.
1231   uint8_t output_channel_tile;
1232   // Number of output height pixels in a tile.
1233   // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
1234   uint8_t output_height_tile;
1235   // Number of output width pixes in a tile.
1236   uint8_t output_width_tile;
1237 };
1238 
1239 struct spchw_dwconv_parameters {
1240   xnn_dwconv_spchw_ukernel_function ukernel;
1241   // Number of input width pixels in a tile.
1242   uint8_t input_width_tile;
1243   // Number of output width pixels in a tile.
1244   uint8_t output_width_tile;
1245   // Number of output height pixels in a tile.
1246   // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call.
1247   uint8_t output_height_tile;
1248 };
1249 
1250 struct spchw_gavgpool_parameters {
1251   xnn_gavgpool_spchw_ukernel_function ukernel;
1252   // Number of channels in a tile.
1253   // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
1254   uint8_t channel_tile;
1255 };
1256 
1257 struct dwconv_parameters {
1258   union {
1259     xnn_dwconv_up_ukernel_function up;
1260     xnn_dwconv_mp_ukernel_function mp;
1261   };
1262   uint8_t cr;
1263   uint8_t mr;
1264   uint8_t qr;
1265 };
1266 
1267 struct gavgpool_parameters {
1268   xnn_gavgpool_up_ukernel_function up;
1269   xnn_gavgpool_mp_ukernel_function mp;
1270   uint8_t mr;
1271 };
1272 
1273 struct avgpool_parameters {
1274   xnn_avgpool_up_ukernel_function up;
1275   xnn_avgpool_mp_ukernel_function mp;
1276   uint8_t mr;
1277   uint8_t qr;
1278 };
1279 
1280 struct pavgpool_parameters {
1281   xnn_pavgpool_up_ukernel_function up;
1282   xnn_pavgpool_mp_ukernel_function mp;
1283   uint8_t mr;
1284   uint8_t qr;
1285 };
1286 
1287 struct argmaxpool_parameters {
1288   union {
1289     xnn_argmaxpool_up_ukernel_function up;
1290     xnn_argmaxpool_mp_ukernel_function mp;
1291   };
1292   uint8_t mr;
1293   uint8_t qr;
1294 };
1295 
1296 struct maxpool_parameters {
1297   xnn_maxpool_ukernel_function ukernel;
1298   uint8_t mr;
1299   uint8_t qr;
1300 };
1301 
1302 struct bilinear_parameters {
1303   xnn_bilinear_ukernel_function ukernel;
1304   // Number of output pixels in a tile.
1305   // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call.
1306   uint8_t pixel_tile;
1307   // Number of channels in a tile.
1308   // For best efficiency, micro-kernel must process a multiple of this number of channels in each call.
1309   uint8_t channel_tile;
1310 };
1311 
1312 struct zip_parameters {
1313   xnn_zipc_ukernel_function x2;
1314   xnn_zipc_ukernel_function x3;
1315   xnn_zipc_ukernel_function x4;
1316   xnn_zipv_ukernel_function xm;
1317 };
1318 
1319 struct prelu_parameters {
1320   xnn_prelu_ukernel_function ukernel;
1321   uint16_t row_tile;
1322   uint16_t channel_tile;
1323 };
1324 
1325 struct pad_parameters {
1326   xnn_pad_ukernel_function ukernel;
1327   uint8_t mr;
1328 };
1329 
1330 struct vmulcaddc_parameters {
1331   xnn_vmulcaddc_ukernel_function ukernel;
1332   uint8_t channel_tile;
1333   uint8_t row_tile;
1334 };
1335 
1336 #define XNN_MAX_Q8_DWCONV_UKERNELS 1
1337 #define XNN_MAX_F32_DWCONV_UKERNELS 3
1338 #define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3
1339 
1340 struct xnn_parameters {
1341   bool initialized;
1342   struct xnn_allocator allocator;
1343   struct {
1344     struct gemm_parameters gemm;
1345     struct dwconv_parameters dwconv[XNN_MAX_Q8_DWCONV_UKERNELS];
1346     struct avgpool_parameters avgpool;
1347     struct gavgpool_parameters gavgpool;
1348     xnn_vadd_ukernel_function vadd;
1349   } q8;
1350   struct {
1351     struct maxpool_parameters maxpool;
1352     xnn_univector_ukernel_function clamp;
1353     xnn_u8_lut32norm_ukernel_function lut32norm;
1354     xnn_u8_rmax_ukernel_function rmax;
1355   } u8;
1356   struct {
1357     xnn_x8_lut_ukernel_function lut;
1358     struct zip_parameters zip;
1359   } x8;
1360   struct {
1361     struct gemm_parameters gemm;
1362     struct gemm_parameters gemm2;
1363     struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS];
1364     struct avgpool_parameters avgpool;
1365     struct pavgpool_parameters pavgpool;
1366     struct gavgpool_parameters gavgpool;
1367     struct maxpool_parameters maxpool;
1368     struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS];
1369     // Bilinear interpolation (2D).
1370     struct bilinear_parameters bilinear;
1371     xnn_univector_ukernel_function clamp;
1372     xnn_univector_ukernel_function hswish;
1373     xnn_univector_ukernel_function sigmoid;
1374     struct prelu_parameters prelu;
1375     struct vbinary_parameters vadd;
1376     struct vbinary_parameters vdiv;
1377     struct vbinary_parameters vmax;
1378     struct vbinary_parameters vmin;
1379     struct vbinary_parameters vmul;
1380     struct vbinary_parameters vsub;
1381     struct vmulcaddc_parameters vmulcaddc;
1382     xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax;
1383     xnn_f32_rmax_ukernel_function rmax;
1384     // Sparse Matrix-Dense Matrix Multiplication (NR=1 block).
1385     struct spmm_parameters spmm;
1386     // Sparse Matrix-Dense Matrix Multiplication (NR=2 block).
1387     struct spmm_parameters spmm2;
1388     // Sparse Matrix-Dense Matrix Multiplication (NR=4 block).
1389     struct spmm_parameters spmm4;
1390     // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->SpCHW layout conversion.
1391     struct hwc2spchw_dconv_parameters hwc2spchw_dconv3x3c3s2;
1392     // Direct 3x3 stride-1 Convolution with padding 1 on left and right in SpCHW layout.
1393     struct spchw_dwconv_parameters spchw_dwconv3x3;
1394     // Direct 3x3 stride-2 Convolution with padding 1 on left and right in SpCHW layout.
1395     struct spchw_dwconv_parameters spchw_dwconv3x3s2;
1396     // Direct 5x5 stride-1 Convolution with padding 2 on left and right in SpCHW layout.
1397     struct spchw_dwconv_parameters spchw_dwconv5x5;
1398     // Direct 5x5 stride-2 Convolution with padding 2 on left and right in SpCHW layout.
1399     struct spchw_dwconv_parameters spchw_dwconv5x5s2;
1400     // Global Average Pooling in SpCHW layout.
1401     struct spchw_gavgpool_parameters spchw_gavgpool;
1402   } f32;
1403   struct {
1404     struct pad_parameters pad;
1405     xnn_unpool_ukernel_function unpool;
1406     struct zip_parameters zip;
1407   } x32;
1408 };
1409 
1410 extern XNN_INTERNAL struct xnn_parameters xnn_params;
1411