1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 9 #include <stddef.h> 10 #include <stdint.h> 11 12 #include <xnnpack.h> 13 #include <xnnpack/common.h> 14 #include <xnnpack/math.h> 15 #include <xnnpack/params.h> 16 17 18 enum xnn_parallelization_type { 19 xnn_parallelization_type_invalid = 0, 20 xnn_parallelization_type_1d, 21 xnn_parallelization_type_1d_tile_1d, 22 xnn_parallelization_type_2d, 23 xnn_parallelization_type_2d_tile_1d, 24 xnn_parallelization_type_2d_tile_2d, 25 xnn_parallelization_type_3d, 26 xnn_parallelization_type_3d_tile_2d, 27 xnn_parallelization_type_4d, 28 xnn_parallelization_type_4d_tile_2d, 29 xnn_parallelization_type_5d, 30 xnn_parallelization_type_5d_tile_2d, 31 xnn_parallelization_type_6d_tile_2d, 32 #if XNN_MAX_UARCH_TYPES > 1 33 xnn_parallelization_type_2d_tile_2d_with_uarch, 34 xnn_parallelization_type_3d_tile_2d_with_uarch, 35 xnn_parallelization_type_4d_tile_2d_with_uarch, 36 #endif // XNN_MAX_UARCH_TYPES > 1 37 }; 38 39 struct compute_parameters { 40 enum xnn_parallelization_type type; 41 union { 42 pthreadpool_task_1d_t task_1d; 43 pthreadpool_task_1d_tile_1d_t task_1d_tile_1d; 44 pthreadpool_task_2d_t task_2d; 45 pthreadpool_task_2d_tile_1d_t task_2d_tile_1d; 46 pthreadpool_task_2d_tile_2d_t task_2d_tile_2d; 47 pthreadpool_task_3d_t task_3d; 48 pthreadpool_task_3d_tile_2d_t task_3d_tile_2d; 49 pthreadpool_task_4d_t task_4d; 50 pthreadpool_task_4d_tile_2d_t task_4d_tile_2d; 51 pthreadpool_task_5d_t task_5d; 52 pthreadpool_task_5d_tile_2d_t task_5d_tile_2d; 53 pthreadpool_task_6d_tile_2d_t task_6d_tile_2d; 54 #if XNN_MAX_UARCH_TYPES > 1 55 pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id; 56 pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id; 57 pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id; 58 #endif // XNN_MAX_UARCH_TYPES > 1 59 }; 60 size_t range[6]; 61 size_t tile[2]; 62 }; 63 64 struct gemm_context { 65 size_t k_scaled; 66 const void* a; 67 size_t a_stride; 68 const void* packed_w; 69 size_t w_stride; 70 size_t wg_stride; 71 void* c; 72 size_t cm_stride; 73 size_t cn_stride; 74 size_t cg_stride; 75 uint32_t log2_csize; 76 struct xnn_hmp_gemm_ukernel ukernel; 77 union { 78 union xnn_qs8_gemm_params qs8; 79 union xnn_qu8_gemm_params qu8; 80 struct xnn_f16_scaleminmax_params f16; 81 union xnn_f32_minmax_params f32; 82 } params; 83 }; 84 85 #ifndef __cplusplus 86 XNN_PRIVATE void xnn_compute_grouped_gemm( 87 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 88 size_t group_index, 89 size_t mr_block_start, 90 size_t nr_block_start, 91 size_t mr_block_size, 92 size_t nr_block_size); 93 94 XNN_PRIVATE void xnn_compute_gemm( 95 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 96 size_t mr_block_start, 97 size_t nr_block_start, 98 size_t mr_block_size, 99 size_t nr_block_size); 100 101 #if XNN_MAX_UARCH_TYPES > 1 102 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( 103 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 104 uint32_t uarch_index, 105 size_t group_index, 106 size_t mr_block_start, 107 size_t nr_block_start, 108 size_t mr_block_size, 109 size_t nr_block_size); 110 111 XNN_PRIVATE void xnn_compute_hmp_gemm( 112 const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], 113 uint32_t uarch_index, 114 size_t mr_block_start, 115 size_t nr_block_start, 116 size_t mr_block_size, 117 size_t nr_block_size); 118 #endif // XNN_MAX_UARCH_TYPES > 1 119 #endif 120 121 // Context for Sparse Matrix-Dense Matrix Multiplication. 122 // C [MxN] := A [MxK] * B [KxN] + bias [N] 123 // A and C are dense matrices with row-major storage, B is a sparse matrix. 124 struct spmm_context { 125 // N dimension of the B and C matrices. 126 // Corresponds to number of output channels in 1x1 convolution. 127 size_t n; 128 // M dimension of the A and C matrices, pre-scaled by sizeof(element size). 129 // Corresponds to the stride, in bytes, between adjacent rows of C matrix. 130 size_t scaled_m; 131 // Input matrix A. 132 const void* input; 133 // Packed bias elements and non-zero filter elements. 134 const void* nonzero_weights; 135 // Input pointer increments, in bytes, after each processed non-zero weight. 136 const int32_t* input_increments; 137 // Number of non-zero filter elements per each N (output channel) dimension. 138 const uint32_t* output_channel_nonzeros; 139 // Output matrix C. 140 void* output; 141 // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution 142 size_t batched_input_stride; 143 // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution 144 size_t batched_output_stride; 145 // Micro-kernel function pointer. 146 xnn_spmm_ukernel_function ukernel; 147 // Output activation parameters. 148 union { 149 union xnn_f32_minmax_params f32; 150 } params; 151 }; 152 153 #ifndef __cplusplus 154 XNN_PRIVATE void xnn_compute_spmm( 155 const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], 156 size_t batch_index, 157 size_t mr_block_start, 158 size_t mr_block_size); 159 #endif 160 161 struct igemm_context { 162 size_t ks; 163 size_t ks_scaled; 164 size_t kc; 165 size_t w_stride; 166 const void** indirect_a; 167 size_t a_offset; 168 void* zero; 169 const void* packed_w; 170 void* c; 171 size_t cm_stride; 172 size_t cn_stride; 173 size_t ga_stride; 174 size_t gw_stride; 175 size_t gc_stride; 176 size_t ba_stride; 177 size_t bc_stride; 178 uint32_t log2_csize; 179 struct xnn_hmp_igemm_ukernel ukernel; 180 union { 181 union xnn_qs8_gemm_params qs8; 182 union xnn_qu8_gemm_params qu8; 183 struct xnn_f16_scaleminmax_params f16; 184 union xnn_f32_minmax_params f32; 185 } params; 186 }; 187 188 #ifndef __cplusplus 189 XNN_PRIVATE void xnn_compute_grouped_igemm( 190 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 191 size_t group_index, 192 size_t mr_block_start, 193 size_t nr_block_start, 194 size_t mr_block_size, 195 size_t nr_block_size); 196 197 XNN_PRIVATE void xnn_compute_grouped_batch_igemm( 198 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 199 size_t batch_index, 200 size_t group_index, 201 size_t mr_block_start, 202 size_t nr_block_start, 203 size_t mr_block_size, 204 size_t nr_block_size); 205 206 XNN_PRIVATE void xnn_compute_igemm( 207 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 208 size_t mr_block_start, 209 size_t nr_block_start, 210 size_t mr_block_size, 211 size_t nr_block_size); 212 213 XNN_PRIVATE void xnn_compute_batch_igemm( 214 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 215 size_t batch_index, 216 size_t mr_block_start, 217 size_t nr_block_start, 218 size_t mr_block_size, 219 size_t nr_block_size); 220 221 #if XNN_MAX_UARCH_TYPES > 1 222 XNN_PRIVATE void xnn_compute_hmp_grouped_igemm( 223 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 224 uint32_t uarch_index, 225 size_t group_index, 226 size_t mr_block_start, 227 size_t nr_block_start, 228 size_t mr_block_size, 229 size_t nr_block_size); 230 231 XNN_PRIVATE void xnn_compute_hmp_grouped_batch_igemm( 232 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 233 uint32_t uarch_index, 234 size_t batch_index, 235 size_t group_index, 236 size_t mr_block_start, 237 size_t nr_block_start, 238 size_t mr_block_size, 239 size_t nr_block_size); 240 241 XNN_PRIVATE void xnn_compute_hmp_igemm( 242 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 243 uint32_t uarch_index, 244 size_t mr_block_start, 245 size_t nr_block_start, 246 size_t mr_block_size, 247 size_t nr_block_size); 248 249 XNN_PRIVATE void xnn_compute_batch_hmp_igemm( 250 const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], 251 uint32_t uarch_index, 252 size_t batch_index, 253 size_t mr_block_start, 254 size_t nr_block_start, 255 size_t mr_block_size, 256 size_t nr_block_size); 257 #endif // XNN_MAX_UARCH_TYPES > 1 258 #endif 259 260 struct subgemm_context { 261 const struct subconvolution_params* subconvolution_params; 262 size_t kc; 263 const void* a; 264 size_t ax_stride; 265 size_t ay_stride; 266 size_t cx_stride; 267 size_t cy_stride; 268 size_t cn_stride; 269 size_t ga_stride; 270 size_t gw_stride; 271 size_t gc_stride; 272 size_t ba_stride; 273 size_t bc_stride; 274 uint32_t log2_csize; 275 struct xnn_hmp_gemm_ukernel ukernel; 276 union { 277 union xnn_qs8_gemm_params qs8; 278 union xnn_qu8_gemm_params qu8; 279 struct xnn_f16_scaleminmax_params f16; 280 union xnn_f32_minmax_params f32; 281 } params; 282 }; 283 284 #ifndef __cplusplus 285 XNN_PRIVATE void xnn_compute_grouped_subgemm2d( 286 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 287 size_t batch_index, 288 size_t group_index, 289 size_t subkernel_index, 290 size_t slice_y, 291 size_t slice_x_start, 292 size_t nr_block_start, 293 size_t slice_x_max, 294 size_t nr_block_size); 295 296 XNN_PRIVATE void xnn_compute_subgemm2d( 297 const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], 298 size_t batch_index, 299 size_t subkernel_index, 300 size_t slice_y, 301 size_t slice_x_start, 302 size_t nr_block_start, 303 size_t slice_x_max, 304 size_t nr_block_size); 305 #endif 306 307 struct subconv_context { 308 const struct subconvolution_params* subconvolution_params; 309 size_t kc; 310 size_t a_offset; 311 void* zero; 312 size_t cx_stride; 313 size_t cy_stride; 314 size_t cn_stride; 315 size_t ga_stride; 316 size_t gw_stride; 317 size_t gc_stride; 318 size_t ba_stride; 319 size_t bc_stride; 320 uint32_t log2_csize; 321 struct xnn_hmp_igemm_ukernel ukernel; 322 union { 323 union xnn_qs8_gemm_params qs8; 324 union xnn_qu8_gemm_params qu8; 325 struct xnn_f16_scaleminmax_params f16; 326 union xnn_f32_minmax_params f32; 327 } params; 328 }; 329 330 #ifndef __cplusplus 331 XNN_PRIVATE void xnn_compute_grouped_subconv2d( 332 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 333 size_t batch_index, 334 size_t group_index, 335 size_t subkernel_index, 336 size_t slice_y, 337 size_t slice_x_start, 338 size_t nr_block_start, 339 size_t slice_x_max, 340 size_t nr_block_size); 341 342 XNN_PRIVATE void xnn_compute_subconv2d( 343 const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], 344 size_t batch_index, 345 size_t subkernel_index, 346 size_t slice_y, 347 size_t slice_x_start, 348 size_t nr_block_start, 349 size_t slice_x_max, 350 size_t nr_block_size); 351 #endif 352 353 struct conv2d_context { 354 size_t input_height; 355 size_t input_width; 356 const void* input; 357 size_t input_batch_stride; 358 const void* zero; 359 const void* packed_weights; 360 void* output; 361 size_t output_batch_stride; 362 size_t input_padding_top; 363 size_t output_channels; 364 size_t output_height_stride; 365 size_t output_channel_stride; 366 union { 367 xnn_conv_hwc2chw_ukernel_function hwc2chw_ukernel; 368 }; 369 union { 370 union xnn_f32_minmax_params f32; 371 } params; 372 }; 373 374 #ifndef __cplusplus 375 XNN_PRIVATE void xnn_compute_conv2d_hwc2chw( 376 const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 377 size_t batch_index, 378 size_t output_y_start, 379 size_t output_y_slice); 380 #endif 381 382 struct dwconv_context { 383 const void** indirect_input; 384 size_t indirect_input_width_stride; 385 size_t indirect_input_height_stride; 386 size_t input_offset; 387 size_t input_batch_stride; 388 const void* packed_weights; 389 void* output; 390 size_t output_batch_stride; 391 size_t output_height_stride; 392 size_t output_width; 393 size_t groups; 394 const void* zero; 395 size_t output_increment; 396 union { 397 union xnn_qs8_gemm_params qs8; 398 union xnn_qu8_gemm_params qu8; 399 struct xnn_f16_minmax_params f16; 400 union xnn_f32_minmax_params f32; 401 } params; 402 union { 403 xnn_dwconv_unipass_ukernel_function unipass_ukernel; 404 }; 405 }; 406 407 #ifndef __cplusplus 408 XNN_PRIVATE void xnn_compute_dwconv_unipass( 409 const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], 410 size_t batch_index, 411 size_t output_y); 412 #endif 413 414 struct dwconv2d_context { 415 size_t input_height; 416 size_t input_width; 417 const void* input; 418 const void* zero; 419 uint32_t input_padding_top; 420 size_t input_channel_stride; 421 size_t input_batch_stride; 422 const void* packed_weights; 423 size_t weights_channel_stride; 424 void* output; 425 size_t output_channel_stride; 426 size_t output_batch_stride; 427 union { 428 union xnn_f32_chw_params f32; 429 } params; 430 union { 431 xnn_dwconv2d_chw_ukernel_function chw_ukernel; 432 }; 433 }; 434 435 #ifndef __cplusplus 436 XNN_PRIVATE void xnn_compute_dwconv2d_chw( 437 const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], 438 size_t batch_index, 439 size_t channel); 440 #endif 441 442 struct depthtospace2d_hwc_context { 443 size_t elements; 444 size_t input_width; 445 size_t block_size; 446 const void* input; 447 void* output; 448 size_t input_height_stride; 449 size_t input_width_stride; 450 size_t output_height_stride; 451 size_t output_width_stride; 452 xnn_univector_ukernel_function ukernel; 453 }; 454 455 #ifndef __cplusplus 456 XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_contiguous( 457 const struct depthtospace2d_hwc_context* context, 458 size_t batch_input_y, 459 size_t input_x, 460 size_t block_y); 461 462 XNN_PRIVATE void xnn_compute_depthtospace2d_hwc_strided( 463 const struct depthtospace2d_hwc_context* context, 464 size_t batch_input_y, 465 size_t input_x, 466 size_t block_y, 467 size_t block_x); 468 #endif 469 470 struct depthtospace2d_chw2hwc_context { 471 size_t output_channels; 472 size_t input_height; 473 size_t input_width; 474 uint32_t block_size; 475 const void* input; 476 void* output; 477 size_t input_batch_stride; 478 size_t output_batch_stride; 479 size_t output_channel_stride; 480 xnn_depthtospace2d_chw2hwc_ukernel_function ukernel; 481 }; 482 483 #ifndef __cplusplus 484 XNN_PRIVATE void xnn_compute_depthtospace2d_chw2hwc( 485 const struct depthtospace2d_chw2hwc_context* context, 486 size_t batch_index); 487 #endif 488 489 struct max_pooling_context { 490 const void** indirect_input; 491 size_t indirect_input_height_stride; 492 size_t input_offset; 493 size_t input_batch_stride; 494 void* output; 495 size_t output_batch_stride; 496 size_t output_height_stride; 497 size_t output_width; 498 size_t pooling_size; 499 size_t channels; 500 size_t input_increment; 501 size_t output_increment; 502 union { 503 union xnn_u8_minmax_params u8; 504 union xnn_f32_minmax_params f32; 505 } params; 506 xnn_maxpool_ukernel_function ukernel; 507 }; 508 509 #ifndef __cplusplus 510 XNN_PRIVATE void xnn_compute_max_pooling( 511 const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 512 size_t batch_index, 513 size_t output_y); 514 #endif 515 516 struct unpooling_context { 517 const void* input; 518 size_t input_height_stride; 519 size_t input_width_stride; 520 const uint32_t* index; 521 size_t index_height_stride; 522 size_t index_width_stride; 523 void** indirect_output; 524 size_t indirect_output_height_stride; 525 size_t indirect_output_width_stride; 526 size_t pooling_size; 527 size_t channels; 528 uint32_t fill_value; 529 xnn_unpool_ukernel_function ukernel; 530 }; 531 532 #ifndef __cplusplus 533 XNN_PRIVATE void xnn_compute_unpooling( 534 const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], 535 size_t input_y, 536 size_t input_x); 537 #endif 538 539 struct argmax_pooling_context { 540 const void** indirect_input; 541 size_t indirect_input_height_stride; 542 size_t input_offset; 543 size_t input_batch_stride; 544 void* output; 545 size_t output_batch_stride; 546 size_t output_height_stride; 547 size_t output_width; 548 uint32_t* index; 549 size_t index_batch_stride; 550 size_t index_height_stride; 551 size_t pooling_size; 552 size_t channels; 553 size_t input_increment; 554 size_t output_increment; 555 union { 556 xnn_argmaxpool_unipass_ukernel_function unipass_ukernel; 557 xnn_argmaxpool_multipass_ukernel_function multipass_ukernel; 558 }; 559 }; 560 561 #ifndef __cplusplus 562 XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( 563 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 564 size_t batch_index, 565 size_t output_y); 566 567 XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( 568 const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 569 size_t batch_index, 570 size_t output_y); 571 #endif 572 573 struct average_pooling_context { 574 const void** indirect_input; 575 size_t indirect_input_height_stride; 576 size_t input_offset; 577 size_t input_batch_stride; 578 void* output; 579 size_t output_batch_stride; 580 size_t output_height_stride; 581 size_t output_width; 582 size_t pooling_size; 583 size_t channels; 584 const void* zero; 585 size_t input_increment; 586 size_t output_increment; 587 union { 588 union xnn_qu8_avgpool_params qu8; 589 union xnn_f32_scaleminmax_params f32; 590 } params; 591 union { 592 xnn_avgpool_unipass_ukernel_function unipass_ukernel; 593 xnn_avgpool_multipass_ukernel_function multipass_ukernel; 594 }; 595 }; 596 597 #ifndef __cplusplus 598 XNN_PRIVATE void xnn_compute_average_pooling_unipass( 599 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 600 size_t batch_index, 601 size_t output_y); 602 603 XNN_PRIVATE void xnn_compute_average_pooling_multipass( 604 const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 605 size_t batch_index, 606 size_t output_y); 607 #endif 608 609 struct pixelwise_average_pooling_context { 610 const void** indirect_input; 611 size_t indirect_input_height_stride; 612 size_t input_offset; 613 size_t input_batch_stride; 614 const void* pixelwise_buffer; 615 size_t pixelwise_buffer_height_stride; 616 void* output; 617 size_t output_batch_stride; 618 size_t output_height_stride; 619 size_t output_width; 620 size_t pooling_size; 621 size_t channels; 622 const void* zero; 623 size_t input_increment; 624 size_t output_increment; 625 union { 626 union xnn_u8_minmax_params u8; 627 union xnn_f32_minmax_params f32; 628 } params; 629 union { 630 xnn_pavgpool_unipass_ukernel_function unipass_ukernel; 631 xnn_pavgpool_multipass_ukernel_function multipass_ukernel; 632 }; 633 }; 634 635 #ifndef __cplusplus 636 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( 637 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 638 size_t batch_index, 639 size_t output_y); 640 641 XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( 642 const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], 643 size_t batch_index, 644 size_t output_y); 645 #endif 646 647 struct global_average_pooling_nwc_context { 648 const void* input; 649 const void* zero; 650 size_t input_pixel_stride; 651 size_t input_batch_stride; 652 size_t input_elements; 653 size_t channels; 654 void* output; 655 size_t output_batch_stride; 656 union { 657 union xnn_qs8_avgpool_params qs8; 658 union xnn_qu8_avgpool_params qu8; 659 struct xnn_f16_scaleminmax_params f16; 660 union xnn_f32_scaleminmax_params f32; 661 } params; 662 union { 663 xnn_gavgpool_unipass_ukernel_function unipass_ukernel; 664 xnn_gavgpool_multipass_ukernel_function multipass_ukernel; 665 }; 666 }; 667 668 #ifndef __cplusplus 669 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass( 670 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 671 size_t batch_index); 672 673 XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass( 674 const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)], 675 size_t batch_index); 676 #endif 677 678 struct global_average_pooling_ncw_context { 679 size_t input_elements; 680 const void* input; 681 size_t input_channel_stride; 682 size_t input_batch_stride; 683 void* output; 684 size_t output_channel_stride; 685 size_t output_batch_stride; 686 xnn_gavgpool_cw_ukernel_function ukernel; 687 union { 688 union xnn_f32_gavgpool_params f32; 689 } params; 690 }; 691 692 #ifndef __cplusplus 693 XNN_PRIVATE void xnn_compute_global_average_pooling_ncw( 694 const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)], 695 size_t batch_index, 696 size_t channels_start, 697 size_t channels_slice); 698 #endif 699 700 struct resize_bilinear_context { 701 // Number of channels multiplied by sizeof(input element). 702 size_t scaled_channels; 703 // Indirection buffer with pointers related to rows of input pixels. 704 const void** indirect_input; 705 // Offset, in bytes, to be added to pointers in indirection buffer. 706 size_t input_offset; 707 // Stride, in bytes, between images of consecutive batches in the input. 708 size_t input_batch_stride; 709 // Packed pairs of (x, y) linear interpolation coefficients. 710 const void* packed_weights; 711 // Pointer to the output tensor. 712 void* output; 713 // Stride, in bytes, between adjacent pixels in the output. 714 size_t output_pixel_stride; 715 // Stride, in bytes, between images of consecutive batches in the output. 716 size_t output_batch_stride; 717 // log2(sizeof(weight element)). 718 uint32_t log2_wsize; 719 // Pointer to BILINEAR micro-kernel function. 720 xnn_ibilinear_ukernel_function ukernel; 721 }; 722 723 struct resize_bilinear_chw_context { 724 // Number of pixels per output image plane. 725 size_t output_pixels; 726 // Number of channels multiplied by sizeof(input element). 727 size_t channels; 728 // Stride, in bytes, between adjacent channels in the input. 729 size_t input_channel_stride; 730 // Indirection buffer with pointers related to rows of input pixels. 731 const void** indirect_input; 732 // Offset, in bytes, to be added to pointers in indirection buffer. 733 size_t input_offset; 734 // Stride, in bytes, between images of consecutive batches in the input. 735 size_t input_batch_stride; 736 // Packed pairs of (x, y) linear interpolation coefficients. 737 const void* packed_weights; 738 // Pointer to the output tensor. 739 void* output; 740 // Stride, in bytes, between images of consecutive batches in the output. 741 size_t output_batch_stride; 742 // Stride, in bytes, between consecutive channels of an output image. 743 size_t output_channel_stride; 744 // Pointer to BILINEAR micro-kernel function. 745 xnn_ibilinear_chw_ukernel_function ukernel; 746 }; 747 748 #ifndef __cplusplus 749 XNN_PRIVATE void xnn_compute_resize_bilinear( 750 const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], 751 size_t batch_index, 752 size_t pixel_start, 753 size_t pixel_range); 754 XNN_PRIVATE void xnn_compute_resize_bilinear_chw( 755 const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], 756 size_t batch_index, 757 size_t pixel_start, 758 size_t pixel_range); 759 #endif 760 761 struct elementwise_binary_context { 762 const void* a; 763 size_t a_stride[XNN_MAX_TENSOR_DIMS - 1]; 764 const void* b; 765 size_t b_stride[XNN_MAX_TENSOR_DIMS - 1]; 766 void* y; 767 size_t y_stride[XNN_MAX_TENSOR_DIMS - 1]; 768 size_t elements; 769 union { 770 union xnn_qs8_add_params qs8; 771 union xnn_qu8_add_params qu8; 772 struct xnn_f16_minmax_params f16; 773 union xnn_f32_minmax_params f32; 774 } params; 775 xnn_vbinary_ukernel_function ukernel; 776 }; 777 778 #ifndef __cplusplus 779 XNN_PRIVATE void xnn_compute_elementwise_binary_5d( 780 const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], 781 size_t i, size_t j, size_t k, size_t l, size_t m); 782 #endif 783 784 struct channel_shuffle_context { 785 const void* x; 786 size_t x_stride; 787 void* y; 788 size_t y_stride; 789 size_t n; 790 size_t m; 791 union { 792 xnn_zipc_ukernel_function fixed_ukernel; 793 xnn_zipv_ukernel_function variable_ukernel; 794 }; 795 }; 796 797 #ifndef __cplusplus 798 XNN_PRIVATE void xnn_compute_channel_shuffle_fixed( 799 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 800 size_t index); 801 802 XNN_PRIVATE void xnn_compute_channel_shuffle_variable( 803 const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], 804 size_t index); 805 #endif 806 807 struct lut_strided_context { 808 size_t n; 809 const void* x; 810 size_t x_stride; 811 const void* t; 812 void* y; 813 size_t y_stride; 814 xnn_x8_lut_ukernel_function ukernel; 815 }; 816 817 #ifndef __cplusplus 818 XNN_PRIVATE void xnn_compute_lut_strided( 819 const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 820 size_t batch_index); 821 #endif 822 823 struct lut_contiguous_context { 824 const void* x; 825 size_t x_stride; 826 const void* t; 827 void* y; 828 size_t y_stride; 829 xnn_x8_lut_ukernel_function ukernel; 830 }; 831 832 #ifndef __cplusplus 833 XNN_PRIVATE void xnn_compute_lut_contiguous( 834 const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 835 size_t offset, 836 size_t size); 837 #endif 838 839 struct univector_strided_context { 840 size_t n; 841 const void* x; 842 size_t x_stride; 843 void* y; 844 size_t y_stride; 845 xnn_univector_ukernel_function ukernel; 846 union { 847 union xnn_u8_minmax_params u8_output; 848 union xnn_f32_minmax_params f32_output; 849 union xnn_f32_hswish_params f32_hswish; 850 } params; 851 }; 852 853 #ifndef __cplusplus 854 XNN_PRIVATE void xnn_compute_univector_strided( 855 const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], 856 size_t batch_index, 857 size_t batch_range); 858 #endif 859 860 struct univector_contiguous_context { 861 const void* x; 862 size_t x_stride; 863 void* y; 864 size_t y_stride; 865 xnn_univector_ukernel_function ukernel; 866 union { 867 union xnn_u8_minmax_params u8_output; 868 union xnn_f32_minmax_params f32_output; 869 union xnn_f32_hswish_params f32_hswish; 870 } params; 871 }; 872 873 #ifndef __cplusplus 874 XNN_PRIVATE void xnn_compute_univector_contiguous( 875 const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], 876 size_t offset, 877 size_t size); 878 #endif 879 880 struct prelu_context { 881 size_t n; 882 const void* x; 883 size_t x_stride; 884 const void* w; 885 void* y; 886 size_t y_stride; 887 xnn_prelu_ukernel_function ukernel; 888 }; 889 890 #ifndef __cplusplus 891 XNN_PRIVATE void xnn_compute_prelu( 892 const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)], 893 size_t batch_start, 894 size_t batch_range); 895 #endif 896 897 struct vmulcaddc_context { 898 size_t n; 899 const void* x; 900 size_t x_stride; 901 const void* w; 902 void* y; 903 size_t y_stride; 904 xnn_vmulcaddc_ukernel_function ukernel; 905 union { 906 struct xnn_f16_minmax_params f16; 907 union xnn_f32_minmax_params f32; 908 } params; 909 }; 910 911 #ifndef __cplusplus 912 XNN_PRIVATE void xnn_compute_vmulcaddc( 913 const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], 914 size_t batch_start, 915 size_t batch_size); 916 #endif 917 918 struct pad_context { 919 const void* input; 920 size_t input_stride[XNN_MAX_TENSOR_DIMS - 1]; 921 void* output; 922 size_t output_stride[XNN_MAX_TENSOR_DIMS - 1]; 923 size_t pre_paddings[XNN_MAX_TENSOR_DIMS]; 924 size_t post_paddings[1]; 925 size_t input_size[XNN_MAX_TENSOR_DIMS]; 926 size_t output_size[1]; 927 uint32_t padding_value; 928 xnn_pad_ukernel_function pad_ukernel; 929 xnn_fill_ukernel_function fill_ukernel; 930 }; 931 932 #ifndef __cplusplus 933 XNN_PRIVATE void xnn_compute_pad_5d( 934 const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], 935 size_t i, size_t j, size_t k, size_t l, size_t m); 936 #endif 937 938 struct u8_softmax_context { 939 size_t n; 940 const uint8_t* x; 941 size_t x_stride; 942 const uint32_t* t; 943 uint8_t* y; 944 size_t y_stride; 945 xnn_u8_rmax_ukernel_function rmax_ukernel; 946 xnn_u8_lut32norm_ukernel_function lut_norm_ukernel; 947 }; 948 949 #ifndef __cplusplus 950 XNN_PRIVATE void xnn_compute_u8_softmax( 951 const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 952 size_t batch_index); 953 #endif 954 955 struct f32_three_pass_softmax_context { 956 size_t n; 957 const void* x; 958 size_t x_stride; 959 void* y; 960 size_t y_stride; 961 xnn_f32_rmax_ukernel_function rmax_ukernel; 962 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel; 963 xnn_vbinary_ukernel_function vmulc_ukernel; 964 union xnn_f32_minmax_params params; 965 }; 966 967 #ifndef __cplusplus 968 XNN_PRIVATE void xnn_compute_f32_three_pass_softmax( 969 const struct f32_three_pass_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], 970 size_t batch_index); 971 #endif 972