• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/operator.h>
17 #include <xnnpack/log.h>
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/compute.h>
22 
23 
xnn_compute_transposec_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)24 void xnn_compute_transposec_2d(
25     const struct transpose_context* context,
26     size_t i,
27     size_t j,
28     size_t tile_i,
29     size_t tile_j)
30 {
31   const size_t log2_element_size = context->log2_element_size;
32 
33   context->const_size_ukernel(
34       (const void*) ((uintptr_t) context->x + (i << log2_element_size) + j * context->input_stride[1]),
35       (void*) ((uintptr_t) context->y + (j << log2_element_size) + i * context->output_stride[0]),
36       context->input_stride[1],
37       context->output_stride[0],
38       tile_i,
39       tile_j);
40 }
41 
xnn_compute_transposec_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)42 void xnn_compute_transposec_3d(
43     const struct transpose_context* context,
44     size_t i,
45     size_t j,
46     size_t k,
47     size_t tile_j,
48     size_t tile_k)
49 {
50   const size_t log2_element_size = context->log2_element_size;
51   const size_t ld_input = context->input_stride[2];
52   const size_t ld_output = context->output_stride[1];
53   const void* x = (const void*) ((uintptr_t) context->x +
54                                  (i * context->input_stride[0] + j * context->input_stride[1]) + k * ld_input);
55   void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
56                      (k << log2_element_size));
57 
58   context->const_size_ukernel(
59       x,
60       y,
61       ld_input,
62       ld_output,
63       tile_j,
64       tile_k);
65 }
66 
xnn_compute_transposec_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)67 void xnn_compute_transposec_4d(
68     const struct transpose_context* context,
69     size_t i,
70     size_t j,
71     size_t k,
72     size_t l,
73     size_t tile_k,
74     size_t tile_l)
75 {
76   const size_t log2_element_size = context->log2_element_size;
77   const size_t ld_input = context->input_stride[3];
78   const size_t ld_output = context->output_stride[2];
79   const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
80                                  k * context->input_stride[2] + l * ld_input);
81   void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
82                      k * context->output_stride[2] + (l << log2_element_size));
83 
84   context->const_size_ukernel(
85       x,
86       y,
87       ld_input,
88       ld_output,
89       tile_k,
90       tile_l);
91 }
92 
xnn_compute_transposec_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)93 void xnn_compute_transposec_5d(
94     const struct transpose_context* context,
95     size_t i,
96     size_t j,
97     size_t k,
98     size_t l,
99     size_t m,
100     size_t tile_l,
101     size_t tile_m)
102 {
103   const size_t log2_element_size = context->log2_element_size;
104   const size_t ld_input = context->input_stride[4];
105   const size_t ld_output = context->output_stride[3];
106   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
107                                  k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
108   void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
109                      k * context->output_stride[2] + l * context->output_stride[3] + (m << log2_element_size));
110 
111   context->const_size_ukernel(
112       x,
113       y,
114       ld_input,
115       ld_output,
116       tile_l,
117       tile_m);
118 }
119 
xnn_compute_transposec_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)120 void xnn_compute_transposec_6d(
121     const struct transpose_context* context,
122     size_t i,
123     size_t j,
124     size_t k,
125     size_t l,
126     size_t m,
127     size_t n,
128     size_t tile_m,
129     size_t tile_n)
130 {
131   const size_t log2_element_size = context->log2_element_size;
132   const size_t ld_input = context->input_stride[5];
133   const size_t ld_output = context->output_stride[4];
134   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
135                                  k * context->input_stride[2] + l * context->input_stride[3] +
136                                  m * context->input_stride[4] + n * ld_input);
137   void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
138                      k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] +
139                      (n << log2_element_size));
140 
141   context->const_size_ukernel(
142       x,
143       y,
144       ld_input,
145       ld_output,
146       tile_m,
147       tile_n);
148 }
149 
xnn_compute_transposev_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)150 void xnn_compute_transposev_2d(
151     const struct transpose_context* context,
152     size_t i,
153     size_t j,
154     size_t tile_i,
155     size_t tile_j)
156 {
157   const size_t element_size = context->element_size;
158   const size_t ld_input = context->input_stride[1];
159   const size_t ld_output = context->output_stride[0];
160   const void* x = (const void*) ((uintptr_t) context->x +
161                                  i * context->input_stride[0] + j * ld_input);
162   void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]);
163 
164   context->variable_size_ukernel(
165       x,
166       y,
167       ld_input,
168       ld_output,
169       context->input_stride[0],
170       context->output_stride[1],
171       element_size,
172       tile_i,
173       tile_j);
174 }
175 
xnn_compute_transposev_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)176 void xnn_compute_transposev_3d(
177     const struct transpose_context* context,
178     size_t i,
179     size_t j,
180     size_t k,
181     size_t tile_j,
182     size_t tile_k)
183 {
184   const size_t element_size = context->element_size;
185   const size_t ld_input = context->input_stride[2];
186   const size_t ld_output = context->output_stride[1];
187   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
188                                  k * ld_input);
189   void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
190                      k * context->output_stride[2]);
191 
192   context->variable_size_ukernel(
193       x,
194       y,
195       ld_input,
196       ld_output,
197       context->input_stride[1],
198       context->output_stride[2],
199       element_size,
200       tile_j,
201       tile_k);
202 }
203 
xnn_compute_transposev_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)204 void xnn_compute_transposev_4d(
205     const struct transpose_context* context,
206     size_t i,
207     size_t j,
208     size_t k,
209     size_t l,
210     size_t tile_k,
211     size_t tile_l)
212 {
213   const size_t element_size = context->element_size;
214   const size_t ld_input = context->input_stride[3];
215   const size_t ld_output = context->output_stride[2];
216   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
217                                  k * context->input_stride[2] + l * ld_input);
218   void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] +
219                      j * context->output_stride[1] + k * context->output_stride[2]);
220 
221   context->variable_size_ukernel(
222       x,
223       y,
224       ld_input,
225       ld_output,
226       context->input_stride[2],
227       context->output_stride[3],
228       element_size,
229       tile_k,
230       tile_l);
231 }
232 
xnn_compute_transposev_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)233 void xnn_compute_transposev_5d(
234     const struct transpose_context* context,
235     size_t i,
236     size_t j,
237     size_t k,
238     size_t l,
239     size_t m,
240     size_t tile_l,
241     size_t tile_m)
242 {
243   const size_t element_size = context->element_size;
244   const size_t ld_input = context->input_stride[4];
245   const size_t ld_output = context->output_stride[3];
246   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
247                                  k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
248   void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] +
249                      j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]);
250 
251   context->variable_size_ukernel(
252       x,
253       y,
254       ld_input,
255       ld_output,
256       context->input_stride[3],
257       context->output_stride[4],
258       element_size,
259       tile_l,
260       tile_m);
261 }
262 
xnn_compute_transposev_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)263 void xnn_compute_transposev_6d(
264     const struct transpose_context* context,
265     size_t i,
266     size_t j,
267     size_t k,
268     size_t l,
269     size_t m,
270     size_t n,
271     size_t tile_m,
272     size_t tile_n)
273 {
274   const size_t element_size = context->element_size;
275   const size_t ld_input = context->input_stride[5];
276   const size_t ld_output = context->output_stride[4];
277   const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
278                                  k * context->input_stride[2] + l * context->input_stride[3] +
279                                  m * context->input_stride[4] + n * ld_input);
280   void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] +
281                      j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] +
282                      m * context->output_stride[4]);
283 
284   context->variable_size_ukernel(
285       x,
286       y,
287       ld_input,
288       ld_output,
289       context->input_stride[4],
290       context->output_stride[5],
291       element_size,
292       tile_m,
293       tile_n);
294 }
295 
xnn_compute_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)296 void xnn_compute_grouped_gemm(
297     const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
298     size_t group_index,
299     size_t mr_block_start,
300     size_t nr_block_start,
301     size_t mr_block_size,
302     size_t nr_block_size)
303 {
304   const size_t k_scaled  = context->k_scaled;
305   const size_t a_stride  = context->a_stride;
306   const size_t cm_stride = context->cm_stride;
307 
308   context->ukernel.function[XNN_UARCH_DEFAULT](
309       mr_block_size,
310       nr_block_size,
311       k_scaled,
312       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
313       a_stride,
314       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
315       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
316       cm_stride,
317       context->cn_stride,
318       &context->params);
319 }
320 
xnn_compute_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)321 void xnn_compute_gemm(
322     const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
323     size_t mr_block_start,
324     size_t nr_block_start,
325     size_t mr_block_size,
326     size_t nr_block_size)
327 {
328   const size_t a_stride  = context->a_stride;
329   const size_t cm_stride = context->cm_stride;
330 
331   context->ukernel.function[XNN_UARCH_DEFAULT](
332       mr_block_size,
333       nr_block_size,
334       context->k_scaled,
335       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
336       a_stride,
337       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
338       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
339       cm_stride,
340       context->cn_stride,
341       context->fused_params);
342 }
343 
xnn_compute_spmm(const struct spmm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t mr_block_size)344 void xnn_compute_spmm(
345     const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
346     size_t batch_index,
347     size_t mr_block_start,
348     size_t mr_block_size)
349 {
350   context->ukernel(
351       mr_block_size,
352       context->n,
353       (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
354       context->nonzero_weights,
355       context->input_increments,
356       context->output_channel_nonzeros,
357       (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
358       context->scaled_m,
359       &context->params);
360 }
361 
xnn_compute_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)362 void xnn_compute_grouped_batch_igemm(
363     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
364     size_t batch_index,
365     size_t group_index,
366     size_t mr_block_start,
367     size_t nr_block_start,
368     size_t mr_block_size,
369     size_t nr_block_size)
370 {
371   const size_t ks        = context->ks;
372   const size_t cm_stride = context->cm_stride;
373 
374   context->ukernel.function[XNN_UARCH_DEFAULT](
375       mr_block_size,
376       nr_block_size,
377       context->kc,
378       context->ks_scaled,
379       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
380       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
381       (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
382       cm_stride,
383       context->cn_stride,
384       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
385       context->zero,
386       &context->params);
387 }
388 
xnn_compute_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)389 void xnn_compute_grouped_igemm(
390     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
391     size_t group_index,
392     size_t mr_block_start,
393     size_t nr_block_start,
394     size_t mr_block_size,
395     size_t nr_block_size)
396 {
397   const size_t ks        = context->ks;
398   const size_t cm_stride = context->cm_stride;
399 
400   context->ukernel.function[XNN_UARCH_DEFAULT](
401       mr_block_size,
402       nr_block_size,
403       context->kc,
404       context->ks_scaled,
405       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
406       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
407       (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
408       cm_stride,
409       context->cn_stride,
410       context->a_offset + group_index * context->ga_stride,
411       context->zero,
412       &context->params);
413 }
414 
xnn_compute_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)415 void xnn_compute_batch_igemm(
416     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
417     size_t batch_index,
418     size_t mr_block_start,
419     size_t nr_block_start,
420     size_t mr_block_size,
421     size_t nr_block_size)
422 {
423   const size_t ks        = context->ks;
424   const size_t cm_stride = context->cm_stride;
425 
426   context->ukernel.function[XNN_UARCH_DEFAULT](
427       mr_block_size,
428       nr_block_size,
429       context->kc,
430       context->ks_scaled,
431       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
432       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
433       (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
434       cm_stride,
435       context->cn_stride,
436       context->a_offset + batch_index * context->ba_stride,
437       context->zero,
438       &context->params);
439 }
440 
xnn_compute_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)441 void xnn_compute_igemm(
442     const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
443     size_t mr_block_start,
444     size_t nr_block_start,
445     size_t mr_block_size,
446     size_t nr_block_size)
447 {
448   const size_t ks        = context->ks;
449   const size_t cm_stride = context->cm_stride;
450 
451   context->ukernel.function[XNN_UARCH_DEFAULT](
452       mr_block_size,
453       nr_block_size,
454       context->kc,
455       context->ks_scaled,
456       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
457       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
458       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
459       cm_stride,
460       context->cn_stride,
461       context->a_offset,
462       context->zero,
463       &context->params);
464 }
465 
xnn_compute_grouped_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)466 void xnn_compute_grouped_subgemm2d(
467       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
468       size_t batch_index,
469       size_t group_index,
470       size_t subkernel_index,
471       size_t slice_y,
472       size_t slice_x_start,
473       size_t nc_block_start,
474       size_t slice_x_max,
475       size_t nc_block_size)
476 {
477   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
478 
479   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
480     return;
481   }
482 
483   const size_t slice_width = subconvolution_params->slice_width;
484   if XNN_UNLIKELY(slice_x_start >= slice_width) {
485     return;
486   }
487   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
488 
489   const size_t ax_stride = context->ax_stride;
490   const size_t cx_stride = context->cx_stride;
491   context->ukernel.function[XNN_UARCH_DEFAULT](
492       slice_x_size,
493       nc_block_size,
494       context->kc,
495       (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
496       ax_stride,
497       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
498       (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
499       cx_stride,
500       context->cn_stride,
501       &context->params);
502 }
503 
xnn_compute_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)504 void xnn_compute_subgemm2d(
505       const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
506       size_t batch_index,
507       size_t subkernel_index,
508       size_t slice_y,
509       size_t slice_x_start,
510       size_t nc_block_start,
511       size_t slice_x_max,
512       size_t nc_block_size)
513 {
514   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
515 
516   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
517     return;
518   }
519 
520   const size_t slice_width = subconvolution_params->slice_width;
521   if XNN_UNLIKELY(slice_x_start >= slice_width) {
522     return;
523   }
524   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
525 
526   const size_t ax_stride = context->ax_stride;
527   const size_t cx_stride = context->cx_stride;
528   context->ukernel.function[XNN_UARCH_DEFAULT](
529       slice_x_size,
530       nc_block_size,
531       context->kc,
532       (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
533       ax_stride,
534       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
535       (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
536       cx_stride,
537       context->cn_stride,
538       &context->params);
539 }
540 
xnn_compute_grouped_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)541 void xnn_compute_grouped_subconv2d(
542       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
543       size_t batch_index,
544       size_t group_index,
545       size_t subkernel_index,
546       size_t slice_y,
547       size_t slice_x_start,
548       size_t nc_block_start,
549       size_t slice_x_max,
550       size_t nc_block_size)
551 {
552   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
553 
554   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
555     return;
556   }
557 
558   const size_t slice_width = subconvolution_params->slice_width;
559   if XNN_UNLIKELY(slice_x_start >= slice_width) {
560     return;
561   }
562   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
563 
564   const size_t cx_stride = context->cx_stride;
565   context->ukernel.function[XNN_UARCH_DEFAULT](
566       slice_x_size,
567       nc_block_size,
568       context->kc,
569       subconvolution_params->scaled_kernel_size,
570       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
571       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
572       (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
573       cx_stride,
574       context->cn_stride,
575       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
576       context->zero,
577       &context->params);
578 }
579 
xnn_compute_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)580 void xnn_compute_subconv2d(
581       const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
582       size_t batch_index,
583       size_t subkernel_index,
584       size_t slice_y,
585       size_t slice_x_start,
586       size_t nc_block_start,
587       size_t slice_x_max,
588       size_t nc_block_size)
589 {
590   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
591 
592   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
593     return;
594   }
595 
596   const size_t slice_width = subconvolution_params->slice_width;
597   if XNN_UNLIKELY(slice_x_start >= slice_width) {
598     return;
599   }
600   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
601 
602   const size_t cx_stride = context->cx_stride;
603   context->ukernel.function[XNN_UARCH_DEFAULT](
604       slice_x_size,
605       nc_block_size,
606       context->kc,
607       subconvolution_params->scaled_kernel_size,
608       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
609       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
610       (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
611       cx_stride,
612       context->cn_stride,
613       context->a_offset + batch_index * context->ba_stride,
614       context->zero,
615       &context->params);
616 }
617 
xnn_compute_conv2d_hwc2chw(const struct conv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y_start,size_t output_y_slice)618 void xnn_compute_conv2d_hwc2chw(
619       const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
620       size_t batch_index,
621       size_t output_y_start,
622       size_t output_y_slice)
623 {
624   context->hwc2chw_ukernel(
625       context->input_height,
626       context->input_width,
627       output_y_start,
628       output_y_start + output_y_slice,
629       (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
630       context->zero,
631       context->packed_weights,
632       (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
633       context->input_padding_top,
634       context->output_channels,
635       context->output_height_stride,
636       context->output_channel_stride,
637       &context->params);
638 }
639 
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)640 void xnn_compute_dwconv_unipass(
641     const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
642     size_t batch_index,
643     size_t output_y)
644 {
645   const void** indirect_input =
646     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
647   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
648   void* output = (void*) ((uintptr_t) context->output +
649     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
650 
651   context->unipass_ukernel(
652     context->groups, context->output_width,
653     indirect_input, context->packed_weights, output,
654     context->indirect_input_width_stride, context->output_increment,
655     input_offset, context->zero,
656     &context->params);
657 }
658 
xnn_compute_dwconv2d_chw(const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel)659 void xnn_compute_dwconv2d_chw(
660     const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
661     size_t batch_index,
662     size_t channel)
663 {
664   context->chw_ukernel(
665     context->input_height,
666     context->input_width,
667     (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
668     (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
669     context->zero,
670     (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
671     context->input_padding_top,
672     &context->params);
673 }
674 
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)675 void xnn_compute_argmax_pooling_unipass(
676     const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
677     size_t batch_index,
678     size_t output_y)
679 {
680   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
681     output_y * context->indirect_input_height_stride);
682   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
683   void* output = (void*) ((uintptr_t) context->output +
684     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
685   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
686     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
687 
688   context->unipass_ukernel(
689     context->output_width, context->pooling_size, context->channels,
690     indirect_input, input_offset, output, index,
691     context->input_increment, context->output_increment);
692 }
693 
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)694 void xnn_compute_argmax_pooling_multipass(
695     const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
696     size_t batch_index,
697     size_t output_y)
698 {
699   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
700     output_y * context->indirect_input_height_stride);
701   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
702   void* output = (void*) ((uintptr_t) context->output +
703     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
704   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
705     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
706 
707   void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
708   void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
709 
710   context->multipass_ukernel(
711     context->output_width, context->pooling_size, context->channels,
712     indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
713     context->input_increment, context->output_increment);
714 }
715 
xnn_compute_max_pooling(const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)716 void xnn_compute_max_pooling(
717     const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
718     size_t batch_index,
719     size_t output_y)
720 {
721   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
722     output_y * context->indirect_input_height_stride);
723   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
724   void* output = (void*) ((uintptr_t) context->output +
725     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
726 
727   context->ukernel(
728     context->output_width, context->pooling_size, context->channels,
729     indirect_input, input_offset, output,
730     context->input_increment, context->output_increment,
731     &context->params);
732 }
733 
xnn_compute_unpooling(const struct unpooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t input_y,size_t input_x)734 void xnn_compute_unpooling(
735     const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
736     size_t input_y,
737     size_t input_x)
738 {
739   const void* input = (const void*) ((uintptr_t) context->input +
740       input_y * context->input_height_stride + input_x * context->input_width_stride);
741   const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
742       input_y * context->index_height_stride + input_x * context->index_width_stride);
743   void** indirect_output =
744     (void**) ((uintptr_t) context->indirect_output +
745       input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
746 
747   context->ukernel(
748     context->pooling_size,
749     context->channels,
750     context->fill_value,
751     input, index, indirect_output);
752 }
753 
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)754 void xnn_compute_average_pooling_unipass(
755     const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
756     size_t batch_index,
757     size_t output_y)
758 {
759   const void** indirect_input =
760     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
761   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
762   void* output = (void*) ((uintptr_t) context->output +
763     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
764 
765   context->unipass_ukernel(
766     context->output_width, context->pooling_size, context->channels,
767     indirect_input, input_offset, context->zero, output,
768     context->input_increment, context->output_increment,
769     &context->params);
770 }
771 
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)772 void xnn_compute_average_pooling_multipass(
773     const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
774     size_t batch_index,
775     size_t output_y)
776 {
777   const void** indirect_input =
778     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
779   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
780   void* output = (void*) ((uintptr_t) context->output +
781     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
782 
783   void* multipass_buffer =
784     XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
785 
786   context->multipass_ukernel(
787     context->output_width, context->pooling_size, context->channels,
788     indirect_input, input_offset, context->zero, multipass_buffer, output,
789     context->input_increment, context->output_increment,
790     &context->params);
791 }
792 
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)793 void xnn_compute_pixelwise_average_pooling_unipass(
794     const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
795     size_t batch_index,
796     size_t output_y)
797 {
798   const void** indirect_input =
799     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
800   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
801   const void* pixelwise_buffer =
802     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
803   void* output = (void*) ((uintptr_t) context->output +
804     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
805 
806   context->unipass_ukernel(
807     context->output_width, context->pooling_size, context->channels,
808     indirect_input, input_offset, context->zero, pixelwise_buffer, output,
809     context->input_increment, context->output_increment,
810     &context->params);
811 }
812 
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)813 void xnn_compute_pixelwise_average_pooling_multipass(
814     const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
815     size_t batch_index,
816     size_t output_y)
817 {
818   const void** indirect_input =
819     (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
820   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
821   const void* pixelwise_buffer =
822     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
823   void* output = (void*) ((uintptr_t) context->output +
824     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
825 
826   void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
827 
828   context->multipass_ukernel(
829     context->output_width, context->pooling_size, context->channels,
830     indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
831     context->input_increment, context->output_increment,
832     &context->params);
833 }
834 
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)835 void xnn_compute_global_average_pooling_nwc_unipass(
836     const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
837     size_t batch_index)
838 {
839   const void* input =
840     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
841   void* output =
842     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
843 
844   context->unipass_ukernel(
845     context->input_elements,
846     context->channels,
847     input,
848     context->input_pixel_stride,
849     context->zero,
850     output,
851     &context->params);
852 }
853 
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)854 void xnn_compute_global_average_pooling_nwc_multipass(
855     const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
856     size_t batch_index)
857 {
858   const void* input =
859     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
860   void* output =
861     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
862 
863   void* multipass_buffer =
864     XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
865 
866   context->multipass_ukernel(
867     context->input_elements,
868     context->channels,
869     input,
870     context->input_pixel_stride,
871     context->zero,
872     multipass_buffer,
873     output,
874     &context->params);
875 }
876 
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channels_start,size_t channels_slice)877 void xnn_compute_global_average_pooling_ncw(
878     const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
879     size_t batch_index,
880     size_t channels_start,
881     size_t channels_slice)
882 {
883   const void* input = (const void*) ((uintptr_t) context->input +
884     channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
885   void* output = (void*) ((uintptr_t) context->output +
886     channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
887 
888   context->ukernel(
889     context->input_elements,
890     channels_slice,
891     input,
892     output,
893     &context->params);
894 }
895 
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t pixel_start,size_t pixel_range)896 void xnn_compute_resize_bilinear(
897     const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
898     size_t batch_index,
899     size_t pixel_start,
900     size_t pixel_range)
901 {
902   void* output =
903     (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
904 
905   context->ukernel(
906     pixel_range,
907     context->scaled_channels,
908     context->indirect_input + pixel_start * 4,
909     context->input_offset + batch_index * context->input_batch_stride,
910     (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
911     output,
912     context->output_pixel_stride - context->scaled_channels);
913 }
914 
xnn_compute_resize_bilinear_chw(const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel_start,size_t channel_range)915 void xnn_compute_resize_bilinear_chw(
916     const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
917     size_t batch_index,
918     size_t channel_start,
919     size_t channel_range)
920 {
921   void* output =
922     (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
923   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
924 
925   context->ukernel(
926     context->output_pixels,
927     channel_range,
928     context->indirect_input,
929     input_offset,
930     context->packed_weights,
931     output,
932     context->input_channel_stride);
933 }
934 
xnn_compute_prelu(const struct prelu_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_range)935 void xnn_compute_prelu(
936     const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
937     size_t batch_start,
938     size_t batch_range)
939 {
940   const size_t x_stride = context->x_stride;
941   const size_t y_stride = context->y_stride;
942   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
943   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
944 
945   context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
946 }
947 
xnn_compute_pad_5d(const struct pad_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)948 void xnn_compute_pad_5d(
949     const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
950     size_t i, size_t j, size_t k, size_t l, size_t m)
951 {
952   const void* input = (const void*) ((uintptr_t) context->input +
953     i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
954   void* output = (void*) ((uintptr_t) context->output +
955     i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
956 
957   const size_t i_padding = context->pre_paddings[5];
958   const size_t j_padding = context->pre_paddings[4];
959   const size_t k_padding = context->pre_paddings[3];
960   const size_t l_padding = context->pre_paddings[2];
961   const size_t m_padding = context->pre_paddings[1];
962 
963   const size_t i_size = context->input_size[5];
964   const size_t j_size = context->input_size[4];
965   const size_t k_size = context->input_size[3];
966   const size_t l_size = context->input_size[2];
967   const size_t m_size = context->input_size[1];
968 
969   if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
970                 l - l_padding < l_size && m - m_padding < m_size)
971   {
972     context->pad_ukernel(
973       1 /* rows */,
974       context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
975       input, 0 /* input stride */, output, 0 /* output stride */,
976       context->padding_value);
977   } else {
978     context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value);
979   }
980 }
981 
xnn_compute_elementwise_binary_1d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i)982 void xnn_compute_elementwise_binary_1d(
983     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
984     size_t i)
985 {
986   const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]);
987   const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]);
988   void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]);
989   context->ukernel(context->elements, a, b, y, &context->params);
990 }
991 
xnn_compute_elementwise_binary_2d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j)992 void xnn_compute_elementwise_binary_2d(
993     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
994     size_t i, size_t j)
995 {
996   const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]);
997   const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]);
998   void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]);
999   context->ukernel(context->elements, a, b, y, &context->params);
1000 }
1001 
xnn_compute_elementwise_binary_3d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k)1002 void xnn_compute_elementwise_binary_3d(
1003     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1004     size_t i, size_t j, size_t k)
1005 {
1006   const void* a = (const void*) ((uintptr_t) context->a +
1007     i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]);
1008   const void* b = (const void*) ((uintptr_t) context->b +
1009     i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]);
1010   void* y = (void*) ((uintptr_t) context->y +
1011     i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]);
1012   context->ukernel(context->elements, a, b, y, &context->params);
1013 }
1014 
xnn_compute_elementwise_binary_4d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l)1015 void xnn_compute_elementwise_binary_4d(
1016     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1017     size_t i, size_t j, size_t k, size_t l)
1018 {
1019   const void* a = (const void*) ((uintptr_t) context->a +
1020     i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]);
1021   const void* b = (const void*) ((uintptr_t) context->b +
1022     i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]);
1023   void* y = (void*) ((uintptr_t) context->y +
1024     i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]);
1025   context->ukernel(context->elements, a, b, y, &context->params);
1026 }
1027 
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)1028 void xnn_compute_elementwise_binary_5d(
1029     const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1030     size_t i, size_t j, size_t k, size_t l, size_t m)
1031 {
1032   const void* a = (const void*) ((uintptr_t) context->a +
1033     i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
1034   const void* b = (const void*) ((uintptr_t) context->b +
1035     i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
1036   void* y = (void*) ((uintptr_t) context->y +
1037     i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
1038   context->ukernel(context->elements, a, b, y, &context->params);
1039 }
1040 
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1041 void xnn_compute_channel_shuffle_fixed(
1042     const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1043     size_t index)
1044 {
1045   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1046   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1047 
1048   context->fixed_ukernel(context->n, x, y);
1049 }
1050 
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1051 void xnn_compute_channel_shuffle_variable(
1052     const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1053     size_t index)
1054 {
1055   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1056   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1057 
1058   context->variable_ukernel(context->n, context->m, x, y);
1059 }
1060 
xnn_compute_lut_strided(const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1061 void xnn_compute_lut_strided(
1062     const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1063     size_t batch_index)
1064 {
1065   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1066   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1067 
1068   context->ukernel(context->n, x, y, context->t);
1069 }
1070 
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1071 void xnn_compute_lut_contiguous(
1072     const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1073     size_t offset,
1074     size_t size)
1075 {
1076   const void* x = (const void*) ((uintptr_t) context->x + offset);
1077   void* y = (void*) ((uintptr_t) context->y + offset);
1078 
1079   context->ukernel(size, x, y, context->t);
1080 }
1081 
xnn_compute_univector_strided(const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t batch_range)1082 void xnn_compute_univector_strided(
1083     const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1084     size_t batch_index,
1085     size_t batch_range)
1086 {
1087   const size_t x_stride = context->x_stride;
1088   const size_t y_stride = context->y_stride;
1089 
1090   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index);
1091   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
1092   do {
1093     context->ukernel(context->n, x, y, &context->params);
1094     x = (const void*) ((uintptr_t) x + x_stride);
1095     y = (void*) ((uintptr_t) y + y_stride);
1096   } while (--batch_range != 0);
1097 }
1098 
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1099 void xnn_compute_univector_contiguous(
1100     const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1101     size_t offset,
1102     size_t size)
1103 {
1104   const uint32_t log2_xsize = context->log2_xsize;
1105   const uint32_t log2_ysize = context->log2_ysize;
1106   const void* x = (const void*) ((uintptr_t) context->x + offset);
1107   void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize));
1108   context->ukernel(size, x, y, &context->params);
1109 }
1110 
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1111 void xnn_compute_u8_softmax(
1112     const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1113     size_t batch_index)
1114 {
1115   const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
1116   uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
1117   const size_t n = context->n;
1118 
1119   uint8_t x_max = 0;
1120   context->rmax_ukernel(n, x, &x_max);
1121   const size_t adjustment = x_max ^ 255;
1122   const uint32_t* t = (const uint32_t*) context->t + adjustment;
1123   context->lut_norm_ukernel(n, x, t, y);
1124 }
1125 
xnn_compute_floating_point_softmax(const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1126 void xnn_compute_floating_point_softmax(
1127     const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1128     size_t batch_index)
1129 {
1130   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1131   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1132   const size_t n = context->n;
1133 
1134   // First pass: reduce-max
1135   union {
1136     float as_float;
1137     uint16_t as_half;
1138   } x_max;
1139   context->rmax_ukernel(n, x, &x_max);
1140 
1141   // Second pass: reduce-add & store exp(x-x_max)
1142   union {
1143     float as_float;
1144     uint16_t as_half;
1145   } y_sum;
1146   context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params);
1147 
1148   // Third pass: scale y
1149   union {
1150     float as_float;
1151     uint16_t as_half;
1152   } y_scale;
1153   context->compute_reciprocal(&y_sum, &y_scale);
1154   context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params);
1155 }
1156 
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_size)1157 void xnn_compute_vmulcaddc(
1158     const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
1159     size_t batch_start,
1160     size_t batch_size)
1161 {
1162   const size_t x_stride = context->x_stride;
1163   const size_t y_stride = context->y_stride;
1164 
1165   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
1166   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
1167 
1168   context->ukernel(
1169     batch_size,
1170     context->n,
1171     x, x_stride,
1172     context->w,
1173     y, y_stride,
1174     &context->params);
1175 }
1176 
1177 #if XNN_MAX_UARCH_TYPES > 1
xnn_compute_hmp_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1178   void xnn_compute_hmp_grouped_gemm(
1179       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1180       uint32_t uarch_index,
1181       size_t group_index,
1182       size_t mr_block_start,
1183       size_t nr_block_start,
1184       size_t mr_block_size,
1185       size_t nr_block_size)
1186   {
1187     const size_t k_scaled  = context->k_scaled;
1188     const size_t a_stride  = context->a_stride;
1189     const size_t cm_stride = context->cm_stride;
1190 
1191     context->ukernel.function[uarch_index](
1192         mr_block_size,
1193         nr_block_size,
1194         k_scaled,
1195         (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
1196         a_stride,
1197         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
1198         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
1199         cm_stride,
1200         context->cn_stride,
1201         &context->params);
1202   }
1203 
xnn_compute_hmp_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1204   void xnn_compute_hmp_gemm(
1205       const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1206       uint32_t uarch_index,
1207       size_t mr_block_start,
1208       size_t nr_block_start,
1209       size_t mr_block_size,
1210       size_t nr_block_size)
1211   {
1212     const size_t a_stride  = context->a_stride;
1213     const size_t cm_stride = context->cm_stride;
1214 
1215     context->ukernel.function[uarch_index](
1216         mr_block_size,
1217         nr_block_size,
1218         context->k_scaled,
1219         (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
1220         a_stride,
1221         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1222         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1223         cm_stride,
1224         context->cn_stride,
1225         context->fused_params);
1226   }
1227 
xnn_compute_hmp_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1228   void xnn_compute_hmp_grouped_batch_igemm(
1229       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1230       uint32_t uarch_index,
1231       size_t batch_index,
1232       size_t group_index,
1233       size_t mr_block_start,
1234       size_t nr_block_start,
1235       size_t mr_block_size,
1236       size_t nr_block_size)
1237   {
1238     const size_t ks        = context->ks;
1239     const size_t cm_stride = context->cm_stride;
1240 
1241     context->ukernel.function[uarch_index](
1242         mr_block_size,
1243         nr_block_size,
1244         context->kc,
1245         context->ks_scaled,
1246         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1247         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1248         (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1249         cm_stride,
1250         context->cn_stride,
1251         context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
1252         context->zero,
1253         &context->params);
1254   }
1255 
xnn_compute_hmp_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1256   void xnn_compute_hmp_grouped_igemm(
1257       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1258       uint32_t uarch_index,
1259       size_t group_index,
1260       size_t mr_block_start,
1261       size_t nr_block_start,
1262       size_t mr_block_size,
1263       size_t nr_block_size)
1264   {
1265     const size_t ks        = context->ks;
1266     const size_t cm_stride = context->cm_stride;
1267 
1268     context->ukernel.function[uarch_index](
1269         mr_block_size,
1270         nr_block_size,
1271         context->kc,
1272         context->ks_scaled,
1273         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1274         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1275         (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1276         cm_stride,
1277         context->cn_stride,
1278         context->a_offset + group_index * context->ga_stride,
1279         context->zero,
1280         &context->params);
1281   }
1282 
xnn_compute_batch_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1283   void xnn_compute_batch_hmp_igemm(
1284       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1285       uint32_t uarch_index,
1286       size_t batch_index,
1287       size_t mr_block_start,
1288       size_t nr_block_start,
1289       size_t mr_block_size,
1290       size_t nr_block_size)
1291   {
1292     const size_t ks        = context->ks;
1293     const size_t cm_stride = context->cm_stride;
1294 
1295     context->ukernel.function[uarch_index](
1296         mr_block_size,
1297         nr_block_size,
1298         context->kc,
1299         context->ks_scaled,
1300         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1301         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1302         (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1303         cm_stride,
1304         context->cn_stride,
1305         context->a_offset + batch_index * context->ba_stride,
1306         context->zero,
1307         &context->params);
1308   }
1309 
xnn_compute_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1310   void xnn_compute_hmp_igemm(
1311       const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1312       uint32_t uarch_index,
1313       size_t mr_block_start,
1314       size_t nr_block_start,
1315       size_t mr_block_size,
1316       size_t nr_block_size)
1317   {
1318     const size_t ks        = context->ks;
1319     const size_t cm_stride = context->cm_stride;
1320 
1321     context->ukernel.function[uarch_index](
1322         mr_block_size,
1323         nr_block_size,
1324         context->kc,
1325         context->ks_scaled,
1326         (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1327         (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1328         (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1329         cm_stride,
1330         context->cn_stride,
1331         context->a_offset,
1332         context->zero,
1333         &context->params);
1334   }
1335 #endif  // XNN_MAX_UARCH_TYPES > 1
1336 
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)1337 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1338 {
1339   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1340     xnn_log_error("failed to run operator: XNNPACK is not initialized");
1341     return xnn_status_uninitialized;
1342   }
1343   switch (op->state) {
1344     case xnn_run_state_invalid:
1345       xnn_log_error("failed to run operator: operator was not successfully setup");
1346       return xnn_status_invalid_state;
1347     case xnn_run_state_ready:
1348       break;
1349     case xnn_run_state_skip:
1350       return xnn_status_success;
1351   }
1352 
1353   uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS;
1354   if (op->flags & XNN_FLAG_YIELD_WORKERS) {
1355     flags |= PTHREADPOOL_FLAG_YIELD_WORKERS;
1356   }
1357   switch (op->compute.type) {
1358     case xnn_parallelization_type_invalid:
1359       break;
1360     case xnn_parallelization_type_1d:
1361       assert(op->compute.range[0] != 0);
1362       pthreadpool_parallelize_1d(
1363           threadpool,
1364           op->compute.task_1d,
1365           &op->context,
1366           op->compute.range[0],
1367           flags);
1368       break;
1369     case xnn_parallelization_type_1d_tile_1d:
1370       assert(op->compute.range[0] != 0);
1371       assert(op->compute.tile[0] != 0);
1372       pthreadpool_parallelize_1d_tile_1d(
1373           threadpool,
1374           op->compute.task_1d_tile_1d,
1375           &op->context,
1376           op->compute.range[0],
1377           op->compute.tile[0],
1378           flags);
1379       break;
1380     case xnn_parallelization_type_2d:
1381       assert(op->compute.range[0] != 0);
1382       assert(op->compute.range[1] != 0);
1383       pthreadpool_parallelize_2d(
1384           threadpool,
1385           op->compute.task_2d,
1386           &op->context,
1387           op->compute.range[0], op->compute.range[1],
1388           flags);
1389       break;
1390     case xnn_parallelization_type_2d_tile_1d:
1391       assert(op->compute.range[0] != 0);
1392       assert(op->compute.range[1] != 0);
1393       assert(op->compute.tile[0] != 0);
1394       pthreadpool_parallelize_2d_tile_1d(
1395           threadpool,
1396           op->compute.task_2d_tile_1d,
1397           &op->context,
1398           op->compute.range[0], op->compute.range[1],
1399           op->compute.tile[0],
1400           flags);
1401       break;
1402     case xnn_parallelization_type_2d_tile_2d:
1403       assert(op->compute.range[0] != 0);
1404       assert(op->compute.range[1] != 0);
1405       assert(op->compute.tile[0] != 0);
1406       assert(op->compute.tile[1] != 0);
1407       pthreadpool_parallelize_2d_tile_2d(
1408           threadpool,
1409           op->compute.task_2d_tile_2d,
1410           &op->context,
1411           op->compute.range[0], op->compute.range[1],
1412           op->compute.tile[0], op->compute.tile[1],
1413           flags);
1414       break;
1415     case xnn_parallelization_type_3d:
1416       assert(op->compute.range[0] != 0);
1417       assert(op->compute.range[1] != 0);
1418       assert(op->compute.range[2] != 0);
1419       pthreadpool_parallelize_3d(
1420           threadpool,
1421           op->compute.task_3d,
1422           &op->context,
1423           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1424           flags);
1425       break;
1426     case xnn_parallelization_type_3d_tile_2d:
1427       assert(op->compute.range[0] != 0);
1428       assert(op->compute.range[1] != 0);
1429       assert(op->compute.range[2] != 0);
1430       assert(op->compute.tile[0] != 0);
1431       assert(op->compute.tile[1] != 0);
1432       pthreadpool_parallelize_3d_tile_2d(
1433           threadpool,
1434           op->compute.task_3d_tile_2d,
1435           &op->context,
1436           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1437           op->compute.tile[0], op->compute.tile[1],
1438           flags);
1439       break;
1440     case xnn_parallelization_type_4d:
1441       assert(op->compute.range[0] != 0);
1442       assert(op->compute.range[1] != 0);
1443       assert(op->compute.range[2] != 0);
1444       assert(op->compute.range[3] != 0);
1445       pthreadpool_parallelize_4d(
1446           threadpool,
1447           op->compute.task_4d,
1448           &op->context,
1449           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1450           flags);
1451       break;
1452     case xnn_parallelization_type_4d_tile_2d:
1453       assert(op->compute.range[0] != 0);
1454       assert(op->compute.range[1] != 0);
1455       assert(op->compute.range[2] != 0);
1456       assert(op->compute.range[3] != 0);
1457       assert(op->compute.tile[0] != 0);
1458       assert(op->compute.tile[1] != 0);
1459       pthreadpool_parallelize_4d_tile_2d(
1460           threadpool,
1461           op->compute.task_4d_tile_2d,
1462           &op->context,
1463           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1464           op->compute.tile[0], op->compute.tile[1],
1465           flags);
1466       break;
1467     case xnn_parallelization_type_5d:
1468       assert(op->compute.range[0] != 0);
1469       assert(op->compute.range[1] != 0);
1470       assert(op->compute.range[2] != 0);
1471       assert(op->compute.range[3] != 0);
1472       assert(op->compute.range[4] != 0);
1473       pthreadpool_parallelize_5d(
1474           threadpool,
1475           op->compute.task_5d,
1476           &op->context,
1477           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1478           flags);
1479       break;
1480     case xnn_parallelization_type_5d_tile_2d:
1481       assert(op->compute.range[0] != 0);
1482       assert(op->compute.range[1] != 0);
1483       assert(op->compute.range[2] != 0);
1484       assert(op->compute.range[3] != 0);
1485       assert(op->compute.range[4] != 0);
1486       assert(op->compute.tile[0] != 0);
1487       assert(op->compute.tile[1] != 0);
1488       pthreadpool_parallelize_5d_tile_2d(
1489           threadpool,
1490           op->compute.task_5d_tile_2d,
1491           &op->context,
1492           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1493           op->compute.tile[0], op->compute.tile[1],
1494           flags);
1495       break;
1496     case xnn_parallelization_type_6d_tile_2d:
1497       assert(op->compute.range[0] != 0);
1498       assert(op->compute.range[1] != 0);
1499       assert(op->compute.range[2] != 0);
1500       assert(op->compute.range[3] != 0);
1501       assert(op->compute.range[4] != 0);
1502       assert(op->compute.range[5] != 0);
1503       assert(op->compute.tile[0] != 0);
1504       assert(op->compute.tile[1] != 0);
1505       pthreadpool_parallelize_6d_tile_2d(
1506           threadpool,
1507           op->compute.task_6d_tile_2d,
1508           &op->context,
1509           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1510           op->compute.tile[0], op->compute.tile[1],
1511           flags);
1512       break;
1513 #if XNN_MAX_UARCH_TYPES > 1
1514     case xnn_parallelization_type_2d_tile_2d_with_uarch:
1515       assert(op->compute.range[0] != 0);
1516       assert(op->compute.range[1] != 0);
1517       assert(op->compute.tile[0] != 0);
1518       assert(op->compute.tile[1] != 0);
1519       pthreadpool_parallelize_2d_tile_2d_with_uarch(
1520           threadpool,
1521           op->compute.task_2d_tile_2d_with_id,
1522           &op->context,
1523           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1524           op->compute.range[0], op->compute.range[1],
1525           op->compute.tile[0], op->compute.tile[1],
1526           flags);
1527       break;
1528     case xnn_parallelization_type_3d_tile_2d_with_uarch:
1529       assert(op->compute.range[0] != 0);
1530       assert(op->compute.range[1] != 0);
1531       assert(op->compute.range[2] != 0);
1532       assert(op->compute.tile[0] != 0);
1533       assert(op->compute.tile[1] != 0);
1534       pthreadpool_parallelize_3d_tile_2d_with_uarch(
1535           threadpool,
1536           op->compute.task_3d_tile_2d_with_id,
1537           &op->context,
1538           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1539           op->compute.range[0], op->compute.range[1], op->compute.range[2],
1540           op->compute.tile[0], op->compute.tile[1],
1541           flags);
1542       break;
1543     case xnn_parallelization_type_4d_tile_2d_with_uarch:
1544       assert(op->compute.range[0] != 0);
1545       assert(op->compute.range[1] != 0);
1546       assert(op->compute.range[2] != 0);
1547       assert(op->compute.range[3] != 0);
1548       assert(op->compute.tile[0] != 0);
1549       assert(op->compute.tile[1] != 0);
1550       pthreadpool_parallelize_4d_tile_2d_with_uarch(
1551           threadpool,
1552           op->compute.task_4d_tile_2d_with_id,
1553           &op->context,
1554           0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1555           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1556           op->compute.tile[0], op->compute.tile[1],
1557           flags);
1558       break;
1559 #endif  // XNN_MAX_UARCH_TYPES > 1
1560     default:
1561       XNN_UNREACHABLE;
1562   }
1563   return xnn_status_success;
1564 }
1565