• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 
9 #include <stddef.h>
10 #include <stdint.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/common.h>
14 #include <xnnpack/math.h>
15 #include <xnnpack/params.h>
16 
17 
18 enum xnn_parallelization_type {
19   xnn_parallelization_type_invalid = 0,
20   xnn_parallelization_type_1d,
21   xnn_parallelization_type_1d_tile_1d,
22   xnn_parallelization_type_2d,
23   xnn_parallelization_type_2d_tile_1d,
24   xnn_parallelization_type_2d_tile_2d,
25   xnn_parallelization_type_3d,
26   xnn_parallelization_type_3d_tile_2d,
27   xnn_parallelization_type_4d,
28   xnn_parallelization_type_4d_tile_2d,
29   xnn_parallelization_type_5d,
30   xnn_parallelization_type_5d_tile_2d,
31   xnn_parallelization_type_6d_tile_2d,
32 #if XNN_MAX_UARCH_TYPES > 1
33   xnn_parallelization_type_2d_tile_2d_with_uarch,
34   xnn_parallelization_type_3d_tile_2d_with_uarch,
35   xnn_parallelization_type_4d_tile_2d_with_uarch,
36 #endif  // XNN_MAX_UARCH_TYPES > 1
37 };
38 
39 struct compute_parameters {
40   enum xnn_parallelization_type type;
41   union {
42     pthreadpool_task_1d_t task_1d;
43     pthreadpool_task_1d_tile_1d_t task_1d_tile_1d;
44     pthreadpool_task_2d_t task_2d;
45     pthreadpool_task_2d_tile_1d_t task_2d_tile_1d;
46     pthreadpool_task_2d_tile_2d_t task_2d_tile_2d;
47     pthreadpool_task_3d_t task_3d;
48     pthreadpool_task_3d_tile_2d_t task_3d_tile_2d;
49     pthreadpool_task_4d_t task_4d;
50     pthreadpool_task_4d_tile_2d_t task_4d_tile_2d;
51     pthreadpool_task_5d_t task_5d;
52     pthreadpool_task_5d_tile_2d_t task_5d_tile_2d;
53     pthreadpool_task_6d_tile_2d_t task_6d_tile_2d;
54 #if XNN_MAX_UARCH_TYPES > 1
55     pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id;
56     pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id;
57     pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id;
58 #endif  // XNN_MAX_UARCH_TYPES > 1
59   };
60   size_t range[6];
61   size_t tile[2];
62 };
63 
64 struct transpose_context {
65   const void* x;
66   void* y;
67   union {
68     xnn_transposec_ukernel_function const_size_ukernel;
69     xnn_transposev_ukernel_function variable_size_ukernel;
70   };
71   union {
72     size_t element_size;
73     size_t log2_element_size;
74   };
75   size_t input_stride[XNN_MAX_TENSOR_DIMS];
76   size_t output_stride[XNN_MAX_TENSOR_DIMS];
77 };
78 
79 XNN_PRIVATE void xnn_compute_transposec_2d(
80     const struct transpose_context* context,
81     size_t i,
82     size_t j,
83     size_t tile_i,
84     size_t tile_j);
85 
86 XNN_PRIVATE void xnn_compute_transposec_3d(
87     const struct transpose_context* context,
88     size_t i,
89     size_t j,
90     size_t k,
91     size_t tile_j,
92     size_t tile_k);
93 
94 XNN_PRIVATE void xnn_compute_transposec_4d(
95     const struct transpose_context* context,
96     size_t i,
97     size_t j,
98     size_t k,
99     size_t l,
100     size_t tile_k,
101     size_t tile_l);
102 
103 XNN_PRIVATE void xnn_compute_transposec_5d(
104     const struct transpose_context* context,
105     size_t i,
106     size_t j,
107     size_t k,
108     size_t l,
109     size_t m,
110     size_t tile_l,
111     size_t tile_m);
112 
113 XNN_PRIVATE void xnn_compute_transposec_6d(
114     const struct transpose_context* context,
115     size_t i,
116     size_t j,
117     size_t k,
118     size_t l,
119     size_t m,
120     size_t n,
121     size_t tile_m,
122     size_t tile_n);
123 
124 XNN_PRIVATE void xnn_compute_transposev_2d(
125     const struct transpose_context* context,
126     size_t i,
127     size_t j,
128     size_t tile_i,
129     size_t tile_j);
130 
131 XNN_PRIVATE void xnn_compute_transposev_3d(
132     const struct transpose_context* context,
133     size_t i,
134     size_t j,
135     size_t k,
136     size_t tile_j,
137     size_t tile_k);
138 
139 XNN_PRIVATE void xnn_compute_transposev_4d(
140     const struct transpose_context* context,
141     size_t i,
142     size_t j,
143     size_t k,
144     size_t l,
145     size_t tile_k,
146     size_t tile_l);
147 
148 XNN_PRIVATE void xnn_compute_transposev_5d(
149     const struct transpose_context* context,
150     size_t i,
151     size_t j,
152     size_t k,
153     size_t l,
154     size_t m,
155     size_t tile_l,
156     size_t tile_m);
157 
158 XNN_PRIVATE void xnn_compute_transposev_6d(
159     const struct transpose_context* context,
160     size_t i,
161     size_t j,
162     size_t k,
163     size_t l,
164     size_t m,
165     size_t n,
166     size_t tile_m,
167     size_t tile_n);
168 
169 struct gemm_context {
170   size_t k_scaled;
171   const void* a;
172   size_t a_stride;
173   const void* packed_w;
174   size_t w_stride;
175   size_t wg_stride;
176   void* c;
177   size_t cm_stride;
178   size_t cn_stride;
179   size_t cg_stride;
180   uint32_t log2_csize;
181   struct xnn_hmp_gemm_ukernel ukernel;
182   void* fused_params;
183   union {
184     union xnn_qs8_conv_minmax_params qs8;
185     union xnn_qu8_conv_minmax_params qu8;
186     union xnn_f16_scaleminmax_params f16;
187     union xnn_f32_minmax_params f32;
188   } params;
189 };
190 
191 #ifndef __cplusplus
192   XNN_PRIVATE void xnn_compute_grouped_gemm(
193       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
194       size_t group_index,
195       size_t mr_block_start,
196       size_t nr_block_start,
197       size_t mr_block_size,
198       size_t nr_block_size);
199 
200   XNN_PRIVATE void xnn_compute_gemm(
201       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
202       size_t mr_block_start,
203       size_t nr_block_start,
204       size_t mr_block_size,
205       size_t nr_block_size);
206 
207   #if XNN_MAX_UARCH_TYPES > 1
208     XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
209         const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
210         uint32_t uarch_index,
211         size_t group_index,
212         size_t mr_block_start,
213         size_t nr_block_start,
214         size_t mr_block_size,
215         size_t nr_block_size);
216 
217     XNN_PRIVATE void xnn_compute_hmp_gemm(
218         const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
219         uint32_t uarch_index,
220         size_t mr_block_start,
221         size_t nr_block_start,
222         size_t mr_block_size,
223         size_t nr_block_size);
224   #endif  // XNN_MAX_UARCH_TYPES > 1
225 #endif
226 
227 // Context for Sparse Matrix-Dense Matrix Multiplication.
228 // C [MxN] := A [MxK] * B [KxN] + bias [N]
229 // A and C are dense matrices with row-major storage, B is a sparse matrix.
230 struct spmm_context {
231   // N dimension of the B and C matrices.
232   // Corresponds to number of output channels in 1x1 convolution.
233   size_t n;
234   // M dimension of the A and C matrices, pre-scaled by sizeof(element size).
235   // Corresponds to the stride, in bytes, between adjacent rows of C matrix.
236   size_t scaled_m;
237   // Input matrix A.
238   const void* input;
239   // Packed bias elements and non-zero filter elements.
240   const void* nonzero_weights;
241   // Input pointer increments, in bytes, after each processed non-zero weight.
242   const int32_t* input_increments;
243   // Number of non-zero filter elements per each N (output channel) dimension.
244   const uint32_t* output_channel_nonzeros;
245   // Output matrix C.
246   void* output;
247   // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution
248   size_t batched_input_stride;
249   // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution
250   size_t batched_output_stride;
251   // Micro-kernel function pointer.
252   xnn_spmm_ukernel_function ukernel;
253   // Output activation parameters.
254   union {
255     union xnn_f32_minmax_params f32;
256   } params;
257 };
258 
259 #ifndef __cplusplus
260   XNN_PRIVATE void xnn_compute_spmm(
261     const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
262     size_t batch_index,
263     size_t mr_block_start,
264     size_t mr_block_size);
265 #endif
266 
267 struct igemm_context {
268   size_t ks;
269   size_t ks_scaled;
270   size_t kc;
271   size_t w_stride;
272   const void** indirect_a;
273   size_t a_offset;
274   void* zero;
275   const void* packed_w;
276   void* c;
277   size_t cm_stride;
278   size_t cn_stride;
279   size_t ga_stride;
280   size_t gw_stride;
281   size_t gc_stride;
282   size_t ba_stride;
283   size_t bc_stride;
284   uint32_t log2_csize;
285   struct xnn_hmp_igemm_ukernel ukernel;
286   union {
287     union xnn_qs8_conv_minmax_params qs8;
288     union xnn_qu8_conv_minmax_params qu8;
289     union xnn_f16_scaleminmax_params f16;
290     union xnn_f32_minmax_params f32;
291   } params;
292 };
293 
294 #ifndef __cplusplus
295   XNN_PRIVATE void xnn_compute_grouped_igemm(
296       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
297       size_t group_index,
298       size_t mr_block_start,
299       size_t nr_block_start,
300       size_t mr_block_size,
301       size_t nr_block_size);
302 
303   XNN_PRIVATE void xnn_compute_grouped_batch_igemm(
304       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
305       size_t batch_index,
306       size_t group_index,
307       size_t mr_block_start,
308       size_t nr_block_start,
309       size_t mr_block_size,
310       size_t nr_block_size);
311 
312   XNN_PRIVATE void xnn_compute_igemm(
313       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
314       size_t mr_block_start,
315       size_t nr_block_start,
316       size_t mr_block_size,
317       size_t nr_block_size);
318 
319   XNN_PRIVATE void xnn_compute_batch_igemm(
320       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
321       size_t batch_index,
322       size_t mr_block_start,
323       size_t nr_block_start,
324       size_t mr_block_size,
325       size_t nr_block_size);
326 
327   #if XNN_MAX_UARCH_TYPES > 1
328     XNN_PRIVATE void xnn_compute_hmp_grouped_igemm(
329         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
330         uint32_t uarch_index,
331         size_t group_index,
332         size_t mr_block_start,
333         size_t nr_block_start,
334         size_t mr_block_size,
335         size_t nr_block_size);
336 
337     XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm(
338         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
339         uint32_t uarch_index,
340         size_t batch_index,
341         size_t group_index,
342         size_t mr_block_start,
343         size_t nr_block_start,
344         size_t mr_block_size,
345         size_t nr_block_size);
346 
347     XNN_PRIVATE void xnn_compute_hmp_igemm(
348         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
349         uint32_t uarch_index,
350         size_t mr_block_start,
351         size_t nr_block_start,
352         size_t mr_block_size,
353         size_t nr_block_size);
354 
355     XNN_PRIVATE void xnn_compute_batch_hmp_igemm(
356         const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
357         uint32_t uarch_index,
358         size_t batch_index,
359         size_t mr_block_start,
360         size_t nr_block_start,
361         size_t mr_block_size,
362         size_t nr_block_size);
363   #endif  // XNN_MAX_UARCH_TYPES > 1
364 #endif
365 
366 struct subgemm_context {
367   const struct subconvolution_params* subconvolution_params;
368   size_t kc;
369   const void* a;
370   size_t ax_stride;
371   size_t ay_stride;
372   size_t cx_stride;
373   size_t cy_stride;
374   size_t cn_stride;
375   size_t ga_stride;
376   size_t gw_stride;
377   size_t gc_stride;
378   size_t ba_stride;
379   size_t bc_stride;
380   uint32_t log2_csize;
381   struct xnn_hmp_gemm_ukernel ukernel;
382   union {
383     union xnn_qs8_conv_minmax_params qs8;
384     union xnn_qu8_conv_minmax_params qu8;
385     union xnn_f16_scaleminmax_params f16;
386     union xnn_f32_minmax_params f32;
387   } params;
388 };
389 
390 #ifndef __cplusplus
391   XNN_PRIVATE void xnn_compute_grouped_subgemm2d(
392       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
393       size_t batch_index,
394       size_t group_index,
395       size_t subkernel_index,
396       size_t slice_y,
397       size_t slice_x_start,
398       size_t nr_block_start,
399       size_t slice_x_max,
400       size_t nr_block_size);
401 
402   XNN_PRIVATE void xnn_compute_subgemm2d(
403       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
404       size_t batch_index,
405       size_t subkernel_index,
406       size_t slice_y,
407       size_t slice_x_start,
408       size_t nr_block_start,
409       size_t slice_x_max,
410       size_t nr_block_size);
411 #endif
412 
413 struct subconv_context {
414   const struct subconvolution_params* subconvolution_params;
415   size_t kc;
416   size_t a_offset;
417   void* zero;
418   size_t cx_stride;
419   size_t cy_stride;
420   size_t cn_stride;
421   size_t ga_stride;
422   size_t gw_stride;
423   size_t gc_stride;
424   size_t ba_stride;
425   size_t bc_stride;
426   uint32_t log2_csize;
427   struct xnn_hmp_igemm_ukernel ukernel;
428   union {
429     union xnn_qs8_conv_minmax_params qs8;
430     union xnn_qu8_conv_minmax_params qu8;
431     union xnn_f16_scaleminmax_params f16;
432     union xnn_f32_minmax_params f32;
433   } params;
434 };
435 
436 #ifndef __cplusplus
437   XNN_PRIVATE void xnn_compute_grouped_subconv2d(
438       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
439       size_t batch_index,
440       size_t group_index,
441       size_t subkernel_index,
442       size_t slice_y,
443       size_t slice_x_start,
444       size_t nr_block_start,
445       size_t slice_x_max,
446       size_t nr_block_size);
447 
448   XNN_PRIVATE void xnn_compute_subconv2d(
449       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
450       size_t batch_index,
451       size_t subkernel_index,
452       size_t slice_y,
453       size_t slice_x_start,
454       size_t nr_block_start,
455       size_t slice_x_max,
456       size_t nr_block_size);
457 #endif
458 
459 struct conv2d_context {
460   size_t input_height;
461   size_t input_width;
462   const void* input;
463   size_t input_batch_stride;
464   const void* zero;
465   const void* packed_weights;
466   void* output;
467   size_t output_batch_stride;
468   size_t input_padding_top;
469   size_t output_channels;
470   size_t output_height_stride;
471   size_t output_channel_stride;
472   union {
473     xnn_conv_hwc2chw_ukernel_function hwc2chw_ukernel;
474   };
475   union {
476     union xnn_f32_minmax_params f32;
477   } params;
478 };
479 
480 #ifndef __cplusplus
481   XNN_PRIVATE void xnn_compute_conv2d_hwc2chw(
482       const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
483       size_t batch_index,
484       size_t output_y_start,
485       size_t output_y_slice);
486 #endif
487 
488 struct dwconv_context {
489   const void** indirect_input;
490   size_t indirect_input_width_stride;
491   size_t indirect_input_height_stride;
492   size_t input_offset;
493   size_t input_batch_stride;
494   const void* packed_weights;
495   void* output;
496   size_t output_batch_stride;
497   size_t output_height_stride;
498   size_t output_width;
499   size_t groups;
500   const void* zero;
501   size_t output_increment;
502   union {
503     union xnn_qs8_conv_minmax_params qs8;
504     union xnn_qu8_conv_minmax_params qu8;
505     union xnn_f16_minmax_params f16;
506     union xnn_f32_minmax_params f32;
507   } params;
508   union {
509     xnn_dwconv_unipass_ukernel_function unipass_ukernel;
510   };
511 };
512 
513 #ifndef __cplusplus
514   XNN_PRIVATE void xnn_compute_dwconv_unipass(
515       const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
516       size_t batch_index,
517       size_t output_y);
518 #endif
519 
520 struct dwconv2d_context {
521   size_t input_height;
522   size_t input_width;
523   const void* input;
524   const void* zero;
525   uint32_t input_padding_top;
526   size_t input_channel_stride;
527   size_t input_batch_stride;
528   const void* packed_weights;
529   size_t weights_channel_stride;
530   void* output;
531   size_t output_channel_stride;
532   size_t output_batch_stride;
533   union {
534     union xnn_f32_chw_params f32;
535   } params;
536   union {
537     xnn_dwconv2d_chw_ukernel_function chw_ukernel;
538   };
539 };
540 
541 #ifndef __cplusplus
542   XNN_PRIVATE void xnn_compute_dwconv2d_chw(
543       const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
544       size_t batch_index,
545       size_t channel);
546 #endif
547 
548 struct max_pooling_context {
549   const void** indirect_input;
550   size_t indirect_input_height_stride;
551   size_t input_offset;
552   size_t input_batch_stride;
553   void* output;
554   size_t output_batch_stride;
555   size_t output_height_stride;
556   size_t output_width;
557   size_t pooling_size;
558   size_t channels;
559   size_t input_increment;
560   size_t output_increment;
561   union {
562     union xnn_u8_minmax_params u8;
563     union xnn_f32_minmax_params f32;
564   } params;
565   xnn_maxpool_ukernel_function ukernel;
566 };
567 
568 #ifndef __cplusplus
569   XNN_PRIVATE void xnn_compute_max_pooling(
570       const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
571       size_t batch_index,
572       size_t output_y);
573 #endif
574 
575 struct unpooling_context {
576   const void* input;
577   size_t input_height_stride;
578   size_t input_width_stride;
579   const uint32_t* index;
580   size_t index_height_stride;
581   size_t index_width_stride;
582   const void** indirect_output;
583   size_t indirect_output_height_stride;
584   size_t indirect_output_width_stride;
585   size_t pooling_size;
586   size_t channels;
587   uint32_t fill_value;
588   xnn_unpool_ukernel_function ukernel;
589 };
590 
591 #ifndef __cplusplus
592   XNN_PRIVATE void xnn_compute_unpooling(
593       const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
594       size_t input_y,
595       size_t input_x);
596 #endif
597 
598 struct argmax_pooling_context {
599   const void** indirect_input;
600   size_t indirect_input_height_stride;
601   size_t input_offset;
602   size_t input_batch_stride;
603   void* output;
604   size_t output_batch_stride;
605   size_t output_height_stride;
606   size_t output_width;
607   uint32_t* index;
608   size_t index_batch_stride;
609   size_t index_height_stride;
610   size_t pooling_size;
611   size_t channels;
612   size_t input_increment;
613   size_t output_increment;
614   union {
615     xnn_argmaxpool_unipass_ukernel_function unipass_ukernel;
616     xnn_argmaxpool_multipass_ukernel_function multipass_ukernel;
617   };
618 };
619 
620 #ifndef __cplusplus
621   XNN_PRIVATE void xnn_compute_argmax_pooling_unipass(
622       const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
623       size_t batch_index,
624       size_t output_y);
625 
626   XNN_PRIVATE void xnn_compute_argmax_pooling_multipass(
627       const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
628       size_t batch_index,
629       size_t output_y);
630 #endif
631 
632 struct average_pooling_context {
633   const void** indirect_input;
634   size_t indirect_input_height_stride;
635   size_t input_offset;
636   size_t input_batch_stride;
637   void* output;
638   size_t output_batch_stride;
639   size_t output_height_stride;
640   size_t output_width;
641   size_t pooling_size;
642   size_t channels;
643   const void* zero;
644   size_t input_increment;
645   size_t output_increment;
646   union {
647     union xnn_f16_scaleminmax_params f16;
648     union xnn_f32_scaleminmax_params f32;
649     union xnn_qu8_avgpool_minmax_params qu8;
650   } params;
651   union {
652     xnn_avgpool_unipass_ukernel_function unipass_ukernel;
653     xnn_avgpool_multipass_ukernel_function multipass_ukernel;
654   };
655 };
656 
657 #ifndef __cplusplus
658   XNN_PRIVATE void xnn_compute_average_pooling_unipass(
659       const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
660       size_t batch_index,
661       size_t output_y);
662 
663   XNN_PRIVATE void xnn_compute_average_pooling_multipass(
664       const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
665       size_t batch_index,
666       size_t output_y);
667 #endif
668 
669 struct pixelwise_average_pooling_context {
670   const void** indirect_input;
671   size_t indirect_input_height_stride;
672   size_t input_offset;
673   size_t input_batch_stride;
674   const void* pixelwise_buffer;
675   size_t pixelwise_buffer_height_stride;
676   void* output;
677   size_t output_batch_stride;
678   size_t output_height_stride;
679   size_t output_width;
680   size_t pooling_size;
681   size_t channels;
682   const void* zero;
683   size_t input_increment;
684   size_t output_increment;
685   union {
686     union xnn_f16_minmax_params f16;
687     union xnn_f32_minmax_params f32;
688     union xnn_u8_minmax_params u8;
689   } params;
690   union {
691     xnn_pavgpool_unipass_ukernel_function unipass_ukernel;
692     xnn_pavgpool_multipass_ukernel_function multipass_ukernel;
693   };
694 };
695 
696 #ifndef __cplusplus
697   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass(
698       const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
699       size_t batch_index,
700       size_t output_y);
701 
702   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass(
703       const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
704       size_t batch_index,
705       size_t output_y);
706 #endif
707 
708 struct global_average_pooling_nwc_context {
709   const void* input;
710   const void* zero;
711   size_t input_pixel_stride;
712   size_t input_batch_stride;
713   size_t input_elements;
714   size_t channels;
715   void* output;
716   size_t output_batch_stride;
717   union {
718     union xnn_qs8_avgpool_minmax_params qs8;
719     union xnn_qu8_avgpool_minmax_params qu8;
720     union xnn_f16_scaleminmax_params f16;
721     union xnn_f32_scaleminmax_params f32;
722   } params;
723   union {
724     xnn_gavgpool_unipass_ukernel_function unipass_ukernel;
725     xnn_gavgpool_multipass_ukernel_function multipass_ukernel;
726   };
727 };
728 
729 #ifndef __cplusplus
730   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass(
731       const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
732       size_t batch_index);
733 
734   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass(
735       const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
736       size_t batch_index);
737 #endif
738 
739 struct global_average_pooling_ncw_context {
740   size_t input_elements;
741   const void* input;
742   size_t input_channel_stride;
743   size_t input_batch_stride;
744   void* output;
745   size_t output_channel_stride;
746   size_t output_batch_stride;
747   xnn_gavgpool_cw_ukernel_function ukernel;
748   union {
749     union xnn_f32_gavgpool_params f32;
750   } params;
751 };
752 
753 #ifndef __cplusplus
754   XNN_PRIVATE void xnn_compute_global_average_pooling_ncw(
755       const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
756       size_t batch_index,
757       size_t channels_start,
758       size_t channels_slice);
759 #endif
760 
761 struct resize_bilinear_context {
762   // Number of channels multiplied by sizeof(input element).
763   size_t scaled_channels;
764   // Indirection buffer with pointers related to rows of input pixels.
765   const void** indirect_input;
766   // Offset, in bytes, to be added to pointers in indirection buffer.
767   size_t input_offset;
768   // Stride, in bytes, between images of consecutive batches in the input.
769   size_t input_batch_stride;
770   // Packed pairs of (x, y) linear interpolation coefficients.
771   const void* packed_weights;
772   // Pointer to the output tensor.
773   void* output;
774   // Stride, in bytes, between adjacent pixels in the output.
775   size_t output_pixel_stride;
776   // Stride, in bytes, between images of consecutive batches in the output.
777   size_t output_batch_stride;
778   // log2(sizeof(weight element)).
779   uint32_t log2_wsize;
780   // Pointer to BILINEAR micro-kernel function.
781   xnn_ibilinear_ukernel_function ukernel;
782 };
783 
784 struct resize_bilinear_chw_context {
785   // Number of pixels per output image plane.
786   size_t output_pixels;
787   // Number of channels multiplied by sizeof(input element).
788   size_t channels;
789   // Stride, in bytes, between adjacent channels in the input.
790   size_t input_channel_stride;
791   // Indirection buffer with pointers related to rows of input pixels.
792   const void** indirect_input;
793   // Offset, in bytes, to be added to pointers in indirection buffer.
794   size_t input_offset;
795   // Stride, in bytes, between images of consecutive batches in the input.
796   size_t input_batch_stride;
797   // Packed pairs of (x, y) linear interpolation coefficients.
798   const void* packed_weights;
799   // Pointer to the output tensor.
800   void* output;
801   // Stride, in bytes, between images of consecutive batches in the output.
802   size_t output_batch_stride;
803   // Stride, in bytes, between consecutive channels of an output image.
804   size_t output_channel_stride;
805   // Pointer to BILINEAR micro-kernel function.
806   xnn_ibilinear_chw_ukernel_function ukernel;
807 };
808 
809 #ifndef __cplusplus
810   XNN_PRIVATE void xnn_compute_resize_bilinear(
811       const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
812       size_t batch_index,
813       size_t pixel_start,
814       size_t pixel_range);
815   XNN_PRIVATE void xnn_compute_resize_bilinear_chw(
816     const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
817     size_t batch_index,
818     size_t pixel_start,
819     size_t pixel_range);
820 #endif
821 
822 struct elementwise_binary_context {
823   const void* a;
824   size_t a_stride[XNN_MAX_TENSOR_DIMS - 1];
825   const void* b;
826   size_t b_stride[XNN_MAX_TENSOR_DIMS - 1];
827   void* y;
828   size_t y_stride[XNN_MAX_TENSOR_DIMS - 1];
829   size_t elements;
830   union {
831     union xnn_qs8_add_minmax_params qs8_addsub;
832     union xnn_qu8_add_minmax_params qu8_addsub;
833     union xnn_qs8_mul_minmax_params qs8_mul;
834     union xnn_qu8_mul_minmax_params qu8_mul;
835     union xnn_f16_minmax_params f16;
836     union xnn_f32_minmax_params f32;
837   } params;
838   xnn_vbinary_ukernel_function ukernel;
839 };
840 
841 #ifndef __cplusplus
842   XNN_PRIVATE void xnn_compute_elementwise_binary_1d(
843       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
844       size_t i);
845   XNN_PRIVATE void xnn_compute_elementwise_binary_2d(
846       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
847       size_t i, size_t j);
848   XNN_PRIVATE void xnn_compute_elementwise_binary_3d(
849       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
850       size_t i, size_t j, size_t k);
851   XNN_PRIVATE void xnn_compute_elementwise_binary_4d(
852       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
853       size_t i, size_t j, size_t k, size_t l);
854   XNN_PRIVATE void xnn_compute_elementwise_binary_5d(
855       const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
856       size_t i, size_t j, size_t k, size_t l, size_t m);
857 #endif
858 
859 struct channel_shuffle_context {
860   const void* x;
861   size_t x_stride;
862   void* y;
863   size_t y_stride;
864   size_t n;
865   size_t m;
866   union {
867     xnn_zipc_ukernel_function fixed_ukernel;
868     xnn_zipv_ukernel_function variable_ukernel;
869   };
870 };
871 
872 #ifndef __cplusplus
873   XNN_PRIVATE void xnn_compute_channel_shuffle_fixed(
874       const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
875       size_t index);
876 
877   XNN_PRIVATE void xnn_compute_channel_shuffle_variable(
878       const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
879       size_t index);
880 #endif
881 
882 struct lut_strided_context {
883   size_t n;
884   const void* x;
885   size_t x_stride;
886   const void* t;
887   void* y;
888   size_t y_stride;
889   xnn_x8_lut_ukernel_function ukernel;
890 };
891 
892 #ifndef __cplusplus
893   XNN_PRIVATE void xnn_compute_lut_strided(
894       const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
895       size_t batch_index);
896 #endif
897 
898 struct lut_contiguous_context {
899   const void* x;
900   size_t x_stride;
901   const void* t;
902   void* y;
903   size_t y_stride;
904   xnn_x8_lut_ukernel_function ukernel;
905 };
906 
907 #ifndef __cplusplus
908   XNN_PRIVATE void xnn_compute_lut_contiguous(
909       const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
910       size_t offset,
911       size_t size);
912 #endif
913 
914 struct univector_strided_context {
915   size_t n;
916   const void* x;
917   size_t x_stride;
918   void* y;
919   size_t y_stride;
920   xnn_vunary_ukernel_function ukernel;
921   union {
922     union xnn_f16_abs_params f16_abs;
923     union xnn_f16_default_params f16_default;
924     union xnn_f16_f32_cvt_params f16_f32_cvt;
925     union xnn_f16_hswish_params f16_hswish;
926     union xnn_f16_lrelu_params f16_lrelu;
927     union xnn_f16_minmax_params f16_minmax;
928     union xnn_f16_neg_params f16_neg;
929     union xnn_f16_sigmoid_params f16_sigmoid;
930     union xnn_f32_abs_params f32_abs;
931     union xnn_f32_default_params f32_default;
932     union xnn_f32_elu_params f32_elu;
933     union xnn_f32_f16_cvt_params f32_f16_cvt;
934     union xnn_f32_hswish_params f32_hswish;
935     union xnn_f32_lrelu_params f32_lrelu;
936     union xnn_f32_minmax_params f32_minmax;
937     union xnn_f32_neg_params f32_neg;
938     union xnn_f32_qs8_cvt_params f32_qs8_cvt;
939     union xnn_f32_qu8_cvt_params f32_qu8_cvt;
940     union xnn_f32_rnd_params f32_rnd;
941     union xnn_f32_sigmoid_params f32_sigmoid;
942     union xnn_f32_sqrt_params f32_sqrt;
943     union xnn_qs8_cvt_params qs8_cvt;
944     union xnn_qs8_f32_cvt_params qs8_f32_cvt;
945     union xnn_qs8_lrelu_params qs8_lrelu;
946     union xnn_qu8_cvt_params qu8_cvt;
947     union xnn_qu8_f32_cvt_params qu8_f32_cvt;
948     union xnn_qu8_lrelu_params qu8_lrelu;
949     union xnn_s8_minmax_params s8_minmax;
950     union xnn_u8_minmax_params u8_minmax;
951   } params;
952 };
953 
954 #ifndef __cplusplus
955   XNN_PRIVATE void xnn_compute_univector_strided(
956       const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
957       size_t batch_index,
958       size_t batch_range);
959 #endif
960 
961 struct univector_contiguous_context {
962   const void* x;
963   void* y;
964   uint16_t log2_xsize;
965   uint16_t log2_ysize;
966   xnn_vunary_ukernel_function ukernel;
967   union {
968     union xnn_f16_abs_params f16_abs;
969     union xnn_f16_default_params f16_default;
970     union xnn_f16_f32_cvt_params f16_f32_cvt;
971     union xnn_f16_hswish_params f16_hswish;
972     union xnn_f16_lrelu_params f16_lrelu;
973     union xnn_f16_minmax_params f16_minmax;
974     union xnn_f16_neg_params f16_neg;
975     union xnn_f16_sigmoid_params f16_sigmoid;
976     union xnn_f32_abs_params f32_abs;
977     union xnn_f32_default_params f32_default;
978     union xnn_f32_elu_params f32_elu;
979     union xnn_f32_f16_cvt_params f32_f16_cvt;
980     union xnn_f32_hswish_params f32_hswish;
981     union xnn_f32_lrelu_params f32_lrelu;
982     union xnn_f32_minmax_params f32_minmax;
983     union xnn_f32_neg_params f32_neg;
984     union xnn_f32_qs8_cvt_params f32_qs8_cvt;
985     union xnn_f32_qu8_cvt_params f32_qu8_cvt;
986     union xnn_f32_rnd_params f32_rnd;
987     union xnn_f32_sigmoid_params f32_sigmoid;
988     union xnn_f32_sqrt_params f32_sqrt;
989     union xnn_qs8_cvt_params qs8_cvt;
990     union xnn_qs8_f32_cvt_params qs8_f32_cvt;
991     union xnn_qs8_lrelu_params qs8_lrelu;
992     union xnn_qu8_cvt_params qu8_cvt;
993     union xnn_qu8_f32_cvt_params qu8_f32_cvt;
994     union xnn_qu8_lrelu_params qu8_lrelu;
995     union xnn_s8_minmax_params s8_minmax;
996     union xnn_u8_minmax_params u8_minmax;
997   } params;
998 };
999 
1000 #ifndef __cplusplus
1001   XNN_PRIVATE void xnn_compute_univector_contiguous(
1002       const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1003       size_t offset,
1004       size_t size);
1005 #endif
1006 
1007 struct prelu_context {
1008   size_t n;
1009   const void* x;
1010   size_t x_stride;
1011   const void* w;
1012   void* y;
1013   size_t y_stride;
1014   xnn_prelu_ukernel_function ukernel;
1015 };
1016 
1017 #ifndef __cplusplus
1018   XNN_PRIVATE void xnn_compute_prelu(
1019       const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
1020       size_t batch_start,
1021       size_t batch_range);
1022 #endif
1023 
1024 struct vmulcaddc_context {
1025   size_t n;
1026   const void* x;
1027   size_t x_stride;
1028   const void* w;
1029   void* y;
1030   size_t y_stride;
1031   xnn_vmulcaddc_ukernel_function ukernel;
1032   union {
1033     union xnn_f16_minmax_params f16;
1034     union xnn_f32_minmax_params f32;
1035   } params;
1036 };
1037 
1038 #ifndef __cplusplus
1039   XNN_PRIVATE void xnn_compute_vmulcaddc(
1040       const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
1041       size_t batch_start,
1042       size_t batch_size);
1043 #endif
1044 
1045 struct pad_context {
1046   const void* input;
1047   size_t input_stride[XNN_MAX_TENSOR_DIMS - 1];
1048   void* output;
1049   size_t output_stride[XNN_MAX_TENSOR_DIMS - 1];
1050   size_t pre_paddings[XNN_MAX_TENSOR_DIMS];
1051   size_t post_paddings[1];
1052   size_t input_size[XNN_MAX_TENSOR_DIMS];
1053   size_t output_size[1];
1054   uint32_t padding_value;
1055   xnn_pad_ukernel_function pad_ukernel;
1056   xnn_fill_ukernel_function fill_ukernel;
1057 };
1058 
1059 #ifndef __cplusplus
1060   XNN_PRIVATE void xnn_compute_pad_5d(
1061       const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
1062       size_t i, size_t j, size_t k, size_t l, size_t m);
1063 #endif
1064 
1065 struct u8_softmax_context {
1066   size_t n;
1067   const uint8_t* x;
1068   size_t x_stride;
1069   const uint32_t* t;
1070   uint8_t* y;
1071   size_t y_stride;
1072   xnn_u8_rmax_ukernel_function rmax_ukernel;
1073   xnn_u8_lut32norm_ukernel_function lut_norm_ukernel;
1074 };
1075 
1076 #ifndef __cplusplus
1077   XNN_PRIVATE void xnn_compute_u8_softmax(
1078       const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1079       size_t batch_index);
1080 #endif
1081 
1082 typedef void (*xnn_compute_reciprocal_function)(const void* input, void* output);
1083 
1084 struct floating_point_softmax_context {
1085   size_t n;
1086   const void* x;
1087   size_t x_stride;
1088   void* y;
1089   size_t y_stride;
1090   xnn_rmax_ukernel_function rmax_ukernel;
1091   xnn_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel;
1092   xnn_compute_reciprocal_function compute_reciprocal;
1093   xnn_vbinary_ukernel_function vmulc_ukernel;
1094   union {
1095     union xnn_f16_minmax_params f16;
1096     union xnn_f32_minmax_params f32;
1097   } minmax_params;
1098   union {
1099     union xnn_f16_expminus_params f16;
1100     union xnn_f32_expminus_params f32;
1101   } expminus_params;
1102 };
1103 
1104 #ifndef __cplusplus
1105   XNN_PRIVATE void xnn_compute_floating_point_softmax(
1106       const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1107       size_t batch_index);
1108 #endif
1109