1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/operator.h>
17 #include <xnnpack/log.h>
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/compute.h>
22
23
xnn_compute_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)24 void xnn_compute_grouped_gemm(
25 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
26 size_t group_index,
27 size_t mr_block_start,
28 size_t nr_block_start,
29 size_t mr_block_size,
30 size_t nr_block_size)
31 {
32 const size_t k_scaled = context->k_scaled;
33 const size_t a_stride = context->a_stride;
34 const size_t cm_stride = context->cm_stride;
35
36 context->ukernel.function[XNN_UARCH_DEFAULT](
37 mr_block_size,
38 nr_block_size,
39 k_scaled,
40 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
41 a_stride,
42 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
43 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
44 cm_stride,
45 context->cn_stride,
46 &context->params);
47 }
48
xnn_compute_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)49 void xnn_compute_gemm(
50 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
51 size_t mr_block_start,
52 size_t nr_block_start,
53 size_t mr_block_size,
54 size_t nr_block_size)
55 {
56 const size_t a_stride = context->a_stride;
57 const size_t cm_stride = context->cm_stride;
58
59 context->ukernel.function[XNN_UARCH_DEFAULT](
60 mr_block_size,
61 nr_block_size,
62 context->k_scaled,
63 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
64 a_stride,
65 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
66 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
67 cm_stride,
68 context->cn_stride,
69 &context->params);
70 }
71
xnn_compute_spmm(const struct spmm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t mr_block_size)72 void xnn_compute_spmm(
73 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
74 size_t batch_index,
75 size_t mr_block_start,
76 size_t mr_block_size)
77 {
78 context->ukernel(
79 mr_block_size,
80 context->n,
81 (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
82 context->nonzero_weights,
83 context->input_increments,
84 context->output_channel_nonzeros,
85 (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
86 context->scaled_m,
87 &context->params);
88 }
89
xnn_compute_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)90 void xnn_compute_grouped_batch_igemm(
91 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
92 size_t batch_index,
93 size_t group_index,
94 size_t mr_block_start,
95 size_t nr_block_start,
96 size_t mr_block_size,
97 size_t nr_block_size)
98 {
99 const size_t ks = context->ks;
100 const size_t cm_stride = context->cm_stride;
101
102 context->ukernel.function[XNN_UARCH_DEFAULT](
103 mr_block_size,
104 nr_block_size,
105 context->kc,
106 context->ks_scaled,
107 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
108 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
109 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
110 cm_stride,
111 context->cn_stride,
112 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
113 context->zero,
114 &context->params);
115 }
116
xnn_compute_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)117 void xnn_compute_grouped_igemm(
118 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
119 size_t group_index,
120 size_t mr_block_start,
121 size_t nr_block_start,
122 size_t mr_block_size,
123 size_t nr_block_size)
124 {
125 const size_t ks = context->ks;
126 const size_t cm_stride = context->cm_stride;
127
128 context->ukernel.function[XNN_UARCH_DEFAULT](
129 mr_block_size,
130 nr_block_size,
131 context->kc,
132 context->ks_scaled,
133 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
134 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
135 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
136 cm_stride,
137 context->cn_stride,
138 context->a_offset + group_index * context->ga_stride,
139 context->zero,
140 &context->params);
141 }
142
xnn_compute_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)143 void xnn_compute_batch_igemm(
144 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
145 size_t batch_index,
146 size_t mr_block_start,
147 size_t nr_block_start,
148 size_t mr_block_size,
149 size_t nr_block_size)
150 {
151 const size_t ks = context->ks;
152 const size_t cm_stride = context->cm_stride;
153
154 context->ukernel.function[XNN_UARCH_DEFAULT](
155 mr_block_size,
156 nr_block_size,
157 context->kc,
158 context->ks_scaled,
159 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
160 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
161 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
162 cm_stride,
163 context->cn_stride,
164 context->a_offset + batch_index * context->ba_stride,
165 context->zero,
166 &context->params);
167 }
168
xnn_compute_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)169 void xnn_compute_igemm(
170 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
171 size_t mr_block_start,
172 size_t nr_block_start,
173 size_t mr_block_size,
174 size_t nr_block_size)
175 {
176 const size_t ks = context->ks;
177 const size_t cm_stride = context->cm_stride;
178
179 context->ukernel.function[XNN_UARCH_DEFAULT](
180 mr_block_size,
181 nr_block_size,
182 context->kc,
183 context->ks_scaled,
184 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
185 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
186 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
187 cm_stride,
188 context->cn_stride,
189 context->a_offset,
190 context->zero,
191 &context->params);
192 }
193
xnn_compute_grouped_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)194 void xnn_compute_grouped_subgemm2d(
195 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
196 size_t batch_index,
197 size_t group_index,
198 size_t subkernel_index,
199 size_t slice_y,
200 size_t slice_x_start,
201 size_t nc_block_start,
202 size_t slice_x_max,
203 size_t nc_block_size)
204 {
205 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
206
207 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
208 return;
209 }
210
211 const size_t slice_width = subconvolution_params->slice_width;
212 if XNN_UNLIKELY(slice_x_start >= slice_width) {
213 return;
214 }
215 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
216
217 const size_t ax_stride = context->ax_stride;
218 const size_t cx_stride = context->cx_stride;
219 context->ukernel.function[XNN_UARCH_DEFAULT](
220 slice_x_size,
221 nc_block_size,
222 context->kc,
223 (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
224 ax_stride,
225 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
226 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
227 cx_stride,
228 context->cn_stride,
229 &context->params);
230 }
231
xnn_compute_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)232 void xnn_compute_subgemm2d(
233 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
234 size_t batch_index,
235 size_t subkernel_index,
236 size_t slice_y,
237 size_t slice_x_start,
238 size_t nc_block_start,
239 size_t slice_x_max,
240 size_t nc_block_size)
241 {
242 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
243
244 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
245 return;
246 }
247
248 const size_t slice_width = subconvolution_params->slice_width;
249 if XNN_UNLIKELY(slice_x_start >= slice_width) {
250 return;
251 }
252 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
253
254 const size_t ax_stride = context->ax_stride;
255 const size_t cx_stride = context->cx_stride;
256 context->ukernel.function[XNN_UARCH_DEFAULT](
257 slice_x_size,
258 nc_block_size,
259 context->kc,
260 (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
261 ax_stride,
262 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
263 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
264 cx_stride,
265 context->cn_stride,
266 &context->params);
267 }
268
xnn_compute_grouped_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)269 void xnn_compute_grouped_subconv2d(
270 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
271 size_t batch_index,
272 size_t group_index,
273 size_t subkernel_index,
274 size_t slice_y,
275 size_t slice_x_start,
276 size_t nc_block_start,
277 size_t slice_x_max,
278 size_t nc_block_size)
279 {
280 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
281
282 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
283 return;
284 }
285
286 const size_t slice_width = subconvolution_params->slice_width;
287 if XNN_UNLIKELY(slice_x_start >= slice_width) {
288 return;
289 }
290 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
291
292 const size_t cx_stride = context->cx_stride;
293 context->ukernel.function[XNN_UARCH_DEFAULT](
294 slice_x_size,
295 nc_block_size,
296 context->kc,
297 subconvolution_params->scaled_kernel_size,
298 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
299 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
300 (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
301 cx_stride,
302 context->cn_stride,
303 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
304 context->zero,
305 &context->params);
306 }
307
xnn_compute_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)308 void xnn_compute_subconv2d(
309 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
310 size_t batch_index,
311 size_t subkernel_index,
312 size_t slice_y,
313 size_t slice_x_start,
314 size_t nc_block_start,
315 size_t slice_x_max,
316 size_t nc_block_size)
317 {
318 const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
319
320 if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
321 return;
322 }
323
324 const size_t slice_width = subconvolution_params->slice_width;
325 if XNN_UNLIKELY(slice_x_start >= slice_width) {
326 return;
327 }
328 const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
329
330 const size_t cx_stride = context->cx_stride;
331 context->ukernel.function[XNN_UARCH_DEFAULT](
332 slice_x_size,
333 nc_block_size,
334 context->kc,
335 subconvolution_params->scaled_kernel_size,
336 (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
337 (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
338 (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
339 cx_stride,
340 context->cn_stride,
341 context->a_offset + batch_index * context->ba_stride,
342 context->zero,
343 &context->params);
344 }
345
xnn_compute_conv2d_hwc2chw(const struct conv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y_start,size_t output_y_slice)346 void xnn_compute_conv2d_hwc2chw(
347 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
348 size_t batch_index,
349 size_t output_y_start,
350 size_t output_y_slice)
351 {
352 context->hwc2chw_ukernel(
353 context->input_height,
354 context->input_width,
355 output_y_start,
356 output_y_start + output_y_slice,
357 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
358 context->zero,
359 context->packed_weights,
360 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
361 context->input_padding_top,
362 context->output_channels,
363 context->output_height_stride,
364 context->output_channel_stride,
365 &context->params);
366 }
367
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)368 void xnn_compute_dwconv_unipass(
369 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
370 size_t batch_index,
371 size_t output_y)
372 {
373 const void** indirect_input =
374 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
375 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
376 void* output = (void*) ((uintptr_t) context->output +
377 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
378
379 context->unipass_ukernel(
380 context->groups, context->output_width,
381 indirect_input, context->packed_weights, output,
382 context->indirect_input_width_stride, context->output_increment,
383 input_offset, context->zero,
384 &context->params);
385 }
386
xnn_compute_dwconv2d_chw(const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel)387 void xnn_compute_dwconv2d_chw(
388 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
389 size_t batch_index,
390 size_t channel)
391 {
392 context->chw_ukernel(
393 context->input_height,
394 context->input_width,
395 (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
396 (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
397 context->zero,
398 (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
399 context->input_padding_top,
400 &context->params);
401 }
402
xnn_compute_depthtospace2d_hwc_contiguous(const struct depthtospace2d_hwc_context * context,size_t batch_input_y,size_t input_x,size_t block_y)403 void xnn_compute_depthtospace2d_hwc_contiguous(
404 const struct depthtospace2d_hwc_context* context,
405 size_t batch_input_y,
406 size_t input_x,
407 size_t block_y)
408 {
409 const size_t input_width = context->input_width;
410 const size_t elements = context->elements;
411 const void* input = (const void*) ((uintptr_t) context->input +
412 (batch_input_y * input_width + input_x) * context->input_width_stride + block_y * elements);
413 void* output = (void*) ((uintptr_t) context->output +
414 ((batch_input_y * context->block_size + block_y) * input_width + input_x) * elements);
415
416 context->ukernel(
417 elements,
418 input,
419 output,
420 NULL);
421 }
422
xnn_compute_depthtospace2d_hwc_strided(const struct depthtospace2d_hwc_context * context,size_t batch_input_y,size_t input_x,size_t block_y,size_t block_x)423 void xnn_compute_depthtospace2d_hwc_strided(
424 const struct depthtospace2d_hwc_context* context,
425 size_t batch_input_y,
426 size_t input_x,
427 size_t block_y,
428 size_t block_x)
429 {
430 const size_t block_size = context->block_size;
431 const size_t elements = context->elements;
432 const void* input = (const void*) ((uintptr_t) context->input +
433 batch_input_y * context->input_height_stride + input_x * context->input_width_stride + (block_y * block_size + block_x) * elements);
434 void* output = (void*) ((uintptr_t) context->output +
435 (batch_input_y * block_size + block_y) * context->output_height_stride +
436 (input_x * block_size + block_x) * context->output_width_stride);
437
438 context->ukernel(
439 elements,
440 input,
441 output,
442 NULL);
443 }
444
xnn_compute_depthtospace2d_chw2hwc(const struct depthtospace2d_chw2hwc_context * context,size_t batch_index)445 void xnn_compute_depthtospace2d_chw2hwc(
446 const struct depthtospace2d_chw2hwc_context* context,
447 size_t batch_index)
448 {
449 context->ukernel(
450 context->output_channels,
451 context->input_height,
452 context->input_width,
453 context->block_size,
454 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
455 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
456 context->output_channel_stride);
457 }
458
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)459 void xnn_compute_argmax_pooling_unipass(
460 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
461 size_t batch_index,
462 size_t output_y)
463 {
464 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
465 output_y * context->indirect_input_height_stride);
466 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
467 void* output = (void*) ((uintptr_t) context->output +
468 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
469 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
470 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
471
472 context->unipass_ukernel(
473 context->output_width, context->pooling_size, context->channels,
474 indirect_input, input_offset, output, index,
475 context->input_increment, context->output_increment);
476 }
477
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)478 void xnn_compute_argmax_pooling_multipass(
479 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
480 size_t batch_index,
481 size_t output_y)
482 {
483 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
484 output_y * context->indirect_input_height_stride);
485 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
486 void* output = (void*) ((uintptr_t) context->output +
487 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
488 uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
489 batch_index * context->index_batch_stride + output_y * context->index_height_stride);
490
491 void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
492 void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
493
494 context->multipass_ukernel(
495 context->output_width, context->pooling_size, context->channels,
496 indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
497 context->input_increment, context->output_increment);
498 }
499
xnn_compute_max_pooling(const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)500 void xnn_compute_max_pooling(
501 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
502 size_t batch_index,
503 size_t output_y)
504 {
505 const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
506 output_y * context->indirect_input_height_stride);
507 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
508 void* output = (void*) ((uintptr_t) context->output +
509 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
510
511 context->ukernel(
512 context->output_width, context->pooling_size, context->channels,
513 indirect_input, input_offset, output,
514 context->input_increment, context->output_increment,
515 &context->params);
516 }
517
xnn_compute_unpooling(const struct unpooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t input_y,size_t input_x)518 void xnn_compute_unpooling(
519 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
520 size_t input_y,
521 size_t input_x)
522 {
523 const void* input = (const void*) ((uintptr_t) context->input +
524 input_y * context->input_height_stride + input_x * context->input_width_stride);
525 const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
526 input_y * context->index_height_stride + input_x * context->index_width_stride);
527 void** indirect_output =
528 (void**) ((uintptr_t) context->indirect_output +
529 input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
530
531 context->ukernel(
532 context->pooling_size,
533 context->channels,
534 context->fill_value,
535 input, index, indirect_output);
536 }
537
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)538 void xnn_compute_average_pooling_unipass(
539 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
540 size_t batch_index,
541 size_t output_y)
542 {
543 const void** indirect_input =
544 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
545 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
546 void* output = (void*) ((uintptr_t) context->output +
547 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
548
549 context->unipass_ukernel(
550 context->output_width, context->pooling_size, context->channels,
551 indirect_input, input_offset, context->zero, output,
552 context->input_increment, context->output_increment,
553 &context->params);
554 }
555
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)556 void xnn_compute_average_pooling_multipass(
557 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
558 size_t batch_index,
559 size_t output_y)
560 {
561 const void** indirect_input =
562 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
563 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
564 void* output = (void*) ((uintptr_t) context->output +
565 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
566
567 void* multipass_buffer =
568 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
569
570 context->multipass_ukernel(
571 context->output_width, context->pooling_size, context->channels,
572 indirect_input, input_offset, context->zero, multipass_buffer, output,
573 context->input_increment, context->output_increment,
574 &context->params);
575 }
576
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)577 void xnn_compute_pixelwise_average_pooling_unipass(
578 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
579 size_t batch_index,
580 size_t output_y)
581 {
582 const void** indirect_input =
583 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
584 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
585 const void* pixelwise_buffer =
586 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
587 void* output = (void*) ((uintptr_t) context->output +
588 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
589
590 context->unipass_ukernel(
591 context->output_width, context->pooling_size, context->channels,
592 indirect_input, input_offset, context->zero, pixelwise_buffer, output,
593 context->input_increment, context->output_increment,
594 &context->params);
595 }
596
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)597 void xnn_compute_pixelwise_average_pooling_multipass(
598 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
599 size_t batch_index,
600 size_t output_y)
601 {
602 const void** indirect_input =
603 (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
604 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
605 const void* pixelwise_buffer =
606 (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
607 void* output = (void*) ((uintptr_t) context->output +
608 batch_index * context->output_batch_stride + output_y * context->output_height_stride);
609
610 void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
611
612 context->multipass_ukernel(
613 context->output_width, context->pooling_size, context->channels,
614 indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
615 context->input_increment, context->output_increment,
616 &context->params);
617 }
618
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)619 void xnn_compute_global_average_pooling_nwc_unipass(
620 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
621 size_t batch_index)
622 {
623 const void* input =
624 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
625 void* output =
626 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
627
628 context->unipass_ukernel(
629 context->input_elements,
630 context->channels,
631 input,
632 context->input_pixel_stride,
633 context->zero,
634 output,
635 &context->params);
636 }
637
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)638 void xnn_compute_global_average_pooling_nwc_multipass(
639 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
640 size_t batch_index)
641 {
642 const void* input =
643 (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
644 void* output =
645 (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
646
647 void* multipass_buffer =
648 XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
649
650 context->multipass_ukernel(
651 context->input_elements,
652 context->channels,
653 input,
654 context->input_pixel_stride,
655 context->zero,
656 multipass_buffer,
657 output,
658 &context->params);
659 }
660
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channels_start,size_t channels_slice)661 void xnn_compute_global_average_pooling_ncw(
662 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
663 size_t batch_index,
664 size_t channels_start,
665 size_t channels_slice)
666 {
667 const void* input = (const void*) ((uintptr_t) context->input +
668 channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
669 void* output = (void*) ((uintptr_t) context->output +
670 channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
671
672 context->ukernel(
673 context->input_elements,
674 channels_slice,
675 input,
676 output,
677 &context->params);
678 }
679
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t pixel_start,size_t pixel_range)680 void xnn_compute_resize_bilinear(
681 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
682 size_t batch_index,
683 size_t pixel_start,
684 size_t pixel_range)
685 {
686 void* output =
687 (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
688
689 context->ukernel(
690 pixel_range,
691 context->scaled_channels,
692 context->indirect_input + pixel_start * 4,
693 context->input_offset + batch_index * context->input_batch_stride,
694 (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
695 output,
696 context->output_pixel_stride - context->scaled_channels);
697 }
698
xnn_compute_resize_bilinear_chw(const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel_start,size_t channel_range)699 void xnn_compute_resize_bilinear_chw(
700 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
701 size_t batch_index,
702 size_t channel_start,
703 size_t channel_range)
704 {
705 void* output =
706 (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
707 const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
708
709 context->ukernel(
710 context->output_pixels,
711 channel_range,
712 context->indirect_input,
713 input_offset,
714 context->packed_weights,
715 output,
716 context->input_channel_stride);
717 }
718
xnn_compute_prelu(const struct prelu_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_range)719 void xnn_compute_prelu(
720 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
721 size_t batch_start,
722 size_t batch_range)
723 {
724 const size_t x_stride = context->x_stride;
725 const size_t y_stride = context->y_stride;
726 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
727 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
728
729 context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
730 }
731
xnn_compute_pad_5d(const struct pad_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)732 void xnn_compute_pad_5d(
733 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
734 size_t i, size_t j, size_t k, size_t l, size_t m)
735 {
736 const void* input = (const void*) ((uintptr_t) context->input +
737 i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
738 void* output = (void*) ((uintptr_t) context->output +
739 i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
740
741 const size_t i_padding = context->pre_paddings[5];
742 const size_t j_padding = context->pre_paddings[4];
743 const size_t k_padding = context->pre_paddings[3];
744 const size_t l_padding = context->pre_paddings[2];
745 const size_t m_padding = context->pre_paddings[1];
746
747 const size_t i_size = context->input_size[5];
748 const size_t j_size = context->input_size[4];
749 const size_t k_size = context->input_size[3];
750 const size_t l_size = context->input_size[2];
751 const size_t m_size = context->input_size[1];
752
753 if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
754 l - l_padding < l_size && m - m_padding < m_size)
755 {
756 context->pad_ukernel(
757 1 /* rows */,
758 context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
759 &context->padding_value,
760 input, 0 /* input stride */, output, 0 /* output stride */);
761 } else {
762 context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, &context->padding_value);
763 }
764 }
765
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)766 void xnn_compute_elementwise_binary_5d(
767 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
768 size_t i, size_t j, size_t k, size_t l, size_t m)
769 {
770 const void* a = (const void*) ((uintptr_t) context->a +
771 i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
772 const void* b = (const void*) ((uintptr_t) context->b +
773 i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
774 void* y = (void*) ((uintptr_t) context->y +
775 i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
776 context->ukernel(context->elements, a, b, y, &context->params);
777 }
778
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)779 void xnn_compute_channel_shuffle_fixed(
780 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
781 size_t index)
782 {
783 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
784 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
785
786 context->fixed_ukernel(context->n, x, y);
787 }
788
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)789 void xnn_compute_channel_shuffle_variable(
790 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
791 size_t index)
792 {
793 const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
794 void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
795
796 context->variable_ukernel(context->n, context->m, x, y);
797 }
798
xnn_compute_lut_strided(const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)799 void xnn_compute_lut_strided(
800 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
801 size_t batch_index)
802 {
803 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
804 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
805
806 context->ukernel(context->n, x, context->t, y);
807 }
808
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)809 void xnn_compute_lut_contiguous(
810 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
811 size_t offset,
812 size_t size)
813 {
814 const void* x = (const void*) ((uintptr_t) context->x + offset);
815 void* y = (void*) ((uintptr_t) context->y + offset);
816
817 context->ukernel(size, x, context->t, y);
818 }
819
xnn_compute_univector_strided(const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t batch_range)820 void xnn_compute_univector_strided(
821 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
822 size_t batch_index,
823 size_t batch_range /* always 1 */)
824 {
825 assert(batch_range == 1);
826
827 const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
828 void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
829 context->ukernel(context->n, x, y, &context->params);
830 }
831
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)832 void xnn_compute_univector_contiguous(
833 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
834 size_t offset,
835 size_t size)
836 {
837 const void* x = (const void*) ((uintptr_t) context->x + offset);
838 void* y = (void*) ((uintptr_t) context->y + offset);
839 context->ukernel(size, x, y, &context->params);
840 }
841
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)842 void xnn_compute_u8_softmax(
843 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
844 size_t batch_index)
845 {
846 const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
847 uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
848 const size_t n = context->n;
849
850 uint8_t x_max = 0;
851 context->rmax_ukernel(n, x, &x_max);
852 const size_t adjustment = x_max ^ 255;
853 const uint32_t* t = (const uint32_t*) context->t + adjustment;
854 context->lut_norm_ukernel(n, x, t, y);
855 }
856
xnn_compute_f32_three_pass_softmax(const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)857 void xnn_compute_f32_three_pass_softmax(
858 const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
859 size_t batch_index)
860 {
861 const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
862 float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
863 const size_t n = context->n;
864
865 // First pass: reduce-max
866 float x_max;
867 context->rmax_ukernel(n, x, &x_max);
868
869 // Second pass: reduce-add & store exp(x-x_max)
870 float y_sum;
871 context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
872
873 // Third pass: scale y
874 const float y_scale = 1.0f / y_sum;
875 context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
876 }
877
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_size)878 void xnn_compute_vmulcaddc(
879 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
880 size_t batch_start,
881 size_t batch_size)
882 {
883 const size_t x_stride = context->x_stride;
884 const size_t y_stride = context->y_stride;
885
886 const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
887 void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
888
889 context->ukernel(
890 batch_size,
891 context->n,
892 x, x_stride,
893 context->w,
894 y, y_stride,
895 &context->params);
896 }
897
898 #if XNN_MAX_UARCH_TYPES > 1
xnn_compute_hmp_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)899 void xnn_compute_hmp_grouped_gemm(
900 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
901 uint32_t uarch_index,
902 size_t group_index,
903 size_t mr_block_start,
904 size_t nr_block_start,
905 size_t mr_block_size,
906 size_t nr_block_size)
907 {
908 const size_t k_scaled = context->k_scaled;
909 const size_t a_stride = context->a_stride;
910 const size_t cm_stride = context->cm_stride;
911
912 context->ukernel.function[uarch_index](
913 mr_block_size,
914 nr_block_size,
915 k_scaled,
916 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
917 a_stride,
918 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
919 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
920 cm_stride,
921 context->cn_stride,
922 &context->params);
923 }
924
xnn_compute_hmp_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)925 void xnn_compute_hmp_gemm(
926 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
927 uint32_t uarch_index,
928 size_t mr_block_start,
929 size_t nr_block_start,
930 size_t mr_block_size,
931 size_t nr_block_size)
932 {
933 const size_t a_stride = context->a_stride;
934 const size_t cm_stride = context->cm_stride;
935
936 context->ukernel.function[uarch_index](
937 mr_block_size,
938 nr_block_size,
939 context->k_scaled,
940 (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
941 a_stride,
942 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
943 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
944 cm_stride,
945 context->cn_stride,
946 &context->params);
947 }
948
xnn_compute_hmp_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)949 void xnn_compute_hmp_grouped_batch_igemm(
950 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
951 uint32_t uarch_index,
952 size_t batch_index,
953 size_t group_index,
954 size_t mr_block_start,
955 size_t nr_block_start,
956 size_t mr_block_size,
957 size_t nr_block_size)
958 {
959 const size_t ks = context->ks;
960 const size_t cm_stride = context->cm_stride;
961
962 context->ukernel.function[uarch_index](
963 mr_block_size,
964 nr_block_size,
965 context->kc,
966 context->ks_scaled,
967 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
968 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
969 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
970 cm_stride,
971 context->cn_stride,
972 context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
973 context->zero,
974 &context->params);
975 }
976
xnn_compute_hmp_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)977 void xnn_compute_hmp_grouped_igemm(
978 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
979 uint32_t uarch_index,
980 size_t group_index,
981 size_t mr_block_start,
982 size_t nr_block_start,
983 size_t mr_block_size,
984 size_t nr_block_size)
985 {
986 const size_t ks = context->ks;
987 const size_t cm_stride = context->cm_stride;
988
989 context->ukernel.function[uarch_index](
990 mr_block_size,
991 nr_block_size,
992 context->kc,
993 context->ks_scaled,
994 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
995 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
996 (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
997 cm_stride,
998 context->cn_stride,
999 context->a_offset + group_index * context->ga_stride,
1000 context->zero,
1001 &context->params);
1002 }
1003
xnn_compute_batch_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1004 void xnn_compute_batch_hmp_igemm(
1005 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1006 uint32_t uarch_index,
1007 size_t batch_index,
1008 size_t mr_block_start,
1009 size_t nr_block_start,
1010 size_t mr_block_size,
1011 size_t nr_block_size)
1012 {
1013 const size_t ks = context->ks;
1014 const size_t cm_stride = context->cm_stride;
1015
1016 context->ukernel.function[uarch_index](
1017 mr_block_size,
1018 nr_block_size,
1019 context->kc,
1020 context->ks_scaled,
1021 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1022 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1023 (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1024 cm_stride,
1025 context->cn_stride,
1026 context->a_offset + batch_index * context->ba_stride,
1027 context->zero,
1028 &context->params);
1029 }
1030
xnn_compute_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1031 void xnn_compute_hmp_igemm(
1032 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1033 uint32_t uarch_index,
1034 size_t mr_block_start,
1035 size_t nr_block_start,
1036 size_t mr_block_size,
1037 size_t nr_block_size)
1038 {
1039 const size_t ks = context->ks;
1040 const size_t cm_stride = context->cm_stride;
1041
1042 context->ukernel.function[uarch_index](
1043 mr_block_size,
1044 nr_block_size,
1045 context->kc,
1046 context->ks_scaled,
1047 (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1048 (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1049 (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1050 cm_stride,
1051 context->cn_stride,
1052 context->a_offset,
1053 context->zero,
1054 &context->params);
1055 }
1056 #endif // XNN_MAX_UARCH_TYPES > 1
1057
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)1058 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1059 {
1060 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1061 xnn_log_error("failed to run operator: XNNPACK is not initialized");
1062 return xnn_status_uninitialized;
1063 }
1064 switch (op->state) {
1065 case xnn_run_state_invalid:
1066 xnn_log_error("failed to run operator: operator was not successfully setup");
1067 return xnn_status_invalid_state;
1068 case xnn_run_state_ready:
1069 break;
1070 case xnn_run_state_skip:
1071 return xnn_status_success;
1072 }
1073
1074 switch (op->compute.type) {
1075 case xnn_parallelization_type_invalid:
1076 break;
1077 case xnn_parallelization_type_1d:
1078 assert(op->compute.range[0] != 0);
1079 pthreadpool_parallelize_1d(
1080 threadpool,
1081 op->compute.task_1d,
1082 &op->context,
1083 op->compute.range[0],
1084 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1085 break;
1086 case xnn_parallelization_type_1d_tile_1d:
1087 assert(op->compute.range[0] != 0);
1088 assert(op->compute.tile[0] != 0);
1089 pthreadpool_parallelize_1d_tile_1d(
1090 threadpool,
1091 op->compute.task_1d_tile_1d,
1092 &op->context,
1093 op->compute.range[0],
1094 op->compute.tile[0],
1095 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1096 break;
1097 case xnn_parallelization_type_2d:
1098 assert(op->compute.range[0] != 0);
1099 assert(op->compute.range[1] != 0);
1100 pthreadpool_parallelize_2d(
1101 threadpool,
1102 op->compute.task_2d,
1103 &op->context,
1104 op->compute.range[0], op->compute.range[1],
1105 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1106 break;
1107 case xnn_parallelization_type_2d_tile_1d:
1108 assert(op->compute.range[0] != 0);
1109 assert(op->compute.range[1] != 0);
1110 assert(op->compute.tile[0] != 0);
1111 pthreadpool_parallelize_2d_tile_1d(
1112 threadpool,
1113 op->compute.task_2d_tile_1d,
1114 &op->context,
1115 op->compute.range[0], op->compute.range[1],
1116 op->compute.tile[0],
1117 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1118 break;
1119 case xnn_parallelization_type_2d_tile_2d:
1120 assert(op->compute.range[0] != 0);
1121 assert(op->compute.range[1] != 0);
1122 assert(op->compute.tile[0] != 0);
1123 assert(op->compute.tile[1] != 0);
1124 pthreadpool_parallelize_2d_tile_2d(
1125 threadpool,
1126 op->compute.task_2d_tile_2d,
1127 &op->context,
1128 op->compute.range[0], op->compute.range[1],
1129 op->compute.tile[0], op->compute.tile[1],
1130 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1131 break;
1132 case xnn_parallelization_type_3d:
1133 assert(op->compute.range[0] != 0);
1134 assert(op->compute.range[1] != 0);
1135 assert(op->compute.range[2] != 0);
1136 pthreadpool_parallelize_3d(
1137 threadpool,
1138 op->compute.task_3d,
1139 &op->context,
1140 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1141 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1142 break;
1143 case xnn_parallelization_type_3d_tile_2d:
1144 assert(op->compute.range[0] != 0);
1145 assert(op->compute.range[1] != 0);
1146 assert(op->compute.range[2] != 0);
1147 assert(op->compute.tile[0] != 0);
1148 assert(op->compute.tile[1] != 0);
1149 pthreadpool_parallelize_3d_tile_2d(
1150 threadpool,
1151 op->compute.task_3d_tile_2d,
1152 &op->context,
1153 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1154 op->compute.tile[0], op->compute.tile[1],
1155 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1156 break;
1157 case xnn_parallelization_type_4d:
1158 assert(op->compute.range[0] != 0);
1159 assert(op->compute.range[1] != 0);
1160 assert(op->compute.range[2] != 0);
1161 assert(op->compute.range[3] != 0);
1162 pthreadpool_parallelize_4d(
1163 threadpool,
1164 op->compute.task_4d,
1165 &op->context,
1166 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1167 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1168 break;
1169 case xnn_parallelization_type_4d_tile_2d:
1170 assert(op->compute.range[0] != 0);
1171 assert(op->compute.range[1] != 0);
1172 assert(op->compute.range[2] != 0);
1173 assert(op->compute.range[3] != 0);
1174 assert(op->compute.tile[0] != 0);
1175 assert(op->compute.tile[1] != 0);
1176 pthreadpool_parallelize_4d_tile_2d(
1177 threadpool,
1178 op->compute.task_4d_tile_2d,
1179 &op->context,
1180 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1181 op->compute.tile[0], op->compute.tile[1],
1182 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1183 break;
1184 case xnn_parallelization_type_5d:
1185 assert(op->compute.range[0] != 0);
1186 assert(op->compute.range[1] != 0);
1187 assert(op->compute.range[2] != 0);
1188 assert(op->compute.range[3] != 0);
1189 assert(op->compute.range[4] != 0);
1190 pthreadpool_parallelize_5d(
1191 threadpool,
1192 op->compute.task_5d,
1193 &op->context,
1194 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1195 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1196 break;
1197 case xnn_parallelization_type_5d_tile_2d:
1198 assert(op->compute.range[0] != 0);
1199 assert(op->compute.range[1] != 0);
1200 assert(op->compute.range[2] != 0);
1201 assert(op->compute.range[3] != 0);
1202 assert(op->compute.range[4] != 0);
1203 assert(op->compute.tile[0] != 0);
1204 assert(op->compute.tile[1] != 0);
1205 pthreadpool_parallelize_5d_tile_2d(
1206 threadpool,
1207 op->compute.task_5d_tile_2d,
1208 &op->context,
1209 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1210 op->compute.tile[0], op->compute.tile[1],
1211 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1212 break;
1213 case xnn_parallelization_type_6d_tile_2d:
1214 assert(op->compute.range[0] != 0);
1215 assert(op->compute.range[1] != 0);
1216 assert(op->compute.range[2] != 0);
1217 assert(op->compute.range[3] != 0);
1218 assert(op->compute.range[4] != 0);
1219 assert(op->compute.range[5] != 0);
1220 assert(op->compute.tile[0] != 0);
1221 assert(op->compute.tile[1] != 0);
1222 pthreadpool_parallelize_6d_tile_2d(
1223 threadpool,
1224 op->compute.task_6d_tile_2d,
1225 &op->context,
1226 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1227 op->compute.tile[0], op->compute.tile[1],
1228 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1229 break;
1230 #if XNN_MAX_UARCH_TYPES > 1
1231 case xnn_parallelization_type_2d_tile_2d_with_uarch:
1232 assert(op->compute.range[0] != 0);
1233 assert(op->compute.range[1] != 0);
1234 assert(op->compute.tile[0] != 0);
1235 assert(op->compute.tile[1] != 0);
1236 pthreadpool_parallelize_2d_tile_2d_with_uarch(
1237 threadpool,
1238 op->compute.task_2d_tile_2d_with_id,
1239 &op->context,
1240 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1241 op->compute.range[0], op->compute.range[1],
1242 op->compute.tile[0], op->compute.tile[1],
1243 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1244 break;
1245 case xnn_parallelization_type_3d_tile_2d_with_uarch:
1246 assert(op->compute.range[0] != 0);
1247 assert(op->compute.range[1] != 0);
1248 assert(op->compute.range[2] != 0);
1249 assert(op->compute.tile[0] != 0);
1250 assert(op->compute.tile[1] != 0);
1251 pthreadpool_parallelize_3d_tile_2d_with_uarch(
1252 threadpool,
1253 op->compute.task_3d_tile_2d_with_id,
1254 &op->context,
1255 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1256 op->compute.range[0], op->compute.range[1], op->compute.range[2],
1257 op->compute.tile[0], op->compute.tile[1],
1258 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1259 break;
1260 case xnn_parallelization_type_4d_tile_2d_with_uarch:
1261 assert(op->compute.range[0] != 0);
1262 assert(op->compute.range[1] != 0);
1263 assert(op->compute.range[2] != 0);
1264 assert(op->compute.range[3] != 0);
1265 assert(op->compute.tile[0] != 0);
1266 assert(op->compute.tile[1] != 0);
1267 pthreadpool_parallelize_4d_tile_2d_with_uarch(
1268 threadpool,
1269 op->compute.task_4d_tile_2d_with_id,
1270 &op->context,
1271 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1272 op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1273 op->compute.tile[0], op->compute.tile[1],
1274 PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
1275 break;
1276 #endif // XNN_MAX_UARCH_TYPES > 1
1277 default:
1278 XNN_UNREACHABLE;
1279 }
1280 return xnn_status_success;
1281 }
1282