1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/operator.h>
17 #include <xnnpack/log.h>
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/compute.h>
22
23
xnn_compute_transposec_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)24 void xnn_compute_transposec_2d(
25 const struct transpose_context* context,
26 size_t i,
27 size_t j,
28 size_t tile_i,
29 size_t tile_j)
30 {
31 const size_t log2_element_size = context->log2_element_size;
32
33 context->const_size_ukernel(
34 (const void*) ((uintptr_t) context->x + (i << log2_element_size) + j * context->input_stride[1]),
35 (void*) ((uintptr_t) context->y + (j << log2_element_size) + i * context->output_stride[0]),
36 context->input_stride[1],
37 context->output_stride[0],
38 tile_i,
39 tile_j);
40 }
41
xnn_compute_transposec_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)42 void xnn_compute_transposec_3d(
43 const struct transpose_context* context,
44 size_t i,
45 size_t j,
46 size_t k,
47 size_t tile_j,
48 size_t tile_k)
49 {
50 const size_t log2_element_size = context->log2_element_size;
51 const size_t ld_input = context->input_stride[2];
52 const size_t ld_output = context->output_stride[1];
53 const void* x = (const void*) ((uintptr_t) context->x +
54 (i * context->input_stride[0] + j * context->input_stride[1]) + k * ld_input);
55 void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
56 (k << log2_element_size));
57
58 context->const_size_ukernel(
59 x,
60 y,
61 ld_input,
62 ld_output,
63 tile_j,
64 tile_k);
65 }
66
xnn_compute_transposec_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)67 void xnn_compute_transposec_4d(
68 const struct transpose_context* context,
69 size_t i,
70 size_t j,
71 size_t k,
72 size_t l,
73 size_t tile_k,
74 size_t tile_l)
75 {
76 const size_t log2_element_size = context->log2_element_size;
77 const size_t ld_input = context->input_stride[3];
78 const size_t ld_output = context->output_stride[2];
79 const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
80 k * context->input_stride[2] + l * ld_input);
81 void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
82 k * context->output_stride[2] + (l << log2_element_size));
83
84 context->const_size_ukernel(
85 x,
86 y,
87 ld_input,
88 ld_output,
89 tile_k,
90 tile_l);
91 }
92
xnn_compute_transposec_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)93 void xnn_compute_transposec_5d(
94 const struct transpose_context* context,
95 size_t i,
96 size_t j,
97 size_t k,
98 size_t l,
99 size_t m,
100 size_t tile_l,
101 size_t tile_m)
102 {
103 const size_t log2_element_size = context->log2_element_size;
104 const size_t ld_input = context->input_stride[4];
105 const size_t ld_output = context->output_stride[3];
106 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
107 k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
108 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
109 k * context->output_stride[2] + l * context->output_stride[3] + (m << log2_element_size));
110
111 context->const_size_ukernel(
112 x,
113 y,
114 ld_input,
115 ld_output,
116 tile_l,
117 tile_m);
118 }
119
xnn_compute_transposec_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)120 void xnn_compute_transposec_6d(
121 const struct transpose_context* context,
122 size_t i,
123 size_t j,
124 size_t k,
125 size_t l,
126 size_t m,
127 size_t n,
128 size_t tile_m,
129 size_t tile_n)
130 {
131 const size_t log2_element_size = context->log2_element_size;
132 const size_t ld_input = context->input_stride[5];
133 const size_t ld_output = context->output_stride[4];
134 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
135 k * context->input_stride[2] + l * context->input_stride[3] +
136 m * context->input_stride[4] + n * ld_input);
137 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
138 k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] +
139 (n << log2_element_size));
140
141 context->const_size_ukernel(
142 x,
143 y,
144 ld_input,
145 ld_output,
146 tile_m,
147 tile_n);
148 }
149
xnn_compute_transposev_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)150 void xnn_compute_transposev_2d(
151 const struct transpose_context* context,
152 size_t i,
153 size_t j,
154 size_t tile_i,
155 size_t tile_j)
156 {
157 const size_t element_size = context->element_size;
158 const size_t ld_input = context->input_stride[1];
159 const size_t ld_output = context->output_stride[0];
160 const void* x = (const void*) ((uintptr_t) context->x +
161 i * context->input_stride[0] + j * ld_input);
162 void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]);
163
164 context->variable_size_ukernel(
165 x,
166 y,
167 ld_input,
168 ld_output,
169 context->input_stride[0],
170 context->output_stride[1],
171 element_size,
172 tile_i,
173 tile_j);
174 }
175
xnn_compute_transposev_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)176 void xnn_compute_transposev_3d(
177 const struct transpose_context* context,
178 size_t i,
179 size_t j,
180 size_t k,
181 size_t tile_j,
182 size_t tile_k)
183 {
184 const size_t element_size = context->element_size;
185 const size_t ld_input = context->input_stride[2];
186 const size_t ld_output = context->output_stride[1];
187 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
188 k * ld_input);
189 void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
190 k * context->output_stride[2]);
191
192 context->variable_size_ukernel(
193 x,
194 y,
195 ld_input,
196 ld_output,
197 context->input_stride[1],
198 context->output_stride[2],
199 element_size,
200 tile_j,
201 tile_k);
202 }
203
xnn_compute_transposev_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)204 void xnn_compute_transposev_4d(
205 const struct transpose_context* context,
206 size_t i,
207 size_t j,
208 size_t k,
209 size_t l,
210 size_t tile_k,
211 size_t tile_l)
212 {
213 const size_t element_size = context->element_size;
214 const size_t ld_input = context->input_stride[3];
215 const size_t ld_output = context->output_stride[2];
216 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
217 k * context->input_stride[2] + l * ld_input);
218 void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] +
219 j * context->output_stride[1] + k * context->output_stride[2]);
220
221 context->variable_size_ukernel(
222 x,
223 y,
224 ld_input,
225 ld_output,
226 context->input_stride[2],
227 context->output_stride[3],
228 element_size,
229 tile_k,
230 tile_l);
231 }
232
xnn_compute_transposev_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)233 void xnn_compute_transposev_5d(
234 const struct transpose_context* context,
235 size_t i,
236 size_t j,
237 size_t k,
238 size_t l,
239 size_t m,
240 size_t tile_l,
241 size_t tile_m)
242 {
243 const size_t element_size = context->element_size;
244 const size_t ld_input = context->input_stride[4];
245 const size_t ld_output = context->output_stride[3];
246 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
247 k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
248 void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] +
249 j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]);
250
251 context->variable_size_ukernel(
252 x,
253 y,
254 ld_input,
255 ld_output,
256 context->input_stride[3],
257 context->output_stride[4],
258 element_size,
259 tile_l,
260 tile_m);
261 }
262
xnn_compute_transposev_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)263 void xnn_compute_transposev_6d(
264 const struct transpose_context* context,
265 size_t i,
266 size_t j,
267 size_t k,
268 size_t l,
269 size_t m,
270 size_t n,
271 size_t tile_m,
272 size_t tile_n)
273 {
274 const size_t element_size = context->element_size;
275 const size_t ld_input = context->input_stride[5];
276 const size_t ld_output = context->output_stride[4];
277 const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
278 k * context->input_stride[2] + l * context->input_stride[3] +
279 m * context->input_stride[4] + n * ld_input);
280 void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] +
281 j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] +
282 m * context->output_stride[4]);
283
284 context->variable_size_ukernel(
285 x,
286 y,
287 ld_input,
288 ld_output,
289 context->input_stride[4],
290 context->output_stride[5],
291 element_size,
292 tile_m,
293 tile_n);
294 }
295
xnn_compute_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)296 void xnn_compute_grouped_gemm(
297 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
298 size_t group_index,
299 size_t mr_block_start,
300 size_t nr_block_start,
301 size_t mr_block_size,
302 size_t nr_block_size)
303 {
304 const size_t k_scaled = context->k_scaled;
305 const size_t a_stride = context->a_stride;
306 const size_t cm_stride = context->cm_stride;
307
308 context->ukernel.function[XNN_UARCH_DEFAULT](
309 mr_block_size,
310 nr_block_size,
311 k_scaled,
312 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
313 a_stride,
314 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
315 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
316 cm_stride,
317 context->cn_stride,
318 &context->params);
319 }
320
xnn_compute_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)321 void xnn_compute_gemm(
322 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
323 size_t mr_block_start,
324 size_t nr_block_start,
325 size_t mr_block_size,
326 size_t nr_block_size)
327 {
328 const size_t a_stride = context->a_stride;
329 const size_t cm_stride = context->cm_stride;
330
331 context->ukernel.function[XNN_UARCH_DEFAULT](
332 mr_block_size,
333 nr_block_size,
334 context->k_scaled,
335 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
336 a_stride,
337 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
338 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
339 cm_stride,
340 context->cn_stride,
341 context->fused_params);
342 }
343
xnn_compute_spmm(const struct spmm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t mr_block_size)344 void xnn_compute_spmm(
345 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
346 size_t batch_index,
347 size_t mr_block_start,
348 size_t mr_block_size)
349 {
350 context->ukernel(
351 mr_block_size,
352 context->n,
353 (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
354 context->nonzero_weights,
355 context->input_increments,
356 context->output_channel_nonzeros,
357 (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
358 context->scaled_m,
359 &context->params);
360 }
361
xnn_compute_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)362 void xnn_compute_grouped_batch_igemm(
363 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
364 size_t batch_index,
365 size_t group_index,
366 size_t mr_block_start,
367 size_t nr_block_start,
368 size_t mr_block_size,
369 size_t nr_block_size)
370 {
371 const size_t ks = context->ks;
372 const size_t cm_stride = context->cm_stride;
373
374 context->ukernel.function[XNN_UARCH_DEFAULT](
375 mr_block_size,
376 nr_block_size,
377 context->kc,
378 context->ks_scaled,
379 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
380 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
381 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
382 cm_stride,
383 context->cn_stride,
384 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
385 context->zero,
386 &context->params);
387 }
388
xnn_compute_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)389 void xnn_compute_grouped_igemm(
390 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
391 size_t group_index,
392 size_t mr_block_start,
393 size_t nr_block_start,
394 size_t mr_block_size,
395 size_t nr_block_size)
396 {
397 const size_t ks = context->ks;
398 const size_t cm_stride = context->cm_stride;
399
400 context->ukernel.function[XNN_UARCH_DEFAULT](
401 mr_block_size,
402 nr_block_size,
403 context->kc,
404 context->ks_scaled,
405 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
406 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
407 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
408 cm_stride,
409 context->cn_stride,
410 context->a_offset + group_index * context->ga_stride,
411 context->zero,
412 &context->params);
413 }
414
xnn_compute_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)415 void xnn_compute_batch_igemm(
416 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
417 size_t batch_index,
418 size_t mr_block_start,
419 size_t nr_block_start,
420 size_t mr_block_size,
421 size_t nr_block_size)
422 {
423 const size_t ks = context->ks;
424 const size_t cm_stride = context->cm_stride;
425
426 context->ukernel.function[XNN_UARCH_DEFAULT](
427 mr_block_size,
428 nr_block_size,
429 context->kc,
430 context->ks_scaled,
431 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
432 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
433 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
434 cm_stride,
435 context->cn_stride,
436 context->a_offset + batch_index * context->ba_stride,
437 context->zero,
438 &context->params);
439 }
440
xnn_compute_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)441 void xnn_compute_igemm(
442 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
443 size_t mr_block_start,
444 size_t nr_block_start,
445 size_t mr_block_size,
446 size_t nr_block_size)
447 {
448 const size_t ks = context->ks;
449 const size_t cm_stride = context->cm_stride;
450
451 context->ukernel.function[XNN_UARCH_DEFAULT](
452 mr_block_size,
453 nr_block_size,
454 context->kc,
455 context->ks_scaled,
456 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
457 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
458 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
459 cm_stride,
460 context->cn_stride,
461 context->a_offset,
462 context->zero,
463 &context->params);
464 }
465
xnn_compute_grouped_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)466 void xnn_compute_grouped_subgemm2d(
467 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
468 size_t batch_index,
469 size_t group_index,
470 size_t subkernel_index,
471 size_t slice_y,
472 size_t slice_x_start,
473 size_t nc_block_start,
474 size_t slice_x_max,
475 size_t nc_block_size)
476 {
477 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
478
479 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
480 return;
481 }
482
483 const size_t slice_width = subconvolution_params->slice_width;
484 if XNN_UNLIKELY(slice_x_start >= slice_width) {
485 return;
486 }
487 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
488
489 const size_t ax_stride = context->ax_stride;
490 const size_t cx_stride = context->cx_stride;
491 context->ukernel.function[XNN_UARCH_DEFAULT](
492 slice_x_size,
493 nc_block_size,
494 context->kc,
495 (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
496 ax_stride,
497 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
498 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
499 cx_stride,
500 context->cn_stride,
501 &context->params);
502 }
503
xnn_compute_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)504 void xnn_compute_subgemm2d(
505 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
506 size_t batch_index,
507 size_t subkernel_index,
508 size_t slice_y,
509 size_t slice_x_start,
510 size_t nc_block_start,
511 size_t slice_x_max,
512 size_t nc_block_size)
513 {
514 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
515
516 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
517 return;
518 }
519
520 const size_t slice_width = subconvolution_params->slice_width;
521 if XNN_UNLIKELY(slice_x_start >= slice_width) {
522 return;
523 }
524 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
525
526 const size_t ax_stride = context->ax_stride;
527 const size_t cx_stride = context->cx_stride;
528 context->ukernel.function[XNN_UARCH_DEFAULT](
529 slice_x_size,
530 nc_block_size,
531 context->kc,
532 (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
533 ax_stride,
534 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
535 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
536 cx_stride,
537 context->cn_stride,
538 &context->params);
539 }
540
xnn_compute_grouped_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)541 void xnn_compute_grouped_subconv2d(
542 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
543 size_t batch_index,
544 size_t group_index,
545 size_t subkernel_index,
546 size_t slice_y,
547 size_t slice_x_start,
548 size_t nc_block_start,
549 size_t slice_x_max,
550 size_t nc_block_size)
551 {
552 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
553
554 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
555 return;
556 }
557
558 const size_t slice_width = subconvolution_params->slice_width;
559 if XNN_UNLIKELY(slice_x_start >= slice_width) {
560 return;
561 }
562 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
563
564 const size_t cx_stride = context->cx_stride;
565 context->ukernel.function[XNN_UARCH_DEFAULT](
566 slice_x_size,
567 nc_block_size,
568 context->kc,
569 subconvolution_params->scaled_kernel_size,
570 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
571 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
572 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
573 cx_stride,
574 context->cn_stride,
575 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
576 context->zero,
577 &context->params);
578 }
579
xnn_compute_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)580 void xnn_compute_subconv2d(
581 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
582 size_t batch_index,
583 size_t subkernel_index,
584 size_t slice_y,
585 size_t slice_x_start,
586 size_t nc_block_start,
587 size_t slice_x_max,
588 size_t nc_block_size)
589 {
590 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
591
592 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
593 return;
594 }
595
596 const size_t slice_width = subconvolution_params->slice_width;
597 if XNN_UNLIKELY(slice_x_start >= slice_width) {
598 return;
599 }
600 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
601
602 const size_t cx_stride = context->cx_stride;
603 context->ukernel.function[XNN_UARCH_DEFAULT](
604 slice_x_size,
605 nc_block_size,
606 context->kc,
607 subconvolution_params->scaled_kernel_size,
608 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
609 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
610 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
611 cx_stride,
612 context->cn_stride,
613 context->a_offset + batch_index * context->ba_stride,
614 context->zero,
615 &context->params);
616 }
617
xnn_compute_conv2d_hwc2chw(const struct conv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y_start,size_t output_y_slice)618 void xnn_compute_conv2d_hwc2chw(
619 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
620 size_t batch_index,
621 size_t output_y_start,
622 size_t output_y_slice)
623 {
624 context->hwc2chw_ukernel(
625 context->input_height,
626 context->input_width,
627 output_y_start,
628 output_y_start + output_y_slice,
629 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
630 context->zero,
631 context->packed_weights,
632 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
633 context->input_padding_top,
634 context->output_channels,
635 context->output_height_stride,
636 context->output_channel_stride,
637 &context->params);
638 }
639
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)640 void xnn_compute_dwconv_unipass(
641 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
642 size_t batch_index,
643 size_t output_y)
644 {
645 const void** indirect_input =
646 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
647 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
648 void* output = (void*) ((uintptr_t) context->output +
649 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
650
651 context->unipass_ukernel(
652 context->groups, context->output_width,
653 indirect_input, context->packed_weights, output,
654 context->indirect_input_width_stride, context->output_increment,
655 input_offset, context->zero,
656 &context->params);
657 }
658
xnn_compute_dwconv2d_chw(const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel)659 void xnn_compute_dwconv2d_chw(
660 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
661 size_t batch_index,
662 size_t channel)
663 {
664 context->chw_ukernel(
665 context->input_height,
666 context->input_width,
667 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
668 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
669 context->zero,
670 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
671 context->input_padding_top,
672 &context->params);
673 }
674
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)675 void xnn_compute_argmax_pooling_unipass(
676 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
677 size_t batch_index,
678 size_t output_y)
679 {
680 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
681 output_y * context->indirect_input_height_stride);
682 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
683 void* output = (void*) ((uintptr_t) context->output +
684 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
685 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
686 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
687
688 context->unipass_ukernel(
689 context->output_width, context->pooling_size, context->channels,
690 indirect_input, input_offset, output, index,
691 context->input_increment, context->output_increment);
692 }
693
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)694 void xnn_compute_argmax_pooling_multipass(
695 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
696 size_t batch_index,
697 size_t output_y)
698 {
699 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
700 output_y * context->indirect_input_height_stride);
701 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
702 void* output = (void*) ((uintptr_t) context->output +
703 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
704 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
705 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
706
707 void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
708 void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
709
710 context->multipass_ukernel(
711 context->output_width, context->pooling_size, context->channels,
712 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
713 context->input_increment, context->output_increment);
714 }
715
xnn_compute_max_pooling(const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)716 void xnn_compute_max_pooling(
717 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
718 size_t batch_index,
719 size_t output_y)
720 {
721 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
722 output_y * context->indirect_input_height_stride);
723 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
724 void* output = (void*) ((uintptr_t) context->output +
725 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
726
727 context->ukernel(
728 context->output_width, context->pooling_size, context->channels,
729 indirect_input, input_offset, output,
730 context->input_increment, context->output_increment,
731 &context->params);
732 }
733
xnn_compute_unpooling(const struct unpooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t input_y,size_t input_x)734 void xnn_compute_unpooling(
735 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
736 size_t input_y,
737 size_t input_x)
738 {
739 const void* input = (const void*) ((uintptr_t) context->input +
740 input_y * context->input_height_stride + input_x * context->input_width_stride);
741 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
742 input_y * context->index_height_stride + input_x * context->index_width_stride);
743 void** indirect_output =
744 (void**) ((uintptr_t) context->indirect_output +
745 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
746
747 context->ukernel(
748 context->pooling_size,
749 context->channels,
750 context->fill_value,
751 input, index, indirect_output);
752 }
753
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)754 void xnn_compute_average_pooling_unipass(
755 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
756 size_t batch_index,
757 size_t output_y)
758 {
759 const void** indirect_input =
760 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
761 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
762 void* output = (void*) ((uintptr_t) context->output +
763 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
764
765 context->unipass_ukernel(
766 context->output_width, context->pooling_size, context->channels,
767 indirect_input, input_offset, context->zero, output,
768 context->input_increment, context->output_increment,
769 &context->params);
770 }
771
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)772 void xnn_compute_average_pooling_multipass(
773 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
774 size_t batch_index,
775 size_t output_y)
776 {
777 const void** indirect_input =
778 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
779 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
780 void* output = (void*) ((uintptr_t) context->output +
781 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
782
783 void* multipass_buffer =
784 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
785
786 context->multipass_ukernel(
787 context->output_width, context->pooling_size, context->channels,
788 indirect_input, input_offset, context->zero, multipass_buffer, output,
789 context->input_increment, context->output_increment,
790 &context->params);
791 }
792
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)793 void xnn_compute_pixelwise_average_pooling_unipass(
794 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
795 size_t batch_index,
796 size_t output_y)
797 {
798 const void** indirect_input =
799 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
800 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
801 const void* pixelwise_buffer =
802 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
803 void* output = (void*) ((uintptr_t) context->output +
804 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
805
806 context->unipass_ukernel(
807 context->output_width, context->pooling_size, context->channels,
808 indirect_input, input_offset, context->zero, pixelwise_buffer, output,
809 context->input_increment, context->output_increment,
810 &context->params);
811 }
812
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)813 void xnn_compute_pixelwise_average_pooling_multipass(
814 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
815 size_t batch_index,
816 size_t output_y)
817 {
818 const void** indirect_input =
819 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
820 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
821 const void* pixelwise_buffer =
822 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
823 void* output = (void*) ((uintptr_t) context->output +
824 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
825
826 void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
827
828 context->multipass_ukernel(
829 context->output_width, context->pooling_size, context->channels,
830 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
831 context->input_increment, context->output_increment,
832 &context->params);
833 }
834
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)835 void xnn_compute_global_average_pooling_nwc_unipass(
836 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
837 size_t batch_index)
838 {
839 const void* input =
840 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
841 void* output =
842 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
843
844 context->unipass_ukernel(
845 context->input_elements,
846 context->channels,
847 input,
848 context->input_pixel_stride,
849 context->zero,
850 output,
851 &context->params);
852 }
853
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)854 void xnn_compute_global_average_pooling_nwc_multipass(
855 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
856 size_t batch_index)
857 {
858 const void* input =
859 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
860 void* output =
861 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
862
863 void* multipass_buffer =
864 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
865
866 context->multipass_ukernel(
867 context->input_elements,
868 context->channels,
869 input,
870 context->input_pixel_stride,
871 context->zero,
872 multipass_buffer,
873 output,
874 &context->params);
875 }
876
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channels_start,size_t channels_slice)877 void xnn_compute_global_average_pooling_ncw(
878 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
879 size_t batch_index,
880 size_t channels_start,
881 size_t channels_slice)
882 {
883 const void* input = (const void*) ((uintptr_t) context->input +
884 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
885 void* output = (void*) ((uintptr_t) context->output +
886 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
887
888 context->ukernel(
889 context->input_elements,
890 channels_slice,
891 input,
892 output,
893 &context->params);
894 }
895
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t pixel_start,size_t pixel_range)896 void xnn_compute_resize_bilinear(
897 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
898 size_t batch_index,
899 size_t pixel_start,
900 size_t pixel_range)
901 {
902 void* output =
903 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
904
905 context->ukernel(
906 pixel_range,
907 context->scaled_channels,
908 context->indirect_input + pixel_start * 4,
909 context->input_offset + batch_index * context->input_batch_stride,
910 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
911 output,
912 context->output_pixel_stride - context->scaled_channels);
913 }
914
xnn_compute_resize_bilinear_chw(const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel_start,size_t channel_range)915 void xnn_compute_resize_bilinear_chw(
916 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
917 size_t batch_index,
918 size_t channel_start,
919 size_t channel_range)
920 {
921 void* output =
922 (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
923 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
924
925 context->ukernel(
926 context->output_pixels,
927 channel_range,
928 context->indirect_input,
929 input_offset,
930 context->packed_weights,
931 output,
932 context->input_channel_stride);
933 }
934
xnn_compute_prelu(const struct prelu_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_range)935 void xnn_compute_prelu(
936 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
937 size_t batch_start,
938 size_t batch_range)
939 {
940 const size_t x_stride = context->x_stride;
941 const size_t y_stride = context->y_stride;
942 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
943 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
944
945 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
946 }
947
xnn_compute_pad_5d(const struct pad_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)948 void xnn_compute_pad_5d(
949 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
950 size_t i, size_t j, size_t k, size_t l, size_t m)
951 {
952 const void* input = (const void*) ((uintptr_t) context->input +
953 i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
954 void* output = (void*) ((uintptr_t) context->output +
955 i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
956
957 const size_t i_padding = context->pre_paddings[5];
958 const size_t j_padding = context->pre_paddings[4];
959 const size_t k_padding = context->pre_paddings[3];
960 const size_t l_padding = context->pre_paddings[2];
961 const size_t m_padding = context->pre_paddings[1];
962
963 const size_t i_size = context->input_size[5];
964 const size_t j_size = context->input_size[4];
965 const size_t k_size = context->input_size[3];
966 const size_t l_size = context->input_size[2];
967 const size_t m_size = context->input_size[1];
968
969 if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
970 l - l_padding < l_size && m - m_padding < m_size)
971 {
972 context->pad_ukernel(
973 1 /* rows */,
974 context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
975 input, 0 /* input stride */, output, 0 /* output stride */,
976 context->padding_value);
977 } else {
978 context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value);
979 }
980 }
981
xnn_compute_elementwise_binary_1d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i)982 void xnn_compute_elementwise_binary_1d(
983 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
984 size_t i)
985 {
986 const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]);
987 const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]);
988 void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]);
989 context->ukernel(context->elements, a, b, y, &context->params);
990 }
991
xnn_compute_elementwise_binary_2d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j)992 void xnn_compute_elementwise_binary_2d(
993 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
994 size_t i, size_t j)
995 {
996 const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]);
997 const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]);
998 void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]);
999 context->ukernel(context->elements, a, b, y, &context->params);
1000 }
1001
xnn_compute_elementwise_binary_3d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k)1002 void xnn_compute_elementwise_binary_3d(
1003 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1004 size_t i, size_t j, size_t k)
1005 {
1006 const void* a = (const void*) ((uintptr_t) context->a +
1007 i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]);
1008 const void* b = (const void*) ((uintptr_t) context->b +
1009 i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]);
1010 void* y = (void*) ((uintptr_t) context->y +
1011 i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]);
1012 context->ukernel(context->elements, a, b, y, &context->params);
1013 }
1014
xnn_compute_elementwise_binary_4d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l)1015 void xnn_compute_elementwise_binary_4d(
1016 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1017 size_t i, size_t j, size_t k, size_t l)
1018 {
1019 const void* a = (const void*) ((uintptr_t) context->a +
1020 i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]);
1021 const void* b = (const void*) ((uintptr_t) context->b +
1022 i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]);
1023 void* y = (void*) ((uintptr_t) context->y +
1024 i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]);
1025 context->ukernel(context->elements, a, b, y, &context->params);
1026 }
1027
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)1028 void xnn_compute_elementwise_binary_5d(
1029 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1030 size_t i, size_t j, size_t k, size_t l, size_t m)
1031 {
1032 const void* a = (const void*) ((uintptr_t) context->a +
1033 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
1034 const void* b = (const void*) ((uintptr_t) context->b +
1035 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
1036 void* y = (void*) ((uintptr_t) context->y +
1037 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
1038 context->ukernel(context->elements, a, b, y, &context->params);
1039 }
1040
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1041 void xnn_compute_channel_shuffle_fixed(
1042 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1043 size_t index)
1044 {
1045 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1046 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1047
1048 context->fixed_ukernel(context->n, x, y);
1049 }
1050
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1051 void xnn_compute_channel_shuffle_variable(
1052 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1053 size_t index)
1054 {
1055 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1056 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1057
1058 context->variable_ukernel(context->n, context->m, x, y);
1059 }
1060
xnn_compute_lut_strided(const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1061 void xnn_compute_lut_strided(
1062 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1063 size_t batch_index)
1064 {
1065 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1066 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1067
1068 context->ukernel(context->n, x, y, context->t);
1069 }
1070
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1071 void xnn_compute_lut_contiguous(
1072 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1073 size_t offset,
1074 size_t size)
1075 {
1076 const void* x = (const void*) ((uintptr_t) context->x + offset);
1077 void* y = (void*) ((uintptr_t) context->y + offset);
1078
1079 context->ukernel(size, x, y, context->t);
1080 }
1081
xnn_compute_univector_strided(const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t batch_range)1082 void xnn_compute_univector_strided(
1083 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1084 size_t batch_index,
1085 size_t batch_range)
1086 {
1087 const size_t x_stride = context->x_stride;
1088 const size_t y_stride = context->y_stride;
1089
1090 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index);
1091 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
1092 do {
1093 context->ukernel(context->n, x, y, &context->params);
1094 x = (const void*) ((uintptr_t) x + x_stride);
1095 y = (void*) ((uintptr_t) y + y_stride);
1096 } while (--batch_range != 0);
1097 }
1098
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1099 void xnn_compute_univector_contiguous(
1100 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1101 size_t offset,
1102 size_t size)
1103 {
1104 const uint32_t log2_xsize = context->log2_xsize;
1105 const uint32_t log2_ysize = context->log2_ysize;
1106 const void* x = (const void*) ((uintptr_t) context->x + offset);
1107 void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize));
1108 context->ukernel(size, x, y, &context->params);
1109 }
1110
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1111 void xnn_compute_u8_softmax(
1112 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1113 size_t batch_index)
1114 {
1115 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
1116 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
1117 const size_t n = context->n;
1118
1119 uint8_t x_max = 0;
1120 context->rmax_ukernel(n, x, &x_max);
1121 const size_t adjustment = x_max ^ 255;
1122 const uint32_t* t = (const uint32_t*) context->t + adjustment;
1123 context->lut_norm_ukernel(n, x, t, y);
1124 }
1125
xnn_compute_floating_point_softmax(const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1126 void xnn_compute_floating_point_softmax(
1127 const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1128 size_t batch_index)
1129 {
1130 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1131 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1132 const size_t n = context->n;
1133
1134 // First pass: reduce-max
1135 union {
1136 float as_float;
1137 uint16_t as_half;
1138 } x_max;
1139 context->rmax_ukernel(n, x, &x_max);
1140
1141 // Second pass: reduce-add & store exp(x-x_max)
1142 union {
1143 float as_float;
1144 uint16_t as_half;
1145 } y_sum;
1146 context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params);
1147
1148 // Third pass: scale y
1149 union {
1150 float as_float;
1151 uint16_t as_half;
1152 } y_scale;
1153 context->compute_reciprocal(&y_sum, &y_scale);
1154 context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params);
1155 }
1156
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_size)1157 void xnn_compute_vmulcaddc(
1158 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
1159 size_t batch_start,
1160 size_t batch_size)
1161 {
1162 const size_t x_stride = context->x_stride;
1163 const size_t y_stride = context->y_stride;
1164
1165 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
1166 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
1167
1168 context->ukernel(
1169 batch_size,
1170 context->n,
1171 x, x_stride,
1172 context->w,
1173 y, y_stride,
1174 &context->params);
1175 }
1176
1177 #if XNN_MAX_UARCH_TYPES > 1
xnn_compute_hmp_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1178 void xnn_compute_hmp_grouped_gemm(
1179 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1180 uint32_t uarch_index,
1181 size_t group_index,
1182 size_t mr_block_start,
1183 size_t nr_block_start,
1184 size_t mr_block_size,
1185 size_t nr_block_size)
1186 {
1187 const size_t k_scaled = context->k_scaled;
1188 const size_t a_stride = context->a_stride;
1189 const size_t cm_stride = context->cm_stride;
1190
1191 context->ukernel.function[uarch_index](
1192 mr_block_size,
1193 nr_block_size,
1194 k_scaled,
1195 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
1196 a_stride,
1197 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
1198 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
1199 cm_stride,
1200 context->cn_stride,
1201 &context->params);
1202 }
1203
xnn_compute_hmp_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1204 void xnn_compute_hmp_gemm(
1205 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1206 uint32_t uarch_index,
1207 size_t mr_block_start,
1208 size_t nr_block_start,
1209 size_t mr_block_size,
1210 size_t nr_block_size)
1211 {
1212 const size_t a_stride = context->a_stride;
1213 const size_t cm_stride = context->cm_stride;
1214
1215 context->ukernel.function[uarch_index](
1216 mr_block_size,
1217 nr_block_size,
1218 context->k_scaled,
1219 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
1220 a_stride,
1221 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1222 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1223 cm_stride,
1224 context->cn_stride,
1225 context->fused_params);
1226 }
1227
xnn_compute_hmp_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1228 void xnn_compute_hmp_grouped_batch_igemm(
1229 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1230 uint32_t uarch_index,
1231 size_t batch_index,
1232 size_t group_index,
1233 size_t mr_block_start,
1234 size_t nr_block_start,
1235 size_t mr_block_size,
1236 size_t nr_block_size)
1237 {
1238 const size_t ks = context->ks;
1239 const size_t cm_stride = context->cm_stride;
1240
1241 context->ukernel.function[uarch_index](
1242 mr_block_size,
1243 nr_block_size,
1244 context->kc,
1245 context->ks_scaled,
1246 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1247 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1248 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1249 cm_stride,
1250 context->cn_stride,
1251 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
1252 context->zero,
1253 &context->params);
1254 }
1255
xnn_compute_hmp_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1256 void xnn_compute_hmp_grouped_igemm(
1257 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1258 uint32_t uarch_index,
1259 size_t group_index,
1260 size_t mr_block_start,
1261 size_t nr_block_start,
1262 size_t mr_block_size,
1263 size_t nr_block_size)
1264 {
1265 const size_t ks = context->ks;
1266 const size_t cm_stride = context->cm_stride;
1267
1268 context->ukernel.function[uarch_index](
1269 mr_block_size,
1270 nr_block_size,
1271 context->kc,
1272 context->ks_scaled,
1273 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1274 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1275 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1276 cm_stride,
1277 context->cn_stride,
1278 context->a_offset + group_index * context->ga_stride,
1279 context->zero,
1280 &context->params);
1281 }
1282
xnn_compute_batch_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1283 void xnn_compute_batch_hmp_igemm(
1284 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1285 uint32_t uarch_index,
1286 size_t batch_index,
1287 size_t mr_block_start,
1288 size_t nr_block_start,
1289 size_t mr_block_size,
1290 size_t nr_block_size)
1291 {
1292 const size_t ks = context->ks;
1293 const size_t cm_stride = context->cm_stride;
1294
1295 context->ukernel.function[uarch_index](
1296 mr_block_size,
1297 nr_block_size,
1298 context->kc,
1299 context->ks_scaled,
1300 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1301 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1302 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1303 cm_stride,
1304 context->cn_stride,
1305 context->a_offset + batch_index * context->ba_stride,
1306 context->zero,
1307 &context->params);
1308 }
1309
xnn_compute_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1310 void xnn_compute_hmp_igemm(
1311 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1312 uint32_t uarch_index,
1313 size_t mr_block_start,
1314 size_t nr_block_start,
1315 size_t mr_block_size,
1316 size_t nr_block_size)
1317 {
1318 const size_t ks = context->ks;
1319 const size_t cm_stride = context->cm_stride;
1320
1321 context->ukernel.function[uarch_index](
1322 mr_block_size,
1323 nr_block_size,
1324 context->kc,
1325 context->ks_scaled,
1326 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1327 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1328 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1329 cm_stride,
1330 context->cn_stride,
1331 context->a_offset,
1332 context->zero,
1333 &context->params);
1334 }
1335 #endif // XNN_MAX_UARCH_TYPES > 1
1336
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)1337 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1338 {
1339 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1340 xnn_log_error("failed to run operator: XNNPACK is not initialized");
1341 return xnn_status_uninitialized;
1342 }
1343 switch (op->state) {
1344 case xnn_run_state_invalid:
1345 xnn_log_error("failed to run operator: operator was not successfully setup");
1346 return xnn_status_invalid_state;
1347 case xnn_run_state_ready:
1348 break;
1349 case xnn_run_state_skip:
1350 return xnn_status_success;
1351 }
1352
1353 uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS;
1354 if (op->flags & XNN_FLAG_YIELD_WORKERS) {
1355 flags |= PTHREADPOOL_FLAG_YIELD_WORKERS;
1356 }
1357 switch (op->compute.type) {
1358 case xnn_parallelization_type_invalid:
1359 break;
1360 case xnn_parallelization_type_1d:
1361 assert(op->compute.range[0] != 0);
1362 pthreadpool_parallelize_1d(
1363 threadpool,
1364 op->compute.task_1d,
1365 &op->context,
1366 op->compute.range[0],
1367 flags);
1368 break;
1369 case xnn_parallelization_type_1d_tile_1d:
1370 assert(op->compute.range[0] != 0);
1371 assert(op->compute.tile[0] != 0);
1372 pthreadpool_parallelize_1d_tile_1d(
1373 threadpool,
1374 op->compute.task_1d_tile_1d,
1375 &op->context,
1376 op->compute.range[0],
1377 op->compute.tile[0],
1378 flags);
1379 break;
1380 case xnn_parallelization_type_2d:
1381 assert(op->compute.range[0] != 0);
1382 assert(op->compute.range[1] != 0);
1383 pthreadpool_parallelize_2d(
1384 threadpool,
1385 op->compute.task_2d,
1386 &op->context,
1387 op->compute.range[0], op->compute.range[1],
1388 flags);
1389 break;
1390 case xnn_parallelization_type_2d_tile_1d:
1391 assert(op->compute.range[0] != 0);
1392 assert(op->compute.range[1] != 0);
1393 assert(op->compute.tile[0] != 0);
1394 pthreadpool_parallelize_2d_tile_1d(
1395 threadpool,
1396 op->compute.task_2d_tile_1d,
1397 &op->context,
1398 op->compute.range[0], op->compute.range[1],
1399 op->compute.tile[0],
1400 flags);
1401 break;
1402 case xnn_parallelization_type_2d_tile_2d:
1403 assert(op->compute.range[0] != 0);
1404 assert(op->compute.range[1] != 0);
1405 assert(op->compute.tile[0] != 0);
1406 assert(op->compute.tile[1] != 0);
1407 pthreadpool_parallelize_2d_tile_2d(
1408 threadpool,
1409 op->compute.task_2d_tile_2d,
1410 &op->context,
1411 op->compute.range[0], op->compute.range[1],
1412 op->compute.tile[0], op->compute.tile[1],
1413 flags);
1414 break;
1415 case xnn_parallelization_type_3d:
1416 assert(op->compute.range[0] != 0);
1417 assert(op->compute.range[1] != 0);
1418 assert(op->compute.range[2] != 0);
1419 pthreadpool_parallelize_3d(
1420 threadpool,
1421 op->compute.task_3d,
1422 &op->context,
1423 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1424 flags);
1425 break;
1426 case xnn_parallelization_type_3d_tile_2d:
1427 assert(op->compute.range[0] != 0);
1428 assert(op->compute.range[1] != 0);
1429 assert(op->compute.range[2] != 0);
1430 assert(op->compute.tile[0] != 0);
1431 assert(op->compute.tile[1] != 0);
1432 pthreadpool_parallelize_3d_tile_2d(
1433 threadpool,
1434 op->compute.task_3d_tile_2d,
1435 &op->context,
1436 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1437 op->compute.tile[0], op->compute.tile[1],
1438 flags);
1439 break;
1440 case xnn_parallelization_type_4d:
1441 assert(op->compute.range[0] != 0);
1442 assert(op->compute.range[1] != 0);
1443 assert(op->compute.range[2] != 0);
1444 assert(op->compute.range[3] != 0);
1445 pthreadpool_parallelize_4d(
1446 threadpool,
1447 op->compute.task_4d,
1448 &op->context,
1449 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1450 flags);
1451 break;
1452 case xnn_parallelization_type_4d_tile_2d:
1453 assert(op->compute.range[0] != 0);
1454 assert(op->compute.range[1] != 0);
1455 assert(op->compute.range[2] != 0);
1456 assert(op->compute.range[3] != 0);
1457 assert(op->compute.tile[0] != 0);
1458 assert(op->compute.tile[1] != 0);
1459 pthreadpool_parallelize_4d_tile_2d(
1460 threadpool,
1461 op->compute.task_4d_tile_2d,
1462 &op->context,
1463 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1464 op->compute.tile[0], op->compute.tile[1],
1465 flags);
1466 break;
1467 case xnn_parallelization_type_5d:
1468 assert(op->compute.range[0] != 0);
1469 assert(op->compute.range[1] != 0);
1470 assert(op->compute.range[2] != 0);
1471 assert(op->compute.range[3] != 0);
1472 assert(op->compute.range[4] != 0);
1473 pthreadpool_parallelize_5d(
1474 threadpool,
1475 op->compute.task_5d,
1476 &op->context,
1477 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1478 flags);
1479 break;
1480 case xnn_parallelization_type_5d_tile_2d:
1481 assert(op->compute.range[0] != 0);
1482 assert(op->compute.range[1] != 0);
1483 assert(op->compute.range[2] != 0);
1484 assert(op->compute.range[3] != 0);
1485 assert(op->compute.range[4] != 0);
1486 assert(op->compute.tile[0] != 0);
1487 assert(op->compute.tile[1] != 0);
1488 pthreadpool_parallelize_5d_tile_2d(
1489 threadpool,
1490 op->compute.task_5d_tile_2d,
1491 &op->context,
1492 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1493 op->compute.tile[0], op->compute.tile[1],
1494 flags);
1495 break;
1496 case xnn_parallelization_type_6d_tile_2d:
1497 assert(op->compute.range[0] != 0);
1498 assert(op->compute.range[1] != 0);
1499 assert(op->compute.range[2] != 0);
1500 assert(op->compute.range[3] != 0);
1501 assert(op->compute.range[4] != 0);
1502 assert(op->compute.range[5] != 0);
1503 assert(op->compute.tile[0] != 0);
1504 assert(op->compute.tile[1] != 0);
1505 pthreadpool_parallelize_6d_tile_2d(
1506 threadpool,
1507 op->compute.task_6d_tile_2d,
1508 &op->context,
1509 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1510 op->compute.tile[0], op->compute.tile[1],
1511 flags);
1512 break;
1513 #if XNN_MAX_UARCH_TYPES > 1
1514 case xnn_parallelization_type_2d_tile_2d_with_uarch:
1515 assert(op->compute.range[0] != 0);
1516 assert(op->compute.range[1] != 0);
1517 assert(op->compute.tile[0] != 0);
1518 assert(op->compute.tile[1] != 0);
1519 pthreadpool_parallelize_2d_tile_2d_with_uarch(
1520 threadpool,
1521 op->compute.task_2d_tile_2d_with_id,
1522 &op->context,
1523 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1524 op->compute.range[0], op->compute.range[1],
1525 op->compute.tile[0], op->compute.tile[1],
1526 flags);
1527 break;
1528 case xnn_parallelization_type_3d_tile_2d_with_uarch:
1529 assert(op->compute.range[0] != 0);
1530 assert(op->compute.range[1] != 0);
1531 assert(op->compute.range[2] != 0);
1532 assert(op->compute.tile[0] != 0);
1533 assert(op->compute.tile[1] != 0);
1534 pthreadpool_parallelize_3d_tile_2d_with_uarch(
1535 threadpool,
1536 op->compute.task_3d_tile_2d_with_id,
1537 &op->context,
1538 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1539 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1540 op->compute.tile[0], op->compute.tile[1],
1541 flags);
1542 break;
1543 case xnn_parallelization_type_4d_tile_2d_with_uarch:
1544 assert(op->compute.range[0] != 0);
1545 assert(op->compute.range[1] != 0);
1546 assert(op->compute.range[2] != 0);
1547 assert(op->compute.range[3] != 0);
1548 assert(op->compute.tile[0] != 0);
1549 assert(op->compute.tile[1] != 0);
1550 pthreadpool_parallelize_4d_tile_2d_with_uarch(
1551 threadpool,
1552 op->compute.task_4d_tile_2d_with_id,
1553 &op->context,
1554 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1555 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1556 op->compute.tile[0], op->compute.tile[1],
1557 flags);
1558 break;
1559 #endif // XNN_MAX_UARCH_TYPES > 1
1560 default:
1561 XNN_UNREACHABLE;
1562 }
1563 return xnn_status_success;
1564 }
1565