1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 9 #include <stddef.h> 10 #include <stdint.h> 11 12 #include <xnnpack.h> 13 #include <xnnpack/common.h> 14 #include <xnnpack/math.h> 15 #include <xnnpack/params.h> 16 17 18 enum xnn_parallelization_type { 19 xnn_parallelization_type_invalid = 0, 20 xnn_parallelization_type_1d, 21 xnn_parallelization_type_1d_tile_1d, 22 xnn_parallelization_type_2d, 23 xnn_parallelization_type_2d_tile_1d, 24 xnn_parallelization_type_2d_tile_2d, 25 xnn_parallelization_type_3d, 26 xnn_parallelization_type_3d_tile_2d, 27 xnn_parallelization_type_4d, 28 xnn_parallelization_type_4d_tile_2d, 29 xnn_parallelization_type_5d, 30 xnn_parallelization_type_5d_tile_2d, 31 xnn_parallelization_type_6d_tile_2d, 32 #if XNN_MAX_UARCH_TYPES > 1 33 xnn_parallelization_type_2d_tile_2d_with_uarch, 34 xnn_parallelization_type_3d_tile_2d_with_uarch, 35 xnn_parallelization_type_4d_tile_2d_with_uarch, 36 #endif // XNN_MAX_UARCH_TYPES > 1 37 }; 38 39 struct compute_parameters { 40 enum xnn_parallelization_type type; 41 union { 42 pthreadpool_task_1d_t task_1d; 43 pthreadpool_task_1d_tile_1d_t task_1d_tile_1d; 44 pthreadpool_task_2d_t task_2d; 45 pthreadpool_task_2d_tile_1d_t task_2d_tile_1d; 46 pthreadpool_task_2d_tile_2d_t task_2d_tile_2d; 47 pthreadpool_task_3d_t task_3d; 48 pthreadpool_task_3d_tile_2d_t task_3d_tile_2d; 49 pthreadpool_task_4d_t task_4d; 50 pthreadpool_task_4d_tile_2d_t task_4d_tile_2d; 51 pthreadpool_task_5d_t task_5d; 52 pthreadpool_task_5d_tile_2d_t task_5d_tile_2d; 53 pthreadpool_task_6d_tile_2d_t task_6d_tile_2d; 54 #if XNN_MAX_UARCH_TYPES > 1 55 pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id; 56 pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id; 57 pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id; 58 #endif // XNN_MAX_UARCH_TYPES > 1 59 }; 60 size_t range[6]; 61 size_t tile[2]; 62 }; 63 64 struct gemm_context { 65 size_t k_scaled; 66 const void* a; 67 size_t a_stride; 68 const void* packed_w; 69 size_t w_stride; 70 size_t wg_stride; 71 void* c; 72 size_t cm_stride; 73 size_t cn_stride; 74 size_t cg_stride; 75 uint32_t log2_csize; 76 struct xnn_hmp_gemm_ukernel ukernel; 77 union { 78 union xnn_qs8_conv_minmax_params qs8; 79 union xnn_qu8_conv_minmax_params qu8; 80 union xnn_f16_scaleminmax_params f16; 81 union xnn_f32_minmax_params f32; 82 } params; 83 }; 84 85 #ifndef __cplusplus 86 XNN_PRIVATE void xnn_compute_grouped_gemm( 87 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 88 size_t group_index, 89 size_t mr_block_start, 90 size_t nr_block_start, 91 size_t mr_block_size, 92 size_t nr_block_size); 93 94 XNN_PRIVATE void xnn_compute_gemm( 95 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 96 size_t mr_block_start, 97 size_t nr_block_start, 98 size_t mr_block_size, 99 size_t nr_block_size); 100 101 #if XNN_MAX_UARCH_TYPES > 1 102 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( 103 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 104 uint32_t uarch_index, 105 size_t group_index, 106 size_t mr_block_start, 107 size_t nr_block_start, 108 size_t mr_block_size, 109 size_t nr_block_size); 110 111 XNN_PRIVATE void xnn_compute_hmp_gemm( 112 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 113 uint32_t uarch_index, 114 size_t mr_block_start, 115 size_t nr_block_start, 116 size_t mr_block_size, 117 size_t nr_block_size); 118 #endif // XNN_MAX_UARCH_TYPES > 1 119 #endif 120 121 // Context for Sparse Matrix-Dense Matrix Multiplication. 122 // C [MxN] := A [MxK] * B [KxN] + bias [N] 123 // A and C are dense matrices with row-major storage, B is a sparse matrix. 124 struct spmm_context { 125 // N dimension of the B and C matrices. 126 // Corresponds to number of output channels in 1x1 convolution. 127 size_t n; 128 // M dimension of the A and C matrices, pre-scaled by sizeof(element size). 129 // Corresponds to the stride, in bytes, between adjacent rows of C matrix. 130 size_t scaled_m; 131 // Input matrix A. 132 const void* input; 133 // Packed bias elements and non-zero filter elements. 134 const void* nonzero_weights; 135 // Input pointer increments, in bytes, after each processed non-zero weight. 136 const int32_t* input_increments; 137 // Number of non-zero filter elements per each N (output channel) dimension. 138 const uint32_t* output_channel_nonzeros; 139 // Output matrix C. 140 void* output; 141 // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution 142 size_t batched_input_stride; 143 // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution 144 size_t batched_output_stride; 145 // Micro-kernel function pointer. 146 xnn_spmm_ukernel_function ukernel; 147 // Output activation parameters. 148 union { 149 union xnn_f32_minmax_params f32; 150 } params; 151 }; 152 153 #ifndef __cplusplus 154 XNN_PRIVATE void xnn_compute_spmm( 155 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], 156 size_t batch_index, 157 size_t mr_block_start, 158 size_t mr_block_size); 159 #endif 160 161 struct igemm_context { 162 size_t ks; 163 size_t ks_scaled; 164 size_t kc; 165 size_t w_stride; 166 const void** indirect_a; 167 size_t a_offset; 168 void* zero; 169 const void* packed_w; 170 void* c; 171 size_t cm_stride; 172 size_t cn_stride; 173 size_t ga_stride; 174 size_t gw_stride; 175 size_t gc_stride; 176 size_t ba_stride; 177 size_t bc_stride; 178 uint32_t log2_csize; 179 struct xnn_hmp_igemm_ukernel ukernel; 180 union { 181 union xnn_qs8_conv_minmax_params qs8; 182 union xnn_qu8_conv_minmax_params qu8; 183 union xnn_f16_scaleminmax_params f16; 184 union xnn_f32_minmax_params f32; 185 } params; 186 }; 187 188 #ifndef __cplusplus 189 XNN_PRIVATE void xnn_compute_grouped_igemm( 190 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 191 size_t group_index, 192 size_t mr_block_start, 193 size_t nr_block_start, 194 size_t mr_block_size, 195 size_t nr_block_size); 196 197 XNN_PRIVATE void xnn_compute_grouped_batch_igemm( 198 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 199 size_t batch_index, 200 size_t group_index, 201 size_t mr_block_start, 202 size_t nr_block_start, 203 size_t mr_block_size, 204 size_t nr_block_size); 205 206 XNN_PRIVATE void xnn_compute_igemm( 207 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 208 size_t mr_block_start, 209 size_t nr_block_start, 210 size_t mr_block_size, 211 size_t nr_block_size); 212 213 XNN_PRIVATE void xnn_compute_batch_igemm( 214 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 215 size_t batch_index, 216 size_t mr_block_start, 217 size_t nr_block_start, 218 size_t mr_block_size, 219 size_t nr_block_size); 220 221 #if XNN_MAX_UARCH_TYPES > 1 222 XNN_PRIVATE void xnn_compute_hmp_grouped_igemm( 223 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 224 uint32_t uarch_index, 225 size_t group_index, 226 size_t mr_block_start, 227 size_t nr_block_start, 228 size_t mr_block_size, 229 size_t nr_block_size); 230 231 XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm( 232 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 233 uint32_t uarch_index, 234 size_t batch_index, 235 size_t group_index, 236 size_t mr_block_start, 237 size_t nr_block_start, 238 size_t mr_block_size, 239 size_t nr_block_size); 240 241 XNN_PRIVATE void xnn_compute_hmp_igemm( 242 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 243 uint32_t uarch_index, 244 size_t mr_block_start, 245 size_t nr_block_start, 246 size_t mr_block_size, 247 size_t nr_block_size); 248 249 XNN_PRIVATE void xnn_compute_batch_hmp_igemm( 250 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 251 uint32_t uarch_index, 252 size_t batch_index, 253 size_t mr_block_start, 254 size_t nr_block_start, 255 size_t mr_block_size, 256 size_t nr_block_size); 257 #endif // XNN_MAX_UARCH_TYPES > 1 258 #endif 259 260 struct subgemm_context { 261 const struct subconvolution_params* subconvolution_params; 262 size_t kc; 263 const void* a; 264 size_t ax_stride; 265 size_t ay_stride; 266 size_t cx_stride; 267 size_t cy_stride; 268 size_t cn_stride; 269 size_t ga_stride; 270 size_t gw_stride; 271 size_t gc_stride; 272 size_t ba_stride; 273 size_t bc_stride; 274 uint32_t log2_csize; 275 struct xnn_hmp_gemm_ukernel ukernel; 276 union { 277 union xnn_qs8_conv_minmax_params qs8; 278 union xnn_qu8_conv_minmax_params qu8; 279 union xnn_f16_scaleminmax_params f16; 280 union xnn_f32_minmax_params f32; 281 } params; 282 }; 283 284 #ifndef __cplusplus 285 XNN_PRIVATE void xnn_compute_grouped_subgemm2d( 286 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 287 size_t batch_index, 288 size_t group_index, 289 size_t subkernel_index, 290 size_t slice_y, 291 size_t slice_x_start, 292 size_t nr_block_start, 293 size_t slice_x_max, 294 size_t nr_block_size); 295 296 XNN_PRIVATE void xnn_compute_subgemm2d( 297 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 298 size_t batch_index, 299 size_t subkernel_index, 300 size_t slice_y, 301 size_t slice_x_start, 302 size_t nr_block_start, 303 size_t slice_x_max, 304 size_t nr_block_size); 305 #endif 306 307 struct subconv_context { 308 const struct subconvolution_params* subconvolution_params; 309 size_t kc; 310 size_t a_offset; 311 void* zero; 312 size_t cx_stride; 313 size_t cy_stride; 314 size_t cn_stride; 315 size_t ga_stride; 316 size_t gw_stride; 317 size_t gc_stride; 318 size_t ba_stride; 319 size_t bc_stride; 320 uint32_t log2_csize; 321 struct xnn_hmp_igemm_ukernel ukernel; 322 union { 323 union xnn_qs8_conv_minmax_params qs8; 324 union xnn_qu8_conv_minmax_params qu8; 325 union xnn_f16_scaleminmax_params f16; 326 union xnn_f32_minmax_params f32; 327 } params; 328 }; 329 330 #ifndef __cplusplus 331 XNN_PRIVATE void xnn_compute_grouped_subconv2d( 332 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 333 size_t batch_index, 334 size_t group_index, 335 size_t subkernel_index, 336 size_t slice_y, 337 size_t slice_x_start, 338 size_t nr_block_start, 339 size_t slice_x_max, 340 size_t nr_block_size); 341 342 XNN_PRIVATE void xnn_compute_subconv2d( 343 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 344 size_t batch_index, 345 size_t subkernel_index, 346 size_t slice_y, 347 size_t slice_x_start, 348 size_t nr_block_start, 349 size_t slice_x_max, 350 size_t nr_block_size); 351 #endif 352 353 struct conv2d_context { 354 size_t input_height; 355 size_t input_width; 356 const void* input; 357 size_t input_batch_stride; 358 const void* zero; 359 const void* packed_weights; 360 void* output; 361 size_t output_batch_stride; 362 size_t input_padding_top; 363 size_t output_channels; 364 size_t output_height_stride; 365 size_t output_channel_stride; 366 union { 367 xnn_conv_hwc2chw_ukernel_function hwc2chw_ukernel; 368 }; 369 union { 370 union xnn_f32_minmax_params f32; 371 } params; 372 }; 373 374 #ifndef __cplusplus 375 XNN_PRIVATE void xnn_compute_conv2d_hwc2chw( 376 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 377 size_t batch_index, 378 size_t output_y_start, 379 size_t output_y_slice); 380 #endif 381 382 struct dwconv_context { 383 const void** indirect_input; 384 size_t indirect_input_width_stride; 385 size_t indirect_input_height_stride; 386 size_t input_offset; 387 size_t input_batch_stride; 388 const void* packed_weights; 389 void* output; 390 size_t output_batch_stride; 391 size_t output_height_stride; 392 size_t output_width; 393 size_t groups; 394 const void* zero; 395 size_t output_increment; 396 union { 397 union xnn_qs8_conv_minmax_params qs8; 398 union xnn_qu8_conv_minmax_params qu8; 399 union xnn_f16_minmax_params f16; 400 union xnn_f32_minmax_params f32; 401 } params; 402 union { 403 xnn_dwconv_unipass_ukernel_function unipass_ukernel; 404 }; 405 }; 406 407 #ifndef __cplusplus 408 XNN_PRIVATE void xnn_compute_dwconv_unipass( 409 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], 410 size_t batch_index, 411 size_t output_y); 412 #endif 413 414 struct dwconv2d_context { 415 size_t input_height; 416 size_t input_width; 417 const void* input; 418 const void* zero; 419 uint32_t input_padding_top; 420 size_t input_channel_stride; 421 size_t input_batch_stride; 422 const void* packed_weights; 423 size_t weights_channel_stride; 424 void* output; 425 size_t output_channel_stride; 426 size_t output_batch_stride; 427 union { 428 union xnn_f32_chw_params f32; 429 } params; 430 union { 431 xnn_dwconv2d_chw_ukernel_function chw_ukernel; 432 }; 433 }; 434 435 #ifndef __cplusplus 436 XNN_PRIVATE void xnn_compute_dwconv2d_chw( 437 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 438 size_t batch_index, 439 size_t channel); 440 #endif 441 442 struct depthtospace2d_hwc_context { 443 size_t elements; 444 size_t input_width; 445 size_t block_size; 446 const void* input; 447 void* output; 448 size_t input_height_stride; 449 size_t input_width_stride; 450 size_t output_height_stride; 451 size_t output_width_stride; 452 xnn_univector_ukernel_function ukernel; 453 }; 454 455 #ifndef __cplusplus 456 XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_contiguous( 457 const struct depthtospace2d_hwc_context* context, 458 size_t batch_input_y, 459 size_t input_x, 460 size_t block_y); 461 462 XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_strided( 463 const struct depthtospace2d_hwc_context* context, 464 size_t batch_input_y, 465 size_t input_x, 466 size_t block_y, 467 size_t block_x); 468 #endif 469 470 struct depthtospace2d_chw2hwc_context { 471 size_t output_channels; 472 size_t input_height; 473 size_t input_width; 474 uint32_t block_size; 475 const void* input; 476 void* output; 477 size_t input_batch_stride; 478 size_t output_batch_stride; 479 size_t output_channel_stride; 480 xnn_depthtospace2d_chw2hwc_ukernel_function ukernel; 481 }; 482 483 #ifndef __cplusplus 484 XNN_PRIVATE void xnn_compute_depthtospace2d_chw2hwc( 485 const struct depthtospace2d_chw2hwc_context* context, 486 size_t batch_index); 487 #endif 488 489 struct max_pooling_context { 490 const void** indirect_input; 491 size_t indirect_input_height_stride; 492 size_t input_offset; 493 size_t input_batch_stride; 494 void* output; 495 size_t output_batch_stride; 496 size_t output_height_stride; 497 size_t output_width; 498 size_t pooling_size; 499 size_t channels; 500 size_t input_increment; 501 size_t output_increment; 502 union { 503 union xnn_u8_minmax_params u8; 504 union xnn_f32_minmax_params f32; 505 } params; 506 xnn_maxpool_ukernel_function ukernel; 507 }; 508 509 #ifndef __cplusplus 510 XNN_PRIVATE void xnn_compute_max_pooling( 511 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 512 size_t batch_index, 513 size_t output_y); 514 #endif 515 516 struct unpooling_context { 517 const void* input; 518 size_t input_height_stride; 519 size_t input_width_stride; 520 const uint32_t* index; 521 size_t index_height_stride; 522 size_t index_width_stride; 523 const void** indirect_output; 524 size_t indirect_output_height_stride; 525 size_t indirect_output_width_stride; 526 size_t pooling_size; 527 size_t channels; 528 uint32_t fill_value; 529 xnn_unpool_ukernel_function ukernel; 530 }; 531 532 #ifndef __cplusplus 533 XNN_PRIVATE void xnn_compute_unpooling( 534 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], 535 size_t input_y, 536 size_t input_x); 537 #endif 538 539 struct argmax_pooling_context { 540 const void** indirect_input; 541 size_t indirect_input_height_stride; 542 size_t input_offset; 543 size_t input_batch_stride; 544 void* output; 545 size_t output_batch_stride; 546 size_t output_height_stride; 547 size_t output_width; 548 uint32_t* index; 549 size_t index_batch_stride; 550 size_t index_height_stride; 551 size_t pooling_size; 552 size_t channels; 553 size_t input_increment; 554 size_t output_increment; 555 union { 556 xnn_argmaxpool_unipass_ukernel_function unipass_ukernel; 557 xnn_argmaxpool_multipass_ukernel_function multipass_ukernel; 558 }; 559 }; 560 561 #ifndef __cplusplus 562 XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( 563 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 564 size_t batch_index, 565 size_t output_y); 566 567 XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( 568 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 569 size_t batch_index, 570 size_t output_y); 571 #endif 572 573 struct average_pooling_context { 574 const void** indirect_input; 575 size_t indirect_input_height_stride; 576 size_t input_offset; 577 size_t input_batch_stride; 578 void* output; 579 size_t output_batch_stride; 580 size_t output_height_stride; 581 size_t output_width; 582 size_t pooling_size; 583 size_t channels; 584 const void* zero; 585 size_t input_increment; 586 size_t output_increment; 587 union { 588 union xnn_qu8_avgpool_minmax_params qu8; 589 union xnn_f32_scaleminmax_params f32; 590 } params; 591 union { 592 xnn_avgpool_unipass_ukernel_function unipass_ukernel; 593 xnn_avgpool_multipass_ukernel_function multipass_ukernel; 594 }; 595 }; 596 597 #ifndef __cplusplus 598 XNN_PRIVATE void xnn_compute_average_pooling_unipass( 599 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 600 size_t batch_index, 601 size_t output_y); 602 603 XNN_PRIVATE void xnn_compute_average_pooling_multipass( 604 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 605 size_t batch_index, 606 size_t output_y); 607 #endif 608 609 struct pixelwise_average_pooling_context { 610 const void** indirect_input; 611 size_t indirect_input_height_stride; 612 size_t input_offset; 613 size_t input_batch_stride; 614 const void* pixelwise_buffer; 615 size_t pixelwise_buffer_height_stride; 616 void* output; 617 size_t output_batch_stride; 618 size_t output_height_stride; 619 size_t output_width; 620 size_t pooling_size; 621 size_t channels; 622 const void* zero; 623 size_t input_increment; 624 size_t output_increment; 625 union { 626 union xnn_u8_minmax_params u8; 627 union xnn_f32_minmax_params f32; 628 } params; 629 union { 630 xnn_pavgpool_unipass_ukernel_function unipass_ukernel; 631 xnn_pavgpool_multipass_ukernel_function multipass_ukernel; 632 }; 633 }; 634 635 #ifndef __cplusplus 636 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( 637 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 638 size_t batch_index, 639 size_t output_y); 640 641 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( 642 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 643 size_t batch_index, 644 size_t output_y); 645 #endif 646 647 struct global_average_pooling_nwc_context { 648 const void* input; 649 const void* zero; 650 size_t input_pixel_stride; 651 size_t input_batch_stride; 652 size_t input_elements; 653 size_t channels; 654 void* output; 655 size_t output_batch_stride; 656 union { 657 union xnn_qs8_avgpool_minmax_params qs8; 658 union xnn_qu8_avgpool_minmax_params qu8; 659 union xnn_f16_scaleminmax_params f16; 660 union xnn_f32_scaleminmax_params f32; 661 } params; 662 union { 663 xnn_gavgpool_unipass_ukernel_function unipass_ukernel; 664 xnn_gavgpool_multipass_ukernel_function multipass_ukernel; 665 }; 666 }; 667 668 #ifndef __cplusplus 669 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass( 670 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 671 size_t batch_index); 672 673 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass( 674 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 675 size_t batch_index); 676 #endif 677 678 struct global_average_pooling_ncw_context { 679 size_t input_elements; 680 const void* input; 681 size_t input_channel_stride; 682 size_t input_batch_stride; 683 void* output; 684 size_t output_channel_stride; 685 size_t output_batch_stride; 686 xnn_gavgpool_cw_ukernel_function ukernel; 687 union { 688 union xnn_f32_gavgpool_params f32; 689 } params; 690 }; 691 692 #ifndef __cplusplus 693 XNN_PRIVATE void xnn_compute_global_average_pooling_ncw( 694 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], 695 size_t batch_index, 696 size_t channels_start, 697 size_t channels_slice); 698 #endif 699 700 struct resize_bilinear_context { 701 // Number of channels multiplied by sizeof(input element). 702 size_t scaled_channels; 703 // Indirection buffer with pointers related to rows of input pixels. 704 const void** indirect_input; 705 // Offset, in bytes, to be added to pointers in indirection buffer. 706 size_t input_offset; 707 // Stride, in bytes, between images of consecutive batches in the input. 708 size_t input_batch_stride; 709 // Packed pairs of (x, y) linear interpolation coefficients. 710 const void* packed_weights; 711 // Pointer to the output tensor. 712 void* output; 713 // Stride, in bytes, between adjacent pixels in the output. 714 size_t output_pixel_stride; 715 // Stride, in bytes, between images of consecutive batches in the output. 716 size_t output_batch_stride; 717 // log2(sizeof(weight element)). 718 uint32_t log2_wsize; 719 // Pointer to BILINEAR micro-kernel function. 720 xnn_ibilinear_ukernel_function ukernel; 721 }; 722 723 struct resize_bilinear_chw_context { 724 // Number of pixels per output image plane. 725 size_t output_pixels; 726 // Number of channels multiplied by sizeof(input element). 727 size_t channels; 728 // Stride, in bytes, between adjacent channels in the input. 729 size_t input_channel_stride; 730 // Indirection buffer with pointers related to rows of input pixels. 731 const void** indirect_input; 732 // Offset, in bytes, to be added to pointers in indirection buffer. 733 size_t input_offset; 734 // Stride, in bytes, between images of consecutive batches in the input. 735 size_t input_batch_stride; 736 // Packed pairs of (x, y) linear interpolation coefficients. 737 const void* packed_weights; 738 // Pointer to the output tensor. 739 void* output; 740 // Stride, in bytes, between images of consecutive batches in the output. 741 size_t output_batch_stride; 742 // Stride, in bytes, between consecutive channels of an output image. 743 size_t output_channel_stride; 744 // Pointer to BILINEAR micro-kernel function. 745 xnn_ibilinear_chw_ukernel_function ukernel; 746 }; 747 748 #ifndef __cplusplus 749 XNN_PRIVATE void xnn_compute_resize_bilinear( 750 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], 751 size_t batch_index, 752 size_t pixel_start, 753 size_t pixel_range); 754 XNN_PRIVATE void xnn_compute_resize_bilinear_chw( 755 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], 756 size_t batch_index, 757 size_t pixel_start, 758 size_t pixel_range); 759 #endif 760 761 struct elementwise_binary_context { 762 const void* a; 763 size_t a_stride[XNN_MAX_TENSOR_DIMS - 1]; 764 const void* b; 765 size_t b_stride[XNN_MAX_TENSOR_DIMS - 1]; 766 void* y; 767 size_t y_stride[XNN_MAX_TENSOR_DIMS - 1]; 768 size_t elements; 769 union { 770 union xnn_qs8_addsub_minmax_params qs8_addsub; 771 union xnn_qu8_addsub_minmax_params qu8_addsub; 772 union xnn_qs8_mul_minmax_params qs8_mul; 773 union xnn_qu8_mul_minmax_params qu8_mul; 774 union xnn_f16_minmax_params f16; 775 union xnn_f32_minmax_params f32; 776 } params; 777 xnn_vbinary_ukernel_function ukernel; 778 }; 779 780 #ifndef __cplusplus 781 XNN_PRIVATE void xnn_compute_elementwise_binary_5d( 782 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 783 size_t i, size_t j, size_t k, size_t l, size_t m); 784 #endif 785 786 struct channel_shuffle_context { 787 const void* x; 788 size_t x_stride; 789 void* y; 790 size_t y_stride; 791 size_t n; 792 size_t m; 793 union { 794 xnn_zipc_ukernel_function fixed_ukernel; 795 xnn_zipv_ukernel_function variable_ukernel; 796 }; 797 }; 798 799 #ifndef __cplusplus 800 XNN_PRIVATE void xnn_compute_channel_shuffle_fixed( 801 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 802 size_t index); 803 804 XNN_PRIVATE void xnn_compute_channel_shuffle_variable( 805 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 806 size_t index); 807 #endif 808 809 struct lut_strided_context { 810 size_t n; 811 const void* x; 812 size_t x_stride; 813 const void* t; 814 void* y; 815 size_t y_stride; 816 xnn_x8_lut_ukernel_function ukernel; 817 }; 818 819 #ifndef __cplusplus 820 XNN_PRIVATE void xnn_compute_lut_strided( 821 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 822 size_t batch_index); 823 #endif 824 825 struct lut_contiguous_context { 826 const void* x; 827 size_t x_stride; 828 const void* t; 829 void* y; 830 size_t y_stride; 831 xnn_x8_lut_ukernel_function ukernel; 832 }; 833 834 #ifndef __cplusplus 835 XNN_PRIVATE void xnn_compute_lut_contiguous( 836 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 837 size_t offset, 838 size_t size); 839 #endif 840 841 struct univector_strided_context { 842 size_t n; 843 const void* x; 844 size_t x_stride; 845 void* y; 846 size_t y_stride; 847 xnn_univector_ukernel_function ukernel; 848 union { 849 union xnn_f16_f32_cvt_params f16_f32_cvt; 850 union xnn_f16_hswish_params f16_hswish; 851 union xnn_f32_abs_params f32_abs; 852 union xnn_f32_default_params f32_default; 853 union xnn_f32_elu_params f32_elu; 854 union xnn_f32_f16_cvt_params f32_f16_cvt; 855 union xnn_f32_hswish_params f32_hswish; 856 union xnn_f32_lrelu_params f32_lrelu; 857 union xnn_f32_minmax_params f32_minmax; 858 union xnn_f32_neg_params f32_neg; 859 union xnn_f32_qs8_cvt_params f32_qs8_cvt; 860 union xnn_f32_qu8_cvt_params f32_qu8_cvt; 861 union xnn_f32_rnd_params f32_rnd; 862 union xnn_f32_sigmoid_params f32_sigmoid; 863 union xnn_f32_sqrt_params f32_sqrt; 864 union xnn_qs8_f32_cvt_params qs8_f32_cvt; 865 union xnn_qu8_f32_cvt_params qu8_f32_cvt; 866 union xnn_s8_minmax_params s8_minmax; 867 union xnn_u8_minmax_params u8_minmax; 868 } params; 869 }; 870 871 #ifndef __cplusplus 872 XNN_PRIVATE void xnn_compute_univector_strided( 873 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 874 size_t batch_index, 875 size_t batch_range); 876 #endif 877 878 struct univector_contiguous_context { 879 const void* x; 880 void* y; 881 uint16_t log2_xsize; 882 uint16_t log2_ysize; 883 xnn_univector_ukernel_function ukernel; 884 union { 885 union xnn_f16_f32_cvt_params f16_f32_cvt; 886 union xnn_f16_hswish_params f16_hswish; 887 union xnn_f32_abs_params f32_abs; 888 union xnn_f32_default_params f32_default; 889 union xnn_f32_elu_params f32_elu; 890 union xnn_f32_f16_cvt_params f32_f16_cvt; 891 union xnn_f32_hswish_params f32_hswish; 892 union xnn_f32_lrelu_params f32_lrelu; 893 union xnn_f32_minmax_params f32_minmax; 894 union xnn_f32_neg_params f32_neg; 895 union xnn_f32_qs8_cvt_params f32_qs8_cvt; 896 union xnn_f32_qu8_cvt_params f32_qu8_cvt; 897 union xnn_f32_rnd_params f32_rnd; 898 union xnn_f32_sigmoid_params f32_sigmoid; 899 union xnn_f32_sqrt_params f32_sqrt; 900 union xnn_qs8_f32_cvt_params qs8_f32_cvt; 901 union xnn_qu8_f32_cvt_params qu8_f32_cvt; 902 union xnn_s8_minmax_params s8_minmax; 903 union xnn_u8_minmax_params u8_minmax; 904 } params; 905 }; 906 907 #ifndef __cplusplus 908 XNN_PRIVATE void xnn_compute_univector_contiguous( 909 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 910 size_t offset, 911 size_t size); 912 #endif 913 914 struct prelu_context { 915 size_t n; 916 const void* x; 917 size_t x_stride; 918 const void* w; 919 void* y; 920 size_t y_stride; 921 xnn_prelu_ukernel_function ukernel; 922 }; 923 924 #ifndef __cplusplus 925 XNN_PRIVATE void xnn_compute_prelu( 926 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], 927 size_t batch_start, 928 size_t batch_range); 929 #endif 930 931 struct vmulcaddc_context { 932 size_t n; 933 const void* x; 934 size_t x_stride; 935 const void* w; 936 void* y; 937 size_t y_stride; 938 xnn_vmulcaddc_ukernel_function ukernel; 939 union { 940 union xnn_f16_minmax_params f16; 941 union xnn_f32_minmax_params f32; 942 } params; 943 }; 944 945 #ifndef __cplusplus 946 XNN_PRIVATE void xnn_compute_vmulcaddc( 947 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], 948 size_t batch_start, 949 size_t batch_size); 950 #endif 951 952 struct pad_context { 953 const void* input; 954 size_t input_stride[XNN_MAX_TENSOR_DIMS - 1]; 955 void* output; 956 size_t output_stride[XNN_MAX_TENSOR_DIMS - 1]; 957 size_t pre_paddings[XNN_MAX_TENSOR_DIMS]; 958 size_t post_paddings[1]; 959 size_t input_size[XNN_MAX_TENSOR_DIMS]; 960 size_t output_size[1]; 961 uint32_t padding_value; 962 xnn_pad_ukernel_function pad_ukernel; 963 xnn_fill_ukernel_function fill_ukernel; 964 }; 965 966 #ifndef __cplusplus 967 XNN_PRIVATE void xnn_compute_pad_5d( 968 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], 969 size_t i, size_t j, size_t k, size_t l, size_t m); 970 #endif 971 972 struct u8_softmax_context { 973 size_t n; 974 const uint8_t* x; 975 size_t x_stride; 976 const uint32_t* t; 977 uint8_t* y; 978 size_t y_stride; 979 xnn_u8_rmax_ukernel_function rmax_ukernel; 980 xnn_u8_lut32norm_ukernel_function lut_norm_ukernel; 981 }; 982 983 #ifndef __cplusplus 984 XNN_PRIVATE void xnn_compute_u8_softmax( 985 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 986 size_t batch_index); 987 #endif 988 989 struct f32_three_pass_softmax_context { 990 size_t n; 991 const void* x; 992 size_t x_stride; 993 void* y; 994 size_t y_stride; 995 xnn_f32_rmax_ukernel_function rmax_ukernel; 996 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel; 997 xnn_vbinary_ukernel_function vmulc_ukernel; 998 union xnn_f32_minmax_params minmax_params; 999 union xnn_f32_expminus_params expminus_params; 1000 }; 1001 1002 #ifndef __cplusplus 1003 XNN_PRIVATE void xnn_compute_f32_three_pass_softmax( 1004 const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 1005 size_t batch_index); 1006 #endif 1007