1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 9 #include <stddef.h> 10 #include <stdint.h> 11 12 #include <xnnpack.h> 13 #include <xnnpack/common.h> 14 #include <xnnpack/math.h> 15 #include <xnnpack/params.h> 16 17 18 enum xnn_parallelization_type { 19 xnn_parallelization_type_invalid = 0, 20 xnn_parallelization_type_1d, 21 xnn_parallelization_type_1d_tile_1d, 22 xnn_parallelization_type_2d, 23 xnn_parallelization_type_2d_tile_1d, 24 xnn_parallelization_type_2d_tile_2d, 25 xnn_parallelization_type_3d, 26 xnn_parallelization_type_3d_tile_2d, 27 xnn_parallelization_type_4d, 28 xnn_parallelization_type_4d_tile_2d, 29 xnn_parallelization_type_5d, 30 xnn_parallelization_type_5d_tile_2d, 31 xnn_parallelization_type_6d_tile_2d, 32 #if XNN_MAX_UARCH_TYPES > 1 33 xnn_parallelization_type_2d_tile_2d_with_uarch, 34 xnn_parallelization_type_3d_tile_2d_with_uarch, 35 xnn_parallelization_type_4d_tile_2d_with_uarch, 36 #endif // XNN_MAX_UARCH_TYPES > 1 37 }; 38 39 struct compute_parameters { 40 enum xnn_parallelization_type type; 41 union { 42 pthreadpool_task_1d_t task_1d; 43 pthreadpool_task_1d_tile_1d_t task_1d_tile_1d; 44 pthreadpool_task_2d_t task_2d; 45 pthreadpool_task_2d_tile_1d_t task_2d_tile_1d; 46 pthreadpool_task_2d_tile_2d_t task_2d_tile_2d; 47 pthreadpool_task_3d_t task_3d; 48 pthreadpool_task_3d_tile_2d_t task_3d_tile_2d; 49 pthreadpool_task_4d_t task_4d; 50 pthreadpool_task_4d_tile_2d_t task_4d_tile_2d; 51 pthreadpool_task_5d_t task_5d; 52 pthreadpool_task_5d_tile_2d_t task_5d_tile_2d; 53 pthreadpool_task_6d_tile_2d_t task_6d_tile_2d; 54 #if XNN_MAX_UARCH_TYPES > 1 55 pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id; 56 pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id; 57 pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id; 58 #endif // XNN_MAX_UARCH_TYPES > 1 59 }; 60 size_t range[6]; 61 size_t tile[2]; 62 }; 63 64 struct transpose_context { 65 const void* x; 66 void* y; 67 union { 68 xnn_transposec_ukernel_function const_size_ukernel; 69 xnn_transposev_ukernel_function variable_size_ukernel; 70 }; 71 union { 72 size_t element_size; 73 size_t log2_element_size; 74 }; 75 size_t input_stride[XNN_MAX_TENSOR_DIMS]; 76 size_t output_stride[XNN_MAX_TENSOR_DIMS]; 77 }; 78 79 XNN_PRIVATE void xnn_compute_transposec_2d( 80 const struct transpose_context* context, 81 size_t i, 82 size_t j, 83 size_t tile_i, 84 size_t tile_j); 85 86 XNN_PRIVATE void xnn_compute_transposec_3d( 87 const struct transpose_context* context, 88 size_t i, 89 size_t j, 90 size_t k, 91 size_t tile_j, 92 size_t tile_k); 93 94 XNN_PRIVATE void xnn_compute_transposec_4d( 95 const struct transpose_context* context, 96 size_t i, 97 size_t j, 98 size_t k, 99 size_t l, 100 size_t tile_k, 101 size_t tile_l); 102 103 XNN_PRIVATE void xnn_compute_transposec_5d( 104 const struct transpose_context* context, 105 size_t i, 106 size_t j, 107 size_t k, 108 size_t l, 109 size_t m, 110 size_t tile_l, 111 size_t tile_m); 112 113 XNN_PRIVATE void xnn_compute_transposec_6d( 114 const struct transpose_context* context, 115 size_t i, 116 size_t j, 117 size_t k, 118 size_t l, 119 size_t m, 120 size_t n, 121 size_t tile_m, 122 size_t tile_n); 123 124 XNN_PRIVATE void xnn_compute_transposev_2d( 125 const struct transpose_context* context, 126 size_t i, 127 size_t j, 128 size_t tile_i, 129 size_t tile_j); 130 131 XNN_PRIVATE void xnn_compute_transposev_3d( 132 const struct transpose_context* context, 133 size_t i, 134 size_t j, 135 size_t k, 136 size_t tile_j, 137 size_t tile_k); 138 139 XNN_PRIVATE void xnn_compute_transposev_4d( 140 const struct transpose_context* context, 141 size_t i, 142 size_t j, 143 size_t k, 144 size_t l, 145 size_t tile_k, 146 size_t tile_l); 147 148 XNN_PRIVATE void xnn_compute_transposev_5d( 149 const struct transpose_context* context, 150 size_t i, 151 size_t j, 152 size_t k, 153 size_t l, 154 size_t m, 155 size_t tile_l, 156 size_t tile_m); 157 158 XNN_PRIVATE void xnn_compute_transposev_6d( 159 const struct transpose_context* context, 160 size_t i, 161 size_t j, 162 size_t k, 163 size_t l, 164 size_t m, 165 size_t n, 166 size_t tile_m, 167 size_t tile_n); 168 169 struct gemm_context { 170 size_t k_scaled; 171 const void* a; 172 size_t a_stride; 173 const void* packed_w; 174 size_t w_stride; 175 size_t wg_stride; 176 void* c; 177 size_t cm_stride; 178 size_t cn_stride; 179 size_t cg_stride; 180 uint32_t log2_csize; 181 struct xnn_hmp_gemm_ukernel ukernel; 182 void* fused_params; 183 union { 184 union xnn_qs8_conv_minmax_params qs8; 185 union xnn_qu8_conv_minmax_params qu8; 186 union xnn_f16_scaleminmax_params f16; 187 union xnn_f32_minmax_params f32; 188 } params; 189 }; 190 191 #ifndef __cplusplus 192 XNN_PRIVATE void xnn_compute_grouped_gemm( 193 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 194 size_t group_index, 195 size_t mr_block_start, 196 size_t nr_block_start, 197 size_t mr_block_size, 198 size_t nr_block_size); 199 200 XNN_PRIVATE void xnn_compute_gemm( 201 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 202 size_t mr_block_start, 203 size_t nr_block_start, 204 size_t mr_block_size, 205 size_t nr_block_size); 206 207 #if XNN_MAX_UARCH_TYPES > 1 208 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( 209 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 210 uint32_t uarch_index, 211 size_t group_index, 212 size_t mr_block_start, 213 size_t nr_block_start, 214 size_t mr_block_size, 215 size_t nr_block_size); 216 217 XNN_PRIVATE void xnn_compute_hmp_gemm( 218 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 219 uint32_t uarch_index, 220 size_t mr_block_start, 221 size_t nr_block_start, 222 size_t mr_block_size, 223 size_t nr_block_size); 224 #endif // XNN_MAX_UARCH_TYPES > 1 225 #endif 226 227 // Context for Sparse Matrix-Dense Matrix Multiplication. 228 // C [MxN] := A [MxK] * B [KxN] + bias [N] 229 // A and C are dense matrices with row-major storage, B is a sparse matrix. 230 struct spmm_context { 231 // N dimension of the B and C matrices. 232 // Corresponds to number of output channels in 1x1 convolution. 233 size_t n; 234 // M dimension of the A and C matrices, pre-scaled by sizeof(element size). 235 // Corresponds to the stride, in bytes, between adjacent rows of C matrix. 236 size_t scaled_m; 237 // Input matrix A. 238 const void* input; 239 // Packed bias elements and non-zero filter elements. 240 const void* nonzero_weights; 241 // Input pointer increments, in bytes, after each processed non-zero weight. 242 const int32_t* input_increments; 243 // Number of non-zero filter elements per each N (output channel) dimension. 244 const uint32_t* output_channel_nonzeros; 245 // Output matrix C. 246 void* output; 247 // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution 248 size_t batched_input_stride; 249 // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution 250 size_t batched_output_stride; 251 // Micro-kernel function pointer. 252 xnn_spmm_ukernel_function ukernel; 253 // Output activation parameters. 254 union { 255 union xnn_f32_minmax_params f32; 256 } params; 257 }; 258 259 #ifndef __cplusplus 260 XNN_PRIVATE void xnn_compute_spmm( 261 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], 262 size_t batch_index, 263 size_t mr_block_start, 264 size_t mr_block_size); 265 #endif 266 267 struct igemm_context { 268 size_t ks; 269 size_t ks_scaled; 270 size_t kc; 271 size_t w_stride; 272 const void** indirect_a; 273 size_t a_offset; 274 void* zero; 275 const void* packed_w; 276 void* c; 277 size_t cm_stride; 278 size_t cn_stride; 279 size_t ga_stride; 280 size_t gw_stride; 281 size_t gc_stride; 282 size_t ba_stride; 283 size_t bc_stride; 284 uint32_t log2_csize; 285 struct xnn_hmp_igemm_ukernel ukernel; 286 union { 287 union xnn_qs8_conv_minmax_params qs8; 288 union xnn_qu8_conv_minmax_params qu8; 289 union xnn_f16_scaleminmax_params f16; 290 union xnn_f32_minmax_params f32; 291 } params; 292 }; 293 294 #ifndef __cplusplus 295 XNN_PRIVATE void xnn_compute_grouped_igemm( 296 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 297 size_t group_index, 298 size_t mr_block_start, 299 size_t nr_block_start, 300 size_t mr_block_size, 301 size_t nr_block_size); 302 303 XNN_PRIVATE void xnn_compute_grouped_batch_igemm( 304 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 305 size_t batch_index, 306 size_t group_index, 307 size_t mr_block_start, 308 size_t nr_block_start, 309 size_t mr_block_size, 310 size_t nr_block_size); 311 312 XNN_PRIVATE void xnn_compute_igemm( 313 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 314 size_t mr_block_start, 315 size_t nr_block_start, 316 size_t mr_block_size, 317 size_t nr_block_size); 318 319 XNN_PRIVATE void xnn_compute_batch_igemm( 320 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 321 size_t batch_index, 322 size_t mr_block_start, 323 size_t nr_block_start, 324 size_t mr_block_size, 325 size_t nr_block_size); 326 327 #if XNN_MAX_UARCH_TYPES > 1 328 XNN_PRIVATE void xnn_compute_hmp_grouped_igemm( 329 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 330 uint32_t uarch_index, 331 size_t group_index, 332 size_t mr_block_start, 333 size_t nr_block_start, 334 size_t mr_block_size, 335 size_t nr_block_size); 336 337 XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm( 338 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 339 uint32_t uarch_index, 340 size_t batch_index, 341 size_t group_index, 342 size_t mr_block_start, 343 size_t nr_block_start, 344 size_t mr_block_size, 345 size_t nr_block_size); 346 347 XNN_PRIVATE void xnn_compute_hmp_igemm( 348 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 349 uint32_t uarch_index, 350 size_t mr_block_start, 351 size_t nr_block_start, 352 size_t mr_block_size, 353 size_t nr_block_size); 354 355 XNN_PRIVATE void xnn_compute_batch_hmp_igemm( 356 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 357 uint32_t uarch_index, 358 size_t batch_index, 359 size_t mr_block_start, 360 size_t nr_block_start, 361 size_t mr_block_size, 362 size_t nr_block_size); 363 #endif // XNN_MAX_UARCH_TYPES > 1 364 #endif 365 366 struct subgemm_context { 367 const struct subconvolution_params* subconvolution_params; 368 size_t kc; 369 const void* a; 370 size_t ax_stride; 371 size_t ay_stride; 372 size_t cx_stride; 373 size_t cy_stride; 374 size_t cn_stride; 375 size_t ga_stride; 376 size_t gw_stride; 377 size_t gc_stride; 378 size_t ba_stride; 379 size_t bc_stride; 380 uint32_t log2_csize; 381 struct xnn_hmp_gemm_ukernel ukernel; 382 union { 383 union xnn_qs8_conv_minmax_params qs8; 384 union xnn_qu8_conv_minmax_params qu8; 385 union xnn_f16_scaleminmax_params f16; 386 union xnn_f32_minmax_params f32; 387 } params; 388 }; 389 390 #ifndef __cplusplus 391 XNN_PRIVATE void xnn_compute_grouped_subgemm2d( 392 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 393 size_t batch_index, 394 size_t group_index, 395 size_t subkernel_index, 396 size_t slice_y, 397 size_t slice_x_start, 398 size_t nr_block_start, 399 size_t slice_x_max, 400 size_t nr_block_size); 401 402 XNN_PRIVATE void xnn_compute_subgemm2d( 403 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 404 size_t batch_index, 405 size_t subkernel_index, 406 size_t slice_y, 407 size_t slice_x_start, 408 size_t nr_block_start, 409 size_t slice_x_max, 410 size_t nr_block_size); 411 #endif 412 413 struct subconv_context { 414 const struct subconvolution_params* subconvolution_params; 415 size_t kc; 416 size_t a_offset; 417 void* zero; 418 size_t cx_stride; 419 size_t cy_stride; 420 size_t cn_stride; 421 size_t ga_stride; 422 size_t gw_stride; 423 size_t gc_stride; 424 size_t ba_stride; 425 size_t bc_stride; 426 uint32_t log2_csize; 427 struct xnn_hmp_igemm_ukernel ukernel; 428 union { 429 union xnn_qs8_conv_minmax_params qs8; 430 union xnn_qu8_conv_minmax_params qu8; 431 union xnn_f16_scaleminmax_params f16; 432 union xnn_f32_minmax_params f32; 433 } params; 434 }; 435 436 #ifndef __cplusplus 437 XNN_PRIVATE void xnn_compute_grouped_subconv2d( 438 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 439 size_t batch_index, 440 size_t group_index, 441 size_t subkernel_index, 442 size_t slice_y, 443 size_t slice_x_start, 444 size_t nr_block_start, 445 size_t slice_x_max, 446 size_t nr_block_size); 447 448 XNN_PRIVATE void xnn_compute_subconv2d( 449 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 450 size_t batch_index, 451 size_t subkernel_index, 452 size_t slice_y, 453 size_t slice_x_start, 454 size_t nr_block_start, 455 size_t slice_x_max, 456 size_t nr_block_size); 457 #endif 458 459 struct conv2d_context { 460 size_t input_height; 461 size_t input_width; 462 const void* input; 463 size_t input_batch_stride; 464 const void* zero; 465 const void* packed_weights; 466 void* output; 467 size_t output_batch_stride; 468 size_t input_padding_top; 469 size_t output_channels; 470 size_t output_height_stride; 471 size_t output_channel_stride; 472 union { 473 xnn_conv_hwc2chw_ukernel_function hwc2chw_ukernel; 474 }; 475 union { 476 union xnn_f32_minmax_params f32; 477 } params; 478 }; 479 480 #ifndef __cplusplus 481 XNN_PRIVATE void xnn_compute_conv2d_hwc2chw( 482 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 483 size_t batch_index, 484 size_t output_y_start, 485 size_t output_y_slice); 486 #endif 487 488 struct dwconv_context { 489 const void** indirect_input; 490 size_t indirect_input_width_stride; 491 size_t indirect_input_height_stride; 492 size_t input_offset; 493 size_t input_batch_stride; 494 const void* packed_weights; 495 void* output; 496 size_t output_batch_stride; 497 size_t output_height_stride; 498 size_t output_width; 499 size_t groups; 500 const void* zero; 501 size_t output_increment; 502 union { 503 union xnn_qs8_conv_minmax_params qs8; 504 union xnn_qu8_conv_minmax_params qu8; 505 union xnn_f16_minmax_params f16; 506 union xnn_f32_minmax_params f32; 507 } params; 508 union { 509 xnn_dwconv_unipass_ukernel_function unipass_ukernel; 510 }; 511 }; 512 513 #ifndef __cplusplus 514 XNN_PRIVATE void xnn_compute_dwconv_unipass( 515 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], 516 size_t batch_index, 517 size_t output_y); 518 #endif 519 520 struct dwconv2d_context { 521 size_t input_height; 522 size_t input_width; 523 const void* input; 524 const void* zero; 525 uint32_t input_padding_top; 526 size_t input_channel_stride; 527 size_t input_batch_stride; 528 const void* packed_weights; 529 size_t weights_channel_stride; 530 void* output; 531 size_t output_channel_stride; 532 size_t output_batch_stride; 533 union { 534 union xnn_f32_chw_params f32; 535 } params; 536 union { 537 xnn_dwconv2d_chw_ukernel_function chw_ukernel; 538 }; 539 }; 540 541 #ifndef __cplusplus 542 XNN_PRIVATE void xnn_compute_dwconv2d_chw( 543 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 544 size_t batch_index, 545 size_t channel); 546 #endif 547 548 struct max_pooling_context { 549 const void** indirect_input; 550 size_t indirect_input_height_stride; 551 size_t input_offset; 552 size_t input_batch_stride; 553 void* output; 554 size_t output_batch_stride; 555 size_t output_height_stride; 556 size_t output_width; 557 size_t pooling_size; 558 size_t channels; 559 size_t input_increment; 560 size_t output_increment; 561 union { 562 union xnn_u8_minmax_params u8; 563 union xnn_f32_minmax_params f32; 564 } params; 565 xnn_maxpool_ukernel_function ukernel; 566 }; 567 568 #ifndef __cplusplus 569 XNN_PRIVATE void xnn_compute_max_pooling( 570 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 571 size_t batch_index, 572 size_t output_y); 573 #endif 574 575 struct unpooling_context { 576 const void* input; 577 size_t input_height_stride; 578 size_t input_width_stride; 579 const uint32_t* index; 580 size_t index_height_stride; 581 size_t index_width_stride; 582 const void** indirect_output; 583 size_t indirect_output_height_stride; 584 size_t indirect_output_width_stride; 585 size_t pooling_size; 586 size_t channels; 587 uint32_t fill_value; 588 xnn_unpool_ukernel_function ukernel; 589 }; 590 591 #ifndef __cplusplus 592 XNN_PRIVATE void xnn_compute_unpooling( 593 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], 594 size_t input_y, 595 size_t input_x); 596 #endif 597 598 struct argmax_pooling_context { 599 const void** indirect_input; 600 size_t indirect_input_height_stride; 601 size_t input_offset; 602 size_t input_batch_stride; 603 void* output; 604 size_t output_batch_stride; 605 size_t output_height_stride; 606 size_t output_width; 607 uint32_t* index; 608 size_t index_batch_stride; 609 size_t index_height_stride; 610 size_t pooling_size; 611 size_t channels; 612 size_t input_increment; 613 size_t output_increment; 614 union { 615 xnn_argmaxpool_unipass_ukernel_function unipass_ukernel; 616 xnn_argmaxpool_multipass_ukernel_function multipass_ukernel; 617 }; 618 }; 619 620 #ifndef __cplusplus 621 XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( 622 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 623 size_t batch_index, 624 size_t output_y); 625 626 XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( 627 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 628 size_t batch_index, 629 size_t output_y); 630 #endif 631 632 struct average_pooling_context { 633 const void** indirect_input; 634 size_t indirect_input_height_stride; 635 size_t input_offset; 636 size_t input_batch_stride; 637 void* output; 638 size_t output_batch_stride; 639 size_t output_height_stride; 640 size_t output_width; 641 size_t pooling_size; 642 size_t channels; 643 const void* zero; 644 size_t input_increment; 645 size_t output_increment; 646 union { 647 union xnn_f16_scaleminmax_params f16; 648 union xnn_f32_scaleminmax_params f32; 649 union xnn_qu8_avgpool_minmax_params qu8; 650 } params; 651 union { 652 xnn_avgpool_unipass_ukernel_function unipass_ukernel; 653 xnn_avgpool_multipass_ukernel_function multipass_ukernel; 654 }; 655 }; 656 657 #ifndef __cplusplus 658 XNN_PRIVATE void xnn_compute_average_pooling_unipass( 659 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 660 size_t batch_index, 661 size_t output_y); 662 663 XNN_PRIVATE void xnn_compute_average_pooling_multipass( 664 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 665 size_t batch_index, 666 size_t output_y); 667 #endif 668 669 struct pixelwise_average_pooling_context { 670 const void** indirect_input; 671 size_t indirect_input_height_stride; 672 size_t input_offset; 673 size_t input_batch_stride; 674 const void* pixelwise_buffer; 675 size_t pixelwise_buffer_height_stride; 676 void* output; 677 size_t output_batch_stride; 678 size_t output_height_stride; 679 size_t output_width; 680 size_t pooling_size; 681 size_t channels; 682 const void* zero; 683 size_t input_increment; 684 size_t output_increment; 685 union { 686 union xnn_f16_minmax_params f16; 687 union xnn_f32_minmax_params f32; 688 union xnn_u8_minmax_params u8; 689 } params; 690 union { 691 xnn_pavgpool_unipass_ukernel_function unipass_ukernel; 692 xnn_pavgpool_multipass_ukernel_function multipass_ukernel; 693 }; 694 }; 695 696 #ifndef __cplusplus 697 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( 698 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 699 size_t batch_index, 700 size_t output_y); 701 702 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( 703 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 704 size_t batch_index, 705 size_t output_y); 706 #endif 707 708 struct global_average_pooling_nwc_context { 709 const void* input; 710 const void* zero; 711 size_t input_pixel_stride; 712 size_t input_batch_stride; 713 size_t input_elements; 714 size_t channels; 715 void* output; 716 size_t output_batch_stride; 717 union { 718 union xnn_qs8_avgpool_minmax_params qs8; 719 union xnn_qu8_avgpool_minmax_params qu8; 720 union xnn_f16_scaleminmax_params f16; 721 union xnn_f32_scaleminmax_params f32; 722 } params; 723 union { 724 xnn_gavgpool_unipass_ukernel_function unipass_ukernel; 725 xnn_gavgpool_multipass_ukernel_function multipass_ukernel; 726 }; 727 }; 728 729 #ifndef __cplusplus 730 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass( 731 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 732 size_t batch_index); 733 734 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass( 735 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 736 size_t batch_index); 737 #endif 738 739 struct global_average_pooling_ncw_context { 740 size_t input_elements; 741 const void* input; 742 size_t input_channel_stride; 743 size_t input_batch_stride; 744 void* output; 745 size_t output_channel_stride; 746 size_t output_batch_stride; 747 xnn_gavgpool_cw_ukernel_function ukernel; 748 union { 749 union xnn_f32_gavgpool_params f32; 750 } params; 751 }; 752 753 #ifndef __cplusplus 754 XNN_PRIVATE void xnn_compute_global_average_pooling_ncw( 755 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], 756 size_t batch_index, 757 size_t channels_start, 758 size_t channels_slice); 759 #endif 760 761 struct resize_bilinear_context { 762 // Number of channels multiplied by sizeof(input element). 763 size_t scaled_channels; 764 // Indirection buffer with pointers related to rows of input pixels. 765 const void** indirect_input; 766 // Offset, in bytes, to be added to pointers in indirection buffer. 767 size_t input_offset; 768 // Stride, in bytes, between images of consecutive batches in the input. 769 size_t input_batch_stride; 770 // Packed pairs of (x, y) linear interpolation coefficients. 771 const void* packed_weights; 772 // Pointer to the output tensor. 773 void* output; 774 // Stride, in bytes, between adjacent pixels in the output. 775 size_t output_pixel_stride; 776 // Stride, in bytes, between images of consecutive batches in the output. 777 size_t output_batch_stride; 778 // log2(sizeof(weight element)). 779 uint32_t log2_wsize; 780 // Pointer to BILINEAR micro-kernel function. 781 xnn_ibilinear_ukernel_function ukernel; 782 }; 783 784 struct resize_bilinear_chw_context { 785 // Number of pixels per output image plane. 786 size_t output_pixels; 787 // Number of channels multiplied by sizeof(input element). 788 size_t channels; 789 // Stride, in bytes, between adjacent channels in the input. 790 size_t input_channel_stride; 791 // Indirection buffer with pointers related to rows of input pixels. 792 const void** indirect_input; 793 // Offset, in bytes, to be added to pointers in indirection buffer. 794 size_t input_offset; 795 // Stride, in bytes, between images of consecutive batches in the input. 796 size_t input_batch_stride; 797 // Packed pairs of (x, y) linear interpolation coefficients. 798 const void* packed_weights; 799 // Pointer to the output tensor. 800 void* output; 801 // Stride, in bytes, between images of consecutive batches in the output. 802 size_t output_batch_stride; 803 // Stride, in bytes, between consecutive channels of an output image. 804 size_t output_channel_stride; 805 // Pointer to BILINEAR micro-kernel function. 806 xnn_ibilinear_chw_ukernel_function ukernel; 807 }; 808 809 #ifndef __cplusplus 810 XNN_PRIVATE void xnn_compute_resize_bilinear( 811 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], 812 size_t batch_index, 813 size_t pixel_start, 814 size_t pixel_range); 815 XNN_PRIVATE void xnn_compute_resize_bilinear_chw( 816 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], 817 size_t batch_index, 818 size_t pixel_start, 819 size_t pixel_range); 820 #endif 821 822 struct elementwise_binary_context { 823 const void* a; 824 size_t a_stride[XNN_MAX_TENSOR_DIMS - 1]; 825 const void* b; 826 size_t b_stride[XNN_MAX_TENSOR_DIMS - 1]; 827 void* y; 828 size_t y_stride[XNN_MAX_TENSOR_DIMS - 1]; 829 size_t elements; 830 union { 831 union xnn_qs8_add_minmax_params qs8_addsub; 832 union xnn_qu8_add_minmax_params qu8_addsub; 833 union xnn_qs8_mul_minmax_params qs8_mul; 834 union xnn_qu8_mul_minmax_params qu8_mul; 835 union xnn_f16_minmax_params f16; 836 union xnn_f32_minmax_params f32; 837 } params; 838 xnn_vbinary_ukernel_function ukernel; 839 }; 840 841 #ifndef __cplusplus 842 XNN_PRIVATE void xnn_compute_elementwise_binary_1d( 843 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 844 size_t i); 845 XNN_PRIVATE void xnn_compute_elementwise_binary_2d( 846 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 847 size_t i, size_t j); 848 XNN_PRIVATE void xnn_compute_elementwise_binary_3d( 849 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 850 size_t i, size_t j, size_t k); 851 XNN_PRIVATE void xnn_compute_elementwise_binary_4d( 852 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 853 size_t i, size_t j, size_t k, size_t l); 854 XNN_PRIVATE void xnn_compute_elementwise_binary_5d( 855 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 856 size_t i, size_t j, size_t k, size_t l, size_t m); 857 #endif 858 859 struct channel_shuffle_context { 860 const void* x; 861 size_t x_stride; 862 void* y; 863 size_t y_stride; 864 size_t n; 865 size_t m; 866 union { 867 xnn_zipc_ukernel_function fixed_ukernel; 868 xnn_zipv_ukernel_function variable_ukernel; 869 }; 870 }; 871 872 #ifndef __cplusplus 873 XNN_PRIVATE void xnn_compute_channel_shuffle_fixed( 874 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 875 size_t index); 876 877 XNN_PRIVATE void xnn_compute_channel_shuffle_variable( 878 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 879 size_t index); 880 #endif 881 882 struct lut_strided_context { 883 size_t n; 884 const void* x; 885 size_t x_stride; 886 const void* t; 887 void* y; 888 size_t y_stride; 889 xnn_x8_lut_ukernel_function ukernel; 890 }; 891 892 #ifndef __cplusplus 893 XNN_PRIVATE void xnn_compute_lut_strided( 894 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 895 size_t batch_index); 896 #endif 897 898 struct lut_contiguous_context { 899 const void* x; 900 size_t x_stride; 901 const void* t; 902 void* y; 903 size_t y_stride; 904 xnn_x8_lut_ukernel_function ukernel; 905 }; 906 907 #ifndef __cplusplus 908 XNN_PRIVATE void xnn_compute_lut_contiguous( 909 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 910 size_t offset, 911 size_t size); 912 #endif 913 914 struct univector_strided_context { 915 size_t n; 916 const void* x; 917 size_t x_stride; 918 void* y; 919 size_t y_stride; 920 xnn_vunary_ukernel_function ukernel; 921 union { 922 union xnn_f16_abs_params f16_abs; 923 union xnn_f16_default_params f16_default; 924 union xnn_f16_f32_cvt_params f16_f32_cvt; 925 union xnn_f16_hswish_params f16_hswish; 926 union xnn_f16_lrelu_params f16_lrelu; 927 union xnn_f16_minmax_params f16_minmax; 928 union xnn_f16_neg_params f16_neg; 929 union xnn_f16_sigmoid_params f16_sigmoid; 930 union xnn_f32_abs_params f32_abs; 931 union xnn_f32_default_params f32_default; 932 union xnn_f32_elu_params f32_elu; 933 union xnn_f32_f16_cvt_params f32_f16_cvt; 934 union xnn_f32_hswish_params f32_hswish; 935 union xnn_f32_lrelu_params f32_lrelu; 936 union xnn_f32_minmax_params f32_minmax; 937 union xnn_f32_neg_params f32_neg; 938 union xnn_f32_qs8_cvt_params f32_qs8_cvt; 939 union xnn_f32_qu8_cvt_params f32_qu8_cvt; 940 union xnn_f32_rnd_params f32_rnd; 941 union xnn_f32_sigmoid_params f32_sigmoid; 942 union xnn_f32_sqrt_params f32_sqrt; 943 union xnn_qs8_cvt_params qs8_cvt; 944 union xnn_qs8_f32_cvt_params qs8_f32_cvt; 945 union xnn_qs8_lrelu_params qs8_lrelu; 946 union xnn_qu8_cvt_params qu8_cvt; 947 union xnn_qu8_f32_cvt_params qu8_f32_cvt; 948 union xnn_qu8_lrelu_params qu8_lrelu; 949 union xnn_s8_minmax_params s8_minmax; 950 union xnn_u8_minmax_params u8_minmax; 951 } params; 952 }; 953 954 #ifndef __cplusplus 955 XNN_PRIVATE void xnn_compute_univector_strided( 956 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 957 size_t batch_index, 958 size_t batch_range); 959 #endif 960 961 struct univector_contiguous_context { 962 const void* x; 963 void* y; 964 uint16_t log2_xsize; 965 uint16_t log2_ysize; 966 xnn_vunary_ukernel_function ukernel; 967 union { 968 union xnn_f16_abs_params f16_abs; 969 union xnn_f16_default_params f16_default; 970 union xnn_f16_f32_cvt_params f16_f32_cvt; 971 union xnn_f16_hswish_params f16_hswish; 972 union xnn_f16_lrelu_params f16_lrelu; 973 union xnn_f16_minmax_params f16_minmax; 974 union xnn_f16_neg_params f16_neg; 975 union xnn_f16_sigmoid_params f16_sigmoid; 976 union xnn_f32_abs_params f32_abs; 977 union xnn_f32_default_params f32_default; 978 union xnn_f32_elu_params f32_elu; 979 union xnn_f32_f16_cvt_params f32_f16_cvt; 980 union xnn_f32_hswish_params f32_hswish; 981 union xnn_f32_lrelu_params f32_lrelu; 982 union xnn_f32_minmax_params f32_minmax; 983 union xnn_f32_neg_params f32_neg; 984 union xnn_f32_qs8_cvt_params f32_qs8_cvt; 985 union xnn_f32_qu8_cvt_params f32_qu8_cvt; 986 union xnn_f32_rnd_params f32_rnd; 987 union xnn_f32_sigmoid_params f32_sigmoid; 988 union xnn_f32_sqrt_params f32_sqrt; 989 union xnn_qs8_cvt_params qs8_cvt; 990 union xnn_qs8_f32_cvt_params qs8_f32_cvt; 991 union xnn_qs8_lrelu_params qs8_lrelu; 992 union xnn_qu8_cvt_params qu8_cvt; 993 union xnn_qu8_f32_cvt_params qu8_f32_cvt; 994 union xnn_qu8_lrelu_params qu8_lrelu; 995 union xnn_s8_minmax_params s8_minmax; 996 union xnn_u8_minmax_params u8_minmax; 997 } params; 998 }; 999 1000 #ifndef __cplusplus 1001 XNN_PRIVATE void xnn_compute_univector_contiguous( 1002 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 1003 size_t offset, 1004 size_t size); 1005 #endif 1006 1007 struct prelu_context { 1008 size_t n; 1009 const void* x; 1010 size_t x_stride; 1011 const void* w; 1012 void* y; 1013 size_t y_stride; 1014 xnn_prelu_ukernel_function ukernel; 1015 }; 1016 1017 #ifndef __cplusplus 1018 XNN_PRIVATE void xnn_compute_prelu( 1019 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], 1020 size_t batch_start, 1021 size_t batch_range); 1022 #endif 1023 1024 struct vmulcaddc_context { 1025 size_t n; 1026 const void* x; 1027 size_t x_stride; 1028 const void* w; 1029 void* y; 1030 size_t y_stride; 1031 xnn_vmulcaddc_ukernel_function ukernel; 1032 union { 1033 union xnn_f16_minmax_params f16; 1034 union xnn_f32_minmax_params f32; 1035 } params; 1036 }; 1037 1038 #ifndef __cplusplus 1039 XNN_PRIVATE void xnn_compute_vmulcaddc( 1040 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], 1041 size_t batch_start, 1042 size_t batch_size); 1043 #endif 1044 1045 struct pad_context { 1046 const void* input; 1047 size_t input_stride[XNN_MAX_TENSOR_DIMS - 1]; 1048 void* output; 1049 size_t output_stride[XNN_MAX_TENSOR_DIMS - 1]; 1050 size_t pre_paddings[XNN_MAX_TENSOR_DIMS]; 1051 size_t post_paddings[1]; 1052 size_t input_size[XNN_MAX_TENSOR_DIMS]; 1053 size_t output_size[1]; 1054 uint32_t padding_value; 1055 xnn_pad_ukernel_function pad_ukernel; 1056 xnn_fill_ukernel_function fill_ukernel; 1057 }; 1058 1059 #ifndef __cplusplus 1060 XNN_PRIVATE void xnn_compute_pad_5d( 1061 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], 1062 size_t i, size_t j, size_t k, size_t l, size_t m); 1063 #endif 1064 1065 struct u8_softmax_context { 1066 size_t n; 1067 const uint8_t* x; 1068 size_t x_stride; 1069 const uint32_t* t; 1070 uint8_t* y; 1071 size_t y_stride; 1072 xnn_u8_rmax_ukernel_function rmax_ukernel; 1073 xnn_u8_lut32norm_ukernel_function lut_norm_ukernel; 1074 }; 1075 1076 #ifndef __cplusplus 1077 XNN_PRIVATE void xnn_compute_u8_softmax( 1078 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 1079 size_t batch_index); 1080 #endif 1081 1082 typedef void (*xnn_compute_reciprocal_function)(const void* input, void* output); 1083 1084 struct floating_point_softmax_context { 1085 size_t n; 1086 const void* x; 1087 size_t x_stride; 1088 void* y; 1089 size_t y_stride; 1090 xnn_rmax_ukernel_function rmax_ukernel; 1091 xnn_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel; 1092 xnn_compute_reciprocal_function compute_reciprocal; 1093 xnn_vbinary_ukernel_function vmulc_ukernel; 1094 union { 1095 union xnn_f16_minmax_params f16; 1096 union xnn_f32_minmax_params f32; 1097 } minmax_params; 1098 union { 1099 union xnn_f16_expminus_params f16; 1100 union xnn_f32_expminus_params f32; 1101 } expminus_params; 1102 }; 1103 1104 #ifndef __cplusplus 1105 XNN_PRIVATE void xnn_compute_floating_point_softmax( 1106 const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 1107 size_t batch_index); 1108 #endif 1109