1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 9 #include <stddef.h> 10 #include <stdint.h> 11 12 #include <xnnpack.h> 13 #include <xnnpack/common.h> 14 #include <xnnpack/math.h> 15 #include <xnnpack/params.h> 16 17 18 enum xnn_parallelization_type { 19 xnn_parallelization_type_invalid = 0, 20 xnn_parallelization_type_1d, 21 xnn_parallelization_type_1d_tile_1d, 22 xnn_parallelization_type_2d, 23 xnn_parallelization_type_2d_tile_1d, 24 xnn_parallelization_type_2d_tile_2d, 25 xnn_parallelization_type_3d_tile_2d, 26 xnn_parallelization_type_4d_tile_2d, 27 xnn_parallelization_type_5d_tile_2d, 28 xnn_parallelization_type_6d_tile_2d, 29 }; 30 31 struct compute_parameters { 32 enum xnn_parallelization_type type; 33 union { 34 pthreadpool_task_1d_t task_1d; 35 pthreadpool_task_1d_tile_1d_t task_1d_tile_1d; 36 pthreadpool_task_2d_t task_2d; 37 pthreadpool_task_2d_tile_1d_t task_2d_tile_1d; 38 pthreadpool_task_2d_tile_2d_t task_2d_tile_2d; 39 pthreadpool_task_3d_tile_2d_t task_3d_tile_2d; 40 pthreadpool_task_4d_tile_2d_t task_4d_tile_2d; 41 pthreadpool_task_5d_tile_2d_t task_5d_tile_2d; 42 pthreadpool_task_6d_tile_2d_t task_6d_tile_2d; 43 }; 44 size_t range[6]; 45 size_t tile[2]; 46 }; 47 48 struct gemm_context { 49 size_t k_scaled; 50 const void* a; 51 size_t a_stride; 52 const void* packed_w; 53 size_t w_stride; 54 size_t wg_stride; 55 void* c; 56 size_t cm_stride; 57 size_t cn_stride; 58 size_t cg_stride; 59 uint32_t log2_csize; 60 xnn_gemm_ukernel_function ukernel; 61 union { 62 union xnn_q8_gemm_params q8; 63 union xnn_f32_output_params f32; 64 } params; 65 }; 66 67 #ifndef __cplusplus 68 XNN_PRIVATE void xnn_compute_ggemm( 69 const struct gemm_context context[restrict static 1], 70 size_t group_index, 71 size_t mr_block_start, 72 size_t nr_block_start, 73 size_t mr_block_size, 74 size_t nr_block_size); 75 76 XNN_PRIVATE void xnn_compute_gemm( 77 const struct gemm_context context[restrict static 1], 78 size_t mr_block_start, 79 size_t nr_block_start, 80 size_t mr_block_size, 81 size_t nr_block_size); 82 #endif 83 84 // Context for Sparse Matrix-Dense Matrix Multiplication. 85 // C [MxN] := A [MxK] * B [KxN] + bias [N] 86 // A and C are dense matrices with row-major storage, B is a sparse matrix. 87 struct spmm_context { 88 // N dimension of the B and C matrices. 89 // Corresponds to number of output channels in 1x1 convolution. 90 size_t n; 91 // Input matrix A. 92 const void* a; 93 // Packed bias elements and non-zero filter elements. 94 const void* packed_weights; 95 // Input pointer increments, in bytes, after each processed non-zero weight. 96 const int32_t* input_increments; 97 // Number of non-zero filter elements per each N (output channel) dimension. 98 const uint32_t* output_channel_nonzeros; 99 // Output matrix C. 100 void* c; 101 // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution 102 size_t batched_a_stride; 103 // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution 104 size_t batched_c_stride; 105 // Micro-kernel function pointer. 106 xnn_spmm_ukernel_function ukernel; 107 // Output activation parameters. 108 union { 109 union xnn_f32_output_params f32; 110 } params; 111 }; 112 113 #ifndef __cplusplus 114 XNN_PRIVATE void xnn_compute_spmm( 115 const struct spmm_context context[restrict static 1], 116 size_t batch_index, 117 size_t mr_block_start, 118 size_t mr_block_size); 119 #endif 120 121 struct igemm_context { 122 size_t ks; 123 size_t ks_scaled; 124 size_t kc; 125 size_t w_stride; 126 const void** indirect_a; 127 size_t a_offset; 128 void* zero; 129 const void* packed_w; 130 void* c; 131 size_t cm_stride; 132 size_t cn_stride; 133 size_t ga_stride; 134 size_t gw_stride; 135 size_t gc_stride; 136 size_t ba_stride; 137 size_t bc_stride; 138 uint32_t log2_csize; 139 xnn_igemm_ukernel_function ukernel; 140 union { 141 union xnn_q8_gemm_params q8; 142 union xnn_f32_output_params f32; 143 } params; 144 }; 145 146 #ifndef __cplusplus 147 XNN_PRIVATE void xnn_compute_gigemm( 148 const struct igemm_context context[restrict static 1], 149 size_t batch_index, 150 size_t group_index, 151 size_t mr_block_start, 152 size_t nr_block_start, 153 size_t mr_block_size, 154 size_t nr_block_size); 155 156 XNN_PRIVATE void xnn_compute_igemm( 157 const struct igemm_context context[restrict static 1], 158 size_t batch_index, 159 size_t mr_block_start, 160 size_t nr_block_start, 161 size_t mr_block_size, 162 size_t nr_block_size); 163 #endif 164 165 struct subconv_context { 166 const struct subconvolution_params* subconvolution_params; 167 size_t kc; 168 size_t a_offset; 169 void* zero; 170 size_t cx_stride; 171 size_t cy_stride; 172 size_t cn_stride; 173 size_t ga_stride; 174 size_t gw_stride; 175 size_t gc_stride; 176 size_t ba_stride; 177 size_t bc_stride; 178 uint32_t log2_csize; 179 xnn_igemm_ukernel_function ukernel; 180 union { 181 union xnn_q8_gemm_params q8; 182 union xnn_f32_output_params f32; 183 } params; 184 }; 185 186 #ifndef __cplusplus 187 XNN_PRIVATE void xnn_compute_gsubconv2d( 188 const struct subconv_context context[restrict static 1], 189 size_t batch_index, 190 size_t group_index, 191 size_t subkernel_index, 192 size_t slice_y, 193 size_t slice_x_start, 194 size_t nr_block_start, 195 size_t slice_x_max, 196 size_t nr_block_size); 197 198 XNN_PRIVATE void xnn_compute_subconv2d( 199 const struct subconv_context context[restrict static 1], 200 size_t batch_index, 201 size_t subkernel_index, 202 size_t slice_y, 203 size_t slice_x_start, 204 size_t nr_block_start, 205 size_t slice_x_max, 206 size_t nr_block_size); 207 #endif 208 209 struct dconv2d_context { 210 size_t input_height; 211 size_t input_width; 212 const void* input; 213 size_t input_batch_stride; 214 const void* zero; 215 const void* packed_weights; 216 void* output; 217 size_t output_batch_stride; 218 size_t input_padding_top; 219 size_t output_channels; 220 size_t output_height_stride; 221 size_t output_channel_stride; 222 union { 223 xnn_conv_hwc2spchw_ukernel_function hwc2spchw_ukernel; 224 }; 225 union { 226 union xnn_f32_output_params f32; 227 } params; 228 }; 229 230 #ifndef __cplusplus 231 XNN_PRIVATE void xnn_compute_dconv2d_hwc2spchw( 232 const struct dconv2d_context context[restrict static 1], 233 size_t batch_index, 234 size_t output_y_start, 235 size_t output_y_slice); 236 #endif 237 238 struct dwconv_context { 239 size_t groups; 240 const void** indirection_buffer; 241 size_t indirection_buffer_row_stride; 242 size_t indirection_buffer_col_stride; 243 const void* packed_weights; 244 void* output; 245 size_t output_width; 246 size_t output_row_stride; 247 size_t output_col_increment; 248 union { 249 union xnn_q8_gemm_params q8; 250 union xnn_f32_output_params f32; 251 } params; 252 union { 253 xnn_dwconv_up_ukernel_function unipass_ukernel; 254 }; 255 }; 256 257 #ifndef __cplusplus 258 XNN_PRIVATE void xnn_compute_dwconv_unipass( 259 const struct dwconv_context context[restrict static 1], 260 size_t output_y); 261 #endif 262 263 struct dwconv2d_context { 264 size_t output_height; 265 size_t input_width; 266 const void* input; 267 size_t input_channel_stride; 268 size_t input_batch_stride; 269 const void* packed_weights; 270 size_t weights_channel_stride; 271 void* output; 272 size_t output_channel_stride; 273 size_t output_batch_stride; 274 size_t input_tuple_stride; 275 size_t output_tuple_stride; 276 size_t input_pixel_stride; 277 size_t output_pixel_stride; 278 union { 279 union xnn_f32_spchw_params f32; 280 } params; 281 union { 282 xnn_dwconv_spchw_ukernel_function spchw_ukernel; 283 }; 284 }; 285 286 #ifndef __cplusplus 287 XNN_PRIVATE void xnn_compute_dwconv2d_spchw( 288 const struct dwconv2d_context context[restrict static 1], 289 size_t batch_index, 290 size_t channel); 291 #endif 292 293 struct max_pooling_context { 294 const void** indirect_input; 295 size_t indirect_input_height_stride; 296 size_t input_offset; 297 size_t input_batch_stride; 298 void* output; 299 size_t output_batch_stride; 300 size_t output_height_stride; 301 size_t output_width; 302 size_t pooling_size; 303 size_t channels; 304 size_t input_increment; 305 size_t output_increment; 306 union { 307 union xnn_u8_output_params u8; 308 union xnn_f32_output_params f32; 309 } params; 310 xnn_maxpool_ukernel_function ukernel; 311 }; 312 313 #ifndef __cplusplus 314 XNN_PRIVATE void xnn_compute_max_pooling( 315 const struct max_pooling_context context[restrict static 1], 316 size_t batch_index, 317 size_t output_y); 318 #endif 319 320 struct unpooling_context { 321 const void* input; 322 size_t input_height_stride; 323 size_t input_width_stride; 324 const uint32_t* index; 325 size_t index_height_stride; 326 size_t index_width_stride; 327 void** indirect_output; 328 size_t indirect_output_height_stride; 329 size_t indirect_output_width_stride; 330 size_t pooling_size; 331 size_t channels; 332 uint32_t fill_value; 333 xnn_unpool_ukernel_function ukernel; 334 }; 335 336 #ifndef __cplusplus 337 XNN_PRIVATE void xnn_compute_unpooling( 338 const struct unpooling_context context[restrict static 1], 339 size_t input_y, 340 size_t input_x); 341 #endif 342 343 struct argmax_pooling_context { 344 const void** indirect_input; 345 size_t indirect_input_height_stride; 346 size_t input_offset; 347 size_t input_batch_stride; 348 void* output; 349 size_t output_batch_stride; 350 size_t output_height_stride; 351 size_t output_width; 352 uint32_t* index; 353 size_t index_batch_stride; 354 size_t index_height_stride; 355 size_t pooling_size; 356 size_t channels; 357 size_t input_increment; 358 size_t output_increment; 359 union { 360 union xnn_f32_output_params f32; 361 } params; 362 union { 363 xnn_argmaxpool_up_ukernel_function unipass_ukernel; 364 xnn_argmaxpool_mp_ukernel_function multipass_ukernel; 365 }; 366 }; 367 368 #ifndef __cplusplus 369 XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( 370 const struct argmax_pooling_context context[restrict static 1], 371 size_t batch_index, 372 size_t output_y); 373 374 XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( 375 const struct argmax_pooling_context context[restrict static 1], 376 size_t batch_index, 377 size_t output_y); 378 #endif 379 380 struct average_pooling_context { 381 const void** indirect_input; 382 size_t indirect_input_batch_stride; 383 size_t indirect_input_height_stride; 384 void* output; 385 size_t output_batch_stride; 386 size_t output_height_stride; 387 size_t output_width; 388 size_t pooling_size; 389 size_t channels; 390 const void* zero; 391 size_t input_increment; 392 size_t output_increment; 393 union { 394 union xnn_q8_avgpool_params q8; 395 union xnn_f32_avgpool_params f32; 396 } params; 397 union { 398 xnn_avgpool_up_ukernel_function unipass_ukernel; 399 xnn_avgpool_mp_ukernel_function multipass_ukernel; 400 }; 401 }; 402 403 #ifndef __cplusplus 404 XNN_PRIVATE void xnn_compute_average_pooling_unipass( 405 const struct average_pooling_context context[restrict static 1], 406 size_t batch_index, 407 size_t output_y); 408 409 XNN_PRIVATE void xnn_compute_average_pooling_multipass( 410 const struct average_pooling_context context[restrict static 1], 411 size_t batch_index, 412 size_t output_y); 413 #endif 414 415 struct pixelwise_average_pooling_context { 416 const void** indirect_input; 417 size_t indirect_input_batch_stride; 418 size_t indirect_input_height_stride; 419 const void* pixelwise_buffer; 420 size_t pixelwise_buffer_height_stride; 421 void* output; 422 size_t output_batch_stride; 423 size_t output_height_stride; 424 size_t output_width; 425 size_t pooling_size; 426 size_t channels; 427 const void* zero; 428 size_t input_increment; 429 size_t output_increment; 430 union { 431 union xnn_u8_output_params u8; 432 union xnn_f32_output_params f32; 433 } params; 434 union { 435 xnn_pavgpool_up_ukernel_function unipass_ukernel; 436 xnn_pavgpool_mp_ukernel_function multipass_ukernel; 437 }; 438 }; 439 440 #ifndef __cplusplus 441 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( 442 const struct pixelwise_average_pooling_context context[restrict static 1], 443 size_t batch_index, 444 size_t output_y); 445 446 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( 447 const struct pixelwise_average_pooling_context context[restrict static 1], 448 size_t batch_index, 449 size_t output_y); 450 #endif 451 452 struct global_average_pooling_nwc_context { 453 const void* input; 454 const void* zero; 455 size_t input_pixel_stride; 456 size_t input_batch_stride; 457 size_t input_elements; 458 size_t channels; 459 void* output; 460 size_t output_batch_stride; 461 union { 462 union xnn_q8_avgpool_params q8; 463 union xnn_f32_avgpool_params f32; 464 } params; 465 union { 466 xnn_gavgpool_up_ukernel_function unipass_ukernel; 467 xnn_gavgpool_mp_ukernel_function multipass_ukernel; 468 }; 469 }; 470 471 #ifndef __cplusplus 472 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass( 473 const struct global_average_pooling_nwc_context context[restrict static 1], 474 size_t batch_index); 475 476 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass( 477 const struct global_average_pooling_nwc_context context[restrict static 1], 478 size_t batch_index); 479 #endif 480 481 struct global_average_pooling_ncw_context { 482 size_t input_elements; 483 const void* input; 484 size_t input_channel_stride; 485 size_t input_batch_stride; 486 void* output; 487 size_t output_channel_stride; 488 size_t output_batch_stride; 489 xnn_gavgpool_spchw_ukernel_function ukernel; 490 union { 491 union xnn_f32_gavgpool_params f32; 492 } params; 493 }; 494 495 #ifndef __cplusplus 496 XNN_PRIVATE void xnn_compute_global_average_pooling_ncw( 497 const struct global_average_pooling_ncw_context context[restrict static 1], 498 size_t batch_index, 499 size_t channels_start, 500 size_t channels_slice); 501 #endif 502 503 struct resize_bilinear_context { 504 // Number of channels multiplied by sizeof(input element). 505 size_t scaled_channels; 506 // Indirection buffer with pointers related to rows of input pixels. 507 const void** indirect_input; 508 // Offset, in bytes, to be added to pointers in indirection buffer. 509 size_t input_offset; 510 // Stride, in bytes, between images of consecutive batches in the input. 511 size_t input_batch_stride; 512 // Packed pairs of (x, y) linear interpolation coefficients. 513 const void* packed_weights; 514 // Pointer to the output tensor. 515 void* output; 516 // Stride, in bytes, between adjacent pixels in the output. 517 size_t output_pixel_stride; 518 // Stride, in bytes, between images of consecutive batches in the output. 519 size_t output_batch_stride; 520 // log2(sizeof(weight element)). 521 uint32_t log2_wsize; 522 // Pointer to BILINEAR micro-kernel function. 523 xnn_bilinear_ukernel_function ukernel; 524 }; 525 526 #ifndef __cplusplus 527 XNN_PRIVATE void xnn_compute_resize_bilinear( 528 const struct resize_bilinear_context context[restrict static 1], 529 size_t batch_index, 530 size_t pixel_start, 531 size_t pixel_range); 532 #endif 533 534 struct add_strided_context { 535 size_t n; 536 const void* a; 537 size_t a_stride; 538 const void* b; 539 size_t b_stride; 540 const void* y; 541 size_t y_stride; 542 union { 543 union xnn_q8_add_params q8; 544 union xnn_f32_output_params f32; 545 } params; 546 xnn_vadd_ukernel_function ukernel; 547 }; 548 549 #ifndef __cplusplus 550 XNN_PRIVATE void xnn_compute_add_strided( 551 const struct add_strided_context context[restrict static 1], 552 size_t batch_index, 553 size_t batch_range); 554 #endif 555 556 struct add_contiguous_context { 557 const void* a; 558 const void* b; 559 void* y; 560 union { 561 union xnn_q8_add_params q8; 562 union xnn_f32_output_params f32; 563 } params; 564 xnn_vadd_ukernel_function ukernel; 565 }; 566 567 #ifndef __cplusplus 568 XNN_PRIVATE void xnn_compute_add_contiguous( 569 const struct add_contiguous_context context[restrict static 1], 570 size_t offset, 571 size_t size); 572 #endif 573 574 struct elementwise_binary_context { 575 const void* a; 576 size_t a_stride[XNN_MAX_TENSOR_DIMS - 1]; 577 const void* b; 578 size_t b_stride[XNN_MAX_TENSOR_DIMS - 1]; 579 void* y; 580 size_t y_stride[XNN_MAX_TENSOR_DIMS - 1]; 581 size_t elements; 582 union { 583 union xnn_q8_add_params q8; 584 union xnn_f32_output_params f32; 585 } params; 586 xnn_vbinary_ukernel_function ukernel; 587 }; 588 589 #ifndef __cplusplus 590 XNN_PRIVATE void xnn_compute_elementwise_binary_5d( 591 const struct elementwise_binary_context context[restrict static 1], 592 size_t i, size_t j, size_t k, size_t l, size_t m, size_t l_range, size_t m_range); 593 #endif 594 595 struct channel_shuffle_context { 596 const void* x; 597 size_t x_stride; 598 void* y; 599 size_t y_stride; 600 size_t n; 601 size_t m; 602 union { 603 xnn_zipc_ukernel_function fixed_ukernel; 604 xnn_zipv_ukernel_function variable_ukernel; 605 }; 606 }; 607 608 #ifndef __cplusplus 609 XNN_PRIVATE void xnn_compute_channel_shuffle_fixed( 610 const struct channel_shuffle_context context[restrict static 1], 611 size_t index); 612 613 XNN_PRIVATE void xnn_compute_channel_shuffle_variable( 614 const struct channel_shuffle_context context[restrict static 1], 615 size_t index); 616 #endif 617 618 struct lut_strided_context { 619 size_t n; 620 const void* x; 621 size_t x_stride; 622 const void* t; 623 void* y; 624 size_t y_stride; 625 xnn_x8_lut_ukernel_function ukernel; 626 }; 627 628 #ifndef __cplusplus 629 XNN_PRIVATE void xnn_compute_lut_strided( 630 const struct lut_strided_context context[restrict static 1], 631 size_t batch_index); 632 #endif 633 634 struct lut_contiguous_context { 635 const void* x; 636 size_t x_stride; 637 const void* t; 638 void* y; 639 size_t y_stride; 640 xnn_x8_lut_ukernel_function ukernel; 641 }; 642 643 #ifndef __cplusplus 644 XNN_PRIVATE void xnn_compute_lut_contiguous( 645 const struct lut_contiguous_context context[restrict static 1], 646 size_t offset, 647 size_t size); 648 #endif 649 650 struct univector_strided_context { 651 size_t n; 652 const void* x; 653 size_t x_stride; 654 void* y; 655 size_t y_stride; 656 xnn_univector_ukernel_function ukernel; 657 union { 658 union xnn_u8_output_params u8_output; 659 union xnn_f32_output_params f32_output; 660 union xnn_f32_hswish_params f32_hswish; 661 } params; 662 }; 663 664 #ifndef __cplusplus 665 XNN_PRIVATE void xnn_compute_univector_strided( 666 const struct univector_strided_context context[restrict static 1], 667 size_t batch_index, 668 size_t batch_range); 669 #endif 670 671 struct univector_contiguous_context { 672 const void* x; 673 size_t x_stride; 674 void* y; 675 size_t y_stride; 676 xnn_univector_ukernel_function ukernel; 677 union { 678 union xnn_u8_output_params u8_output; 679 union xnn_f32_output_params f32_output; 680 union xnn_f32_hswish_params f32_hswish; 681 } params; 682 }; 683 684 #ifndef __cplusplus 685 XNN_PRIVATE void xnn_compute_univector_contiguous( 686 const struct univector_contiguous_context context[restrict static 1], 687 size_t offset, 688 size_t size); 689 #endif 690 691 struct prelu_context { 692 size_t n; 693 const void* x; 694 size_t x_stride; 695 const void* w; 696 void* y; 697 size_t y_stride; 698 xnn_prelu_ukernel_function ukernel; 699 union xnn_f32_output_params params; 700 }; 701 702 #ifndef __cplusplus 703 XNN_PRIVATE void xnn_compute_prelu( 704 const struct prelu_context context[restrict static 1], 705 size_t batch_start, 706 size_t batch_range); 707 #endif 708 709 struct vmulcaddc_context { 710 size_t n; 711 const void* x; 712 size_t x_stride; 713 const void* w; 714 void* y; 715 size_t y_stride; 716 xnn_vmulcaddc_ukernel_function ukernel; 717 union { 718 union xnn_f32_output_params f32; 719 } params; 720 }; 721 722 #ifndef __cplusplus 723 XNN_PRIVATE void xnn_compute_vmulcaddc( 724 const struct vmulcaddc_context context[restrict static 1], 725 size_t batch_start, 726 size_t batch_size); 727 #endif 728 729 struct channel_pad_context { 730 size_t n; 731 size_t l; 732 size_t r; 733 uint32_t c; 734 const void* x; 735 size_t x_stride; 736 void* y; 737 size_t y_stride; 738 xnn_pad_ukernel_function ukernel; 739 }; 740 741 #ifndef __cplusplus 742 XNN_PRIVATE void xnn_compute_channel_pad( 743 const struct channel_pad_context context[restrict static 1], 744 size_t batch_start, 745 size_t batch_range); 746 #endif 747 748 struct u8_softmax_context { 749 size_t n; 750 const uint8_t* x; 751 size_t x_stride; 752 const uint32_t* t; 753 uint8_t* y; 754 size_t y_stride; 755 xnn_u8_rmax_ukernel_function rmax_ukernel; 756 xnn_u8_lut32norm_ukernel_function lut_norm_ukernel; 757 }; 758 759 #ifndef __cplusplus 760 XNN_PRIVATE void xnn_compute_u8_softmax( 761 const struct u8_softmax_context context[restrict static 1], 762 size_t batch_index); 763 #endif 764 765 struct f32_three_pass_softmax_context { 766 size_t n; 767 const void* x; 768 size_t x_stride; 769 void* y; 770 size_t y_stride; 771 xnn_f32_rmax_ukernel_function rmax_ukernel; 772 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel; 773 xnn_vbinary_ukernel_function vmulc_ukernel; 774 union xnn_f32_output_params params; 775 }; 776 777 #ifndef __cplusplus 778 XNN_PRIVATE void xnn_compute_f32_three_pass_softmax( 779 const struct f32_three_pass_softmax_context context[restrict static 1], 780 size_t batch_index); 781 #endif 782