• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stddef.h>
11 #include <stdint.h>
12 #include <string.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/operator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/common.h>
18 #include <xnnpack/math.h>
19 #include <xnnpack/params.h>
20 #include <xnnpack/compute.h>
21 
22 
xnn_compute_ggemm(const struct gemm_context context[restrict static1],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)23 void xnn_compute_ggemm(
24     const struct gemm_context context[restrict static 1],
25     size_t group_index,
26     size_t mr_block_start,
27     size_t nr_block_start,
28     size_t mr_block_size,
29     size_t nr_block_size)
30 {
31   const size_t k_scaled  = context->k_scaled;
32   const size_t a_stride  = context->a_stride;
33   const size_t cm_stride = context->cm_stride;
34 
35   context->ukernel(
36       mr_block_size,
37       nr_block_size,
38       k_scaled,
39       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
40       a_stride,
41       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
42       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
43       cm_stride,
44       context->cn_stride,
45       &context->params);
46 }
47 
xnn_compute_gemm(const struct gemm_context context[restrict static1],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)48 void xnn_compute_gemm(
49     const struct gemm_context context[restrict static 1],
50     size_t mr_block_start,
51     size_t nr_block_start,
52     size_t mr_block_size,
53     size_t nr_block_size)
54 {
55   const size_t a_stride  = context->a_stride;
56   const size_t cm_stride = context->cm_stride;
57 
58   context->ukernel(
59       mr_block_size,
60       nr_block_size,
61       context->k_scaled,
62       (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
63       a_stride,
64       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
65       (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
66       cm_stride,
67       context->cn_stride,
68       &context->params);
69 }
70 
xnn_compute_spmm(const struct spmm_context context[restrict static1],size_t batch_index,size_t mr_block_start,size_t mr_block_size)71 void xnn_compute_spmm(
72     const struct spmm_context context[restrict static 1],
73     size_t batch_index,
74     size_t mr_block_start,
75     size_t mr_block_size)
76 {
77   context->ukernel(
78       mr_block_size,
79       context->n,
80       (const void*) ((uintptr_t) context->a + batch_index * context->batched_a_stride + mr_block_start * sizeof(float)),
81       context->packed_weights,
82       context->input_increments,
83       context->output_channel_nonzeros,
84       (void*) ((uintptr_t) context->c + batch_index * context->batched_c_stride + mr_block_start * sizeof(float)),
85       &context->params);
86 }
87 
xnn_compute_gigemm(const struct igemm_context context[restrict static1],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)88 void xnn_compute_gigemm(
89     const struct igemm_context context[restrict static 1],
90     size_t batch_index,
91     size_t group_index,
92     size_t mr_block_start,
93     size_t nr_block_start,
94     size_t mr_block_size,
95     size_t nr_block_size)
96 {
97   const size_t ks        = context->ks;
98   const size_t cm_stride = context->cm_stride;
99 
100   context->ukernel(
101       mr_block_size,
102       nr_block_size,
103       context->kc,
104       context->ks_scaled,
105       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
106       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
107       (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
108       cm_stride,
109       context->cn_stride,
110       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
111       context->zero,
112       &context->params);
113 }
114 
xnn_compute_igemm(const struct igemm_context context[restrict static1],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)115 void xnn_compute_igemm(
116     const struct igemm_context context[restrict static 1],
117     size_t batch_index,
118     size_t mr_block_start,
119     size_t nr_block_start,
120     size_t mr_block_size,
121     size_t nr_block_size)
122 {
123   const size_t ks        = context->ks;
124   const size_t cm_stride = context->cm_stride;
125 
126   context->ukernel(
127       mr_block_size,
128       nr_block_size,
129       context->kc,
130       context->ks_scaled,
131       (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
132       (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
133       (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
134       cm_stride,
135       context->cn_stride,
136       context->a_offset + batch_index * context->ba_stride,
137       context->zero,
138       &context->params);
139 }
140 
xnn_compute_gsubconv2d(const struct subconv_context context[restrict static1],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)141 void xnn_compute_gsubconv2d(
142       const struct subconv_context context[restrict static 1],
143       size_t batch_index,
144       size_t group_index,
145       size_t subkernel_index,
146       size_t slice_y,
147       size_t slice_x_start,
148       size_t nc_block_start,
149       size_t slice_x_max,
150       size_t nc_block_size)
151 {
152   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
153 
154   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
155     return;
156   }
157 
158   const size_t slice_width = subconvolution_params->slice_width;
159   if XNN_UNLIKELY(slice_x_start >= slice_width) {
160     return;
161   }
162   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
163 
164   const size_t cx_stride = context->cx_stride;
165   context->ukernel(
166       slice_x_size,
167       nc_block_size,
168       context->kc,
169       subconvolution_params->scaled_kernel_size,
170       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
171       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
172       (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
173       cx_stride,
174       context->cn_stride,
175       context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
176       context->zero,
177       &context->params);
178 }
179 
xnn_compute_subconv2d(const struct subconv_context context[restrict static1],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)180 void xnn_compute_subconv2d(
181       const struct subconv_context context[restrict static 1],
182       size_t batch_index,
183       size_t subkernel_index,
184       size_t slice_y,
185       size_t slice_x_start,
186       size_t nc_block_start,
187       size_t slice_x_max,
188       size_t nc_block_size)
189 {
190   const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
191 
192   if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
193     return;
194   }
195 
196   const size_t slice_width = subconvolution_params->slice_width;
197   if XNN_UNLIKELY(slice_x_start >= slice_width) {
198     return;
199   }
200   const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
201 
202   const size_t cx_stride = context->cx_stride;
203   context->ukernel(
204       slice_x_size,
205       nc_block_size,
206       context->kc,
207       subconvolution_params->scaled_kernel_size,
208       (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
209       (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
210       (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
211       cx_stride,
212       context->cn_stride,
213       context->a_offset + batch_index * context->ba_stride,
214       context->zero,
215       &context->params);
216 }
217 
xnn_compute_dconv2d_hwc2spchw(const struct dconv2d_context context[restrict static1],size_t batch_index,size_t output_y_start,size_t output_y_slice)218 void xnn_compute_dconv2d_hwc2spchw(
219       const struct dconv2d_context context[restrict static 1],
220       size_t batch_index,
221       size_t output_y_start,
222       size_t output_y_slice)
223 {
224   context->hwc2spchw_ukernel(
225       context->input_height,
226       context->input_width,
227       output_y_start,
228       output_y_start + output_y_slice,
229       (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
230       context->zero,
231       context->packed_weights,
232       (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
233       context->input_padding_top,
234       context->output_channels,
235       context->output_height_stride,
236       context->output_channel_stride,
237       &context->params);
238 }
239 
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict static1],size_t output_y)240 void xnn_compute_dwconv_unipass(
241     const struct dwconv_context context[restrict static 1],
242     size_t output_y)
243 {
244   context->unipass_ukernel(
245     context->groups,
246     context->output_width,
247     context->indirection_buffer + output_y * context->indirection_buffer_row_stride,
248     context->packed_weights,
249     context->output + output_y * context->output_row_stride,
250     context->indirection_buffer_col_stride,
251     context->output_col_increment,
252     &context->params);
253 }
254 
xnn_compute_dwconv2d_spchw(const struct dwconv2d_context context[restrict static1],size_t batch_index,size_t channel)255 void xnn_compute_dwconv2d_spchw(
256     const struct dwconv2d_context context[restrict static 1],
257     size_t batch_index,
258     size_t channel)
259 {
260   context->spchw_ukernel(
261     context->output_height,
262     context->input_width,
263     (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
264     (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
265     (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
266     context->input_tuple_stride,
267     context->output_tuple_stride,
268     context->input_pixel_stride,
269     context->output_pixel_stride,
270     &context->params);
271 }
272 
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict static1],size_t batch_index,size_t output_y)273 void xnn_compute_argmax_pooling_unipass(
274     const struct argmax_pooling_context context[restrict static 1],
275     size_t batch_index,
276     size_t output_y)
277 {
278   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
279     output_y * context->indirect_input_height_stride);
280   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
281   void* output = (void*) ((uintptr_t) context->output +
282     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
283   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
284     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
285 
286   context->unipass_ukernel(
287     context->output_width, context->pooling_size, context->channels,
288     indirect_input, input_offset, output, index,
289     context->input_increment, context->output_increment,
290     &context->params);
291 }
292 
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict static1],size_t batch_index,size_t output_y)293 void xnn_compute_argmax_pooling_multipass(
294     const struct argmax_pooling_context context[restrict static 1],
295     size_t batch_index,
296     size_t output_y)
297 {
298   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
299     output_y * context->indirect_input_height_stride);
300   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
301   void* output = (void*) ((uintptr_t) context->output +
302     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
303   uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
304     batch_index * context->index_batch_stride + output_y * context->index_height_stride);
305 
306   XNN_ALIGN(16) float multipass_accumulation_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(float)];
307   XNN_ALIGN(16) uint32_t multipass_index_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint32_t)];
308 
309   context->multipass_ukernel(
310     context->output_width, context->pooling_size, context->channels,
311     indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
312     context->input_increment, context->output_increment,
313     &context->params);
314 }
315 
xnn_compute_max_pooling(const struct max_pooling_context context[restrict static1],size_t batch_index,size_t output_y)316 void xnn_compute_max_pooling(
317     const struct max_pooling_context context[restrict static 1],
318     size_t batch_index,
319     size_t output_y)
320 {
321   const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
322     output_y * context->indirect_input_height_stride);
323   const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
324   void* output = (void*) ((uintptr_t) context->output +
325     batch_index * context->output_batch_stride + output_y * context->output_height_stride);
326 
327   context->ukernel(
328     context->output_width, context->pooling_size, context->channels,
329     indirect_input, input_offset, output,
330     context->input_increment, context->output_increment,
331     &context->params);
332 }
333 
xnn_compute_unpooling(const struct unpooling_context context[restrict static1],size_t input_y,size_t input_x)334 void xnn_compute_unpooling(
335     const struct unpooling_context context[restrict static 1],
336     size_t input_y,
337     size_t input_x)
338 {
339   const void* input = (const void*) ((uintptr_t) context->input +
340       input_y * context->input_height_stride + input_x * context->input_width_stride);
341   const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
342       input_y * context->index_height_stride + input_x * context->index_width_stride);
343   void** indirect_output =
344     (void**) ((uintptr_t) context->indirect_output +
345       input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
346 
347   context->ukernel(
348     context->pooling_size,
349     context->channels,
350     context->fill_value,
351     input, index, indirect_output);
352 }
353 
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict static1],size_t batch_index,size_t output_y)354 void xnn_compute_average_pooling_unipass(
355     const struct average_pooling_context context[restrict static 1],
356     size_t batch_index,
357     size_t output_y)
358 {
359   const void** indirect_input =
360     (const void**) ((uintptr_t) context->indirect_input +
361       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
362   void* output =
363     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
364 
365   context->unipass_ukernel(
366     context->output_width, context->pooling_size, context->channels,
367     indirect_input, context->zero, output,
368     context->input_increment, context->output_increment,
369     &context->params);
370 }
371 
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict static1],size_t batch_index,size_t output_y)372 void xnn_compute_average_pooling_multipass(
373     const struct average_pooling_context context[restrict static 1],
374     size_t batch_index,
375     size_t output_y)
376 {
377   const void** indirect_input =
378     (const void**) ((uintptr_t) context->indirect_input +
379       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
380   void* output =
381     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
382   XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
383 
384   context->multipass_ukernel(
385     context->output_width, context->pooling_size, context->channels,
386     indirect_input, context->zero, multipass_buffer, output,
387     context->input_increment, context->output_increment,
388     &context->params);
389 }
390 
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict static1],size_t batch_index,size_t output_y)391 void xnn_compute_pixelwise_average_pooling_unipass(
392     const struct pixelwise_average_pooling_context context[restrict static 1],
393     size_t batch_index,
394     size_t output_y)
395 {
396   const void** indirect_input =
397     (const void**) ((uintptr_t) context->indirect_input +
398       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
399   const void* pixelwise_buffer =
400     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
401   void* output =
402     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
403 
404   context->unipass_ukernel(
405     context->output_width, context->pooling_size, context->channels,
406     indirect_input, context->zero, pixelwise_buffer, output,
407     context->input_increment, context->output_increment,
408     &context->params);
409 }
410 
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict static1],size_t batch_index,size_t output_y)411 void xnn_compute_pixelwise_average_pooling_multipass(
412     const struct pixelwise_average_pooling_context context[restrict static 1],
413     size_t batch_index,
414     size_t output_y)
415 {
416   const void** indirect_input =
417     (const void**) ((uintptr_t) context->indirect_input +
418       batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride);
419   const void* pixelwise_buffer =
420     (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
421   void* output =
422     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride);
423   XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
424 
425   context->multipass_ukernel(
426     context->output_width, context->pooling_size, context->channels,
427     indirect_input, context->zero, pixelwise_buffer, multipass_buffer, output,
428     context->input_increment, context->output_increment,
429     &context->params);
430 }
431 
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict static1],size_t batch_index)432 void xnn_compute_global_average_pooling_nwc_unipass(
433     const struct global_average_pooling_nwc_context context[restrict static 1],
434     size_t batch_index)
435 {
436   const void* input =
437     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
438   void* output =
439     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
440 
441   context->unipass_ukernel(
442     context->input_elements,
443     context->channels,
444     input,
445     context->input_pixel_stride,
446     context->zero,
447     output,
448     &context->params);
449 }
450 
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict static1],size_t batch_index)451 void xnn_compute_global_average_pooling_nwc_multipass(
452     const struct global_average_pooling_nwc_context context[restrict static 1],
453     size_t batch_index)
454 {
455   const void* input =
456     (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
457   void* output =
458     (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
459   XNN_ALIGN(16) int32_t multipass_buffer[context->channels + XNN_EXTRA_BYTES / sizeof(uint8_t)];
460 
461   context->multipass_ukernel(
462     context->input_elements,
463     context->channels,
464     input,
465     context->input_pixel_stride,
466     context->zero,
467     multipass_buffer,
468     output,
469     &context->params);
470 }
471 
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict static1],size_t batch_index,size_t channels_start,size_t channels_slice)472 void xnn_compute_global_average_pooling_ncw(
473     const struct global_average_pooling_ncw_context context[restrict static 1],
474     size_t batch_index,
475     size_t channels_start,
476     size_t channels_slice)
477 {
478   const void* input = (const void*) ((uintptr_t) context->input +
479     channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
480   void* output = (void*) ((uintptr_t) context->output +
481     channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
482 
483   context->ukernel(
484     context->input_elements,
485     channels_slice,
486     input,
487     output,
488     &context->params);
489 }
490 
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict static1],size_t batch_index,size_t pixel_start,size_t pixel_range)491 void xnn_compute_resize_bilinear(
492     const struct resize_bilinear_context context[restrict static 1],
493     size_t batch_index,
494     size_t pixel_start,
495     size_t pixel_range)
496 {
497   void* output =
498     (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
499 
500   context->ukernel(
501     pixel_range,
502     context->scaled_channels,
503     context->indirect_input + pixel_start * 4,
504     context->input_offset + batch_index * context->input_batch_stride,
505     context->packed_weights + (pixel_start << context->log2_wsize),
506     output,
507     context->output_pixel_stride - context->scaled_channels);
508 }
509 
xnn_compute_prelu(const struct prelu_context context[restrict static1],size_t batch_start,size_t batch_range)510 void xnn_compute_prelu(
511     const struct prelu_context context[restrict static 1],
512     size_t batch_start,
513     size_t batch_range)
514 {
515   const size_t x_stride = context->x_stride;
516   const size_t y_stride = context->y_stride;
517   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
518   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
519 
520   context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride, &context->params);
521 }
522 
xnn_compute_channel_pad(const struct channel_pad_context context[restrict static1],size_t batch_start,size_t batch_range)523 void xnn_compute_channel_pad(
524     const struct channel_pad_context context[restrict static 1],
525     size_t batch_start,
526     size_t batch_range)
527 {
528   const size_t x_stride = context->x_stride;
529   const size_t y_stride = context->y_stride;
530   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
531   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
532 
533   context->ukernel(batch_range, context->n, context->l, context->r, context->c, x, x_stride, y, y_stride);
534 }
535 
xnn_compute_add_strided(const struct add_strided_context context[restrict static1],size_t batch_index,size_t batch_range)536 void xnn_compute_add_strided(
537     const struct add_strided_context context[restrict static 1],
538     size_t batch_index,
539     size_t batch_range /* always 1 */)
540 {
541   assert(batch_range == 1);
542 
543   const size_t n = context->n;
544   const size_t a_stride = context->a_stride;
545   const size_t b_stride = context->b_stride;
546   const size_t y_stride = context->y_stride;
547   const void* a = (const void*) ((uintptr_t) context->a + a_stride * batch_index);
548   const void* b = (const void*) ((uintptr_t) context->b + b_stride * batch_index);
549   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
550 
551   context->ukernel(n, a, b, y, &context->params);
552 }
553 
xnn_compute_add_contiguous(const struct add_contiguous_context context[restrict static1],size_t offset,size_t size)554 void xnn_compute_add_contiguous(
555     const struct add_contiguous_context context[restrict static 1],
556     size_t offset,
557     size_t size)
558 {
559   const void* a = (const void*) ((uintptr_t) context->a + offset);
560   const void* b = (const void*) ((uintptr_t) context->b + offset);
561   void* y = (void*) ((uintptr_t) context->y + offset);
562   context->ukernel(size, a, b, y, &context->params);
563 }
564 
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict static1],size_t i,size_t j,size_t k,size_t l,size_t m,size_t l_range,size_t m_range)565 void xnn_compute_elementwise_binary_5d(
566     const struct elementwise_binary_context context[restrict static 1],
567     size_t i, size_t j, size_t k, size_t l, size_t m,
568     size_t l_range, size_t m_range)
569 {
570   assert(l_range == 1);
571   assert(m_range == 1);
572 
573   const void* a = (const void*) ((uintptr_t) context->a +
574     i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
575   const void* b = (const void*) ((uintptr_t) context->b +
576     i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
577   void* y = (void*) ((uintptr_t) context->y +
578     i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
579   context->ukernel(context->elements, a, b, y, &context->params);
580 }
581 
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict static1],size_t index)582 void xnn_compute_channel_shuffle_fixed(
583     const struct channel_shuffle_context context[restrict static 1],
584     size_t index)
585 {
586   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
587   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
588 
589   context->fixed_ukernel(context->n, x, y);
590 }
591 
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict static1],size_t index)592 void xnn_compute_channel_shuffle_variable(
593     const struct channel_shuffle_context context[restrict static 1],
594     size_t index)
595 {
596   const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
597   void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
598 
599   context->variable_ukernel(context->n, context->m, x, y);
600 }
601 
xnn_compute_lut_strided(const struct lut_strided_context context[restrict static1],size_t batch_index)602 void xnn_compute_lut_strided(
603     const struct lut_strided_context context[restrict static 1],
604     size_t batch_index)
605 {
606   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
607   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
608 
609   context->ukernel(context->n, x, context->t, y);
610 }
611 
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict static1],size_t offset,size_t size)612 void xnn_compute_lut_contiguous(
613     const struct lut_contiguous_context context[restrict static 1],
614     size_t offset,
615     size_t size)
616 {
617   const void* x = (const void*) ((uintptr_t) context->x + offset);
618   void* y = (void*) ((uintptr_t) context->y + offset);
619 
620   context->ukernel(size, x, context->t, y);
621 }
622 
xnn_compute_univector_strided(const struct univector_strided_context context[restrict static1],size_t batch_index,size_t batch_range)623 void xnn_compute_univector_strided(
624     const struct univector_strided_context context[restrict static 1],
625     size_t batch_index,
626     size_t batch_range /* always 1 */)
627 {
628   assert(batch_range == 1);
629 
630   const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
631   void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
632   context->ukernel(context->n, x, y, &context->params);
633 }
634 
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict static1],size_t offset,size_t size)635 void xnn_compute_univector_contiguous(
636     const struct univector_contiguous_context context[restrict static 1],
637     size_t offset,
638     size_t size)
639 {
640   const void* x = (const void*) ((uintptr_t) context->x + offset);
641   void* y = (void*) ((uintptr_t) context->y + offset);
642   context->ukernel(size, x, y, &context->params);
643 }
644 
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict static1],size_t batch_index)645 void xnn_compute_u8_softmax(
646     const struct u8_softmax_context context[restrict static 1],
647     size_t batch_index)
648 {
649   const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
650   uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
651   const size_t n = context->n;
652 
653   uint8_t x_max = 0;
654   context->rmax_ukernel(n, x, &x_max);
655   const size_t adjustment = x_max ^ 255;
656   const uint32_t* t = (const uint32_t*) context->t + adjustment;
657   context->lut_norm_ukernel(n, x, t, y);
658 }
659 
xnn_compute_f32_three_pass_softmax(const struct f32_three_pass_softmax_context context[restrict static1],size_t batch_index)660 void xnn_compute_f32_three_pass_softmax(
661     const struct f32_three_pass_softmax_context context[restrict static 1],
662     size_t batch_index)
663 {
664   const float* x = (const float*) ((uintptr_t) context->x + context->x_stride * batch_index);
665   float* y = (float*) ((uintptr_t) context->y + context->y_stride * batch_index);
666   const size_t n = context->n;
667 
668   // First pass: reduce-max
669   float x_max;
670   context->rmax_ukernel(n, x, &x_max);
671 
672   // Second pass: reduce-add & store exp(x-x_max)
673   float y_sum;
674   context->raddstoreexpminusmax_ukernel(n, x, y, &y_sum, x_max);
675 
676   // Third pass: scale y
677   const float y_scale = 1.0f / y_sum;
678   context->vmulc_ukernel(n, y, &y_scale, y, &context->params);
679 }
680 
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict static1],size_t batch_start,size_t batch_size)681 void xnn_compute_vmulcaddc(
682     const struct vmulcaddc_context context[restrict static 1],
683     size_t batch_start,
684     size_t batch_size)
685 {
686   const size_t x_stride = context->x_stride;
687   const size_t y_stride = context->y_stride;
688 
689   const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
690   void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
691 
692   context->ukernel(
693     batch_size,
694     context->n,
695     x, x_stride,
696     context->w,
697     y, y_stride,
698     &context->params);
699 }
700 
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)701 enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
702 {
703   if (!xnn_params.initialized) {
704     xnn_log_error("failed to run operator: XNNPACK is not initialized");
705     return xnn_status_uninitialized;
706   }
707   switch (op->state) {
708     case xnn_run_state_invalid:
709       xnn_log_error("failed to run operator: operator was not successfully setup");
710       return xnn_status_invalid_state;
711     case xnn_run_state_ready:
712       break;
713     case xnn_run_state_skip:
714       return xnn_status_success;
715   }
716 
717   switch (op->compute.type) {
718     case xnn_parallelization_type_invalid:
719       break;
720     case xnn_parallelization_type_1d:
721       assert(op->compute.range[0] != 0);
722       pthreadpool_parallelize_1d(
723           threadpool,
724           op->compute.task_1d,
725           &op->context,
726           op->compute.range[0],
727           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
728       break;
729     case xnn_parallelization_type_1d_tile_1d:
730       assert(op->compute.range[0] != 0);
731       assert(op->compute.tile[0] != 0);
732       pthreadpool_parallelize_1d_tile_1d(
733           threadpool,
734           op->compute.task_1d_tile_1d,
735           &op->context,
736           op->compute.range[0],
737           op->compute.tile[0],
738           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
739       break;
740     case xnn_parallelization_type_2d:
741       assert(op->compute.range[0] != 0);
742       assert(op->compute.range[1] != 0);
743       pthreadpool_parallelize_2d(
744           threadpool,
745           op->compute.task_2d,
746           &op->context,
747           op->compute.range[0], op->compute.range[1],
748           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
749       break;
750     case xnn_parallelization_type_2d_tile_1d:
751       assert(op->compute.range[0] != 0);
752       assert(op->compute.range[1] != 0);
753       assert(op->compute.tile[0] != 0);
754       pthreadpool_parallelize_2d_tile_1d(
755           threadpool,
756           op->compute.task_2d_tile_1d,
757           &op->context,
758           op->compute.range[0], op->compute.range[1],
759           op->compute.tile[0],
760           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
761       break;
762     case xnn_parallelization_type_2d_tile_2d:
763       assert(op->compute.range[0] != 0);
764       assert(op->compute.range[1] != 0);
765       assert(op->compute.tile[0] != 0);
766       assert(op->compute.tile[1] != 0);
767       pthreadpool_parallelize_2d_tile_2d(
768           threadpool,
769           op->compute.task_2d_tile_2d,
770           &op->context,
771           op->compute.range[0], op->compute.range[1],
772           op->compute.tile[0], op->compute.tile[1],
773           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
774       break;
775     case xnn_parallelization_type_3d_tile_2d:
776       assert(op->compute.range[0] != 0);
777       assert(op->compute.range[1] != 0);
778       assert(op->compute.range[2] != 0);
779       assert(op->compute.tile[0] != 0);
780       assert(op->compute.tile[1] != 0);
781       pthreadpool_parallelize_3d_tile_2d(
782           threadpool,
783           op->compute.task_3d_tile_2d,
784           &op->context,
785           op->compute.range[0], op->compute.range[1], op->compute.range[2],
786           op->compute.tile[0], op->compute.tile[1],
787           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
788       break;
789     case xnn_parallelization_type_4d_tile_2d:
790       assert(op->compute.range[0] != 0);
791       assert(op->compute.range[1] != 0);
792       assert(op->compute.range[2] != 0);
793       assert(op->compute.range[3] != 0);
794       assert(op->compute.tile[0] != 0);
795       assert(op->compute.tile[1] != 0);
796       pthreadpool_parallelize_4d_tile_2d(
797           threadpool,
798           op->compute.task_4d_tile_2d,
799           &op->context,
800           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
801           op->compute.tile[0], op->compute.tile[1],
802           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
803       break;
804     case xnn_parallelization_type_5d_tile_2d:
805       assert(op->compute.range[0] != 0);
806       assert(op->compute.range[1] != 0);
807       assert(op->compute.range[2] != 0);
808       assert(op->compute.range[3] != 0);
809       assert(op->compute.range[4] != 0);
810       assert(op->compute.tile[0] != 0);
811       assert(op->compute.tile[1] != 0);
812       pthreadpool_parallelize_5d_tile_2d(
813           threadpool,
814           op->compute.task_5d_tile_2d,
815           &op->context,
816           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
817           op->compute.tile[0], op->compute.tile[1],
818           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
819       break;
820     case xnn_parallelization_type_6d_tile_2d:
821       assert(op->compute.range[0] != 0);
822       assert(op->compute.range[1] != 0);
823       assert(op->compute.range[2] != 0);
824       assert(op->compute.range[3] != 0);
825       assert(op->compute.range[4] != 0);
826       assert(op->compute.range[5] != 0);
827       assert(op->compute.tile[0] != 0);
828       assert(op->compute.tile[1] != 0);
829       pthreadpool_parallelize_6d_tile_2d(
830           threadpool,
831           op->compute.task_6d_tile_2d,
832           &op->context,
833           op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
834           op->compute.tile[0], op->compute.tile[1],
835           PTHREADPOOL_FLAG_DISABLE_DENORMALS /* flags */);
836       break;
837     default:
838       XNN_UNREACHABLE;
839   }
840   return xnn_status_success;
841 }
842