1 // Copyright 2022 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <xnnpack/common.h> 12 #include <xnnpack/microparams.h> 13 14 15 /****************** Microkernel pointers for dense inference *****************/ 16 17 // CONV-HWC: direct CONVolution in HWC layout 18 19 typedef void (*xnn_conv_hwc_ukernel_function)( 20 size_t input_height, 21 size_t input_width, 22 size_t output_y_start, 23 size_t output_y_end, 24 const void* input, 25 const void* zero, 26 const void* weights, 27 void* output, 28 size_t input_padding_top, 29 size_t output_channels, 30 size_t output_height_stride, 31 size_t output_width_stride, 32 const void* params); 33 34 typedef void (*xnn_f32_conv_hwc_ukernel_function)( 35 size_t input_height, 36 size_t input_width, 37 size_t output_y_start, 38 size_t output_y_end, 39 const float* input, 40 const float* zero, 41 const float* weights, 42 float* output, 43 size_t input_padding_top, 44 size_t output_channels, 45 size_t output_height_stride, 46 size_t output_width_stride, 47 const union xnn_f32_minmax_params* params); 48 49 // GEMM: GEneral Matrix Multiplication without activations 50 51 typedef void (*xnn_gemm_ukernel_function)( 52 size_t mr, 53 size_t nr, 54 size_t k, 55 const void* a, 56 size_t a_stride, 57 const void* w, 58 void* c, 59 size_t cm_stride, 60 size_t cn_stride, 61 const void* params); 62 63 typedef void (*xnn_f32_gemm_ukernel_function)( 64 size_t mr, 65 size_t nr, 66 size_t k, 67 const float* a, 68 size_t a_stride, 69 const float* w, 70 float* c, 71 size_t cm_stride, 72 size_t cn_stride, 73 const union xnn_f32_default_params* params); 74 75 // GEMM: GEneral Matrix Multiplication with ReLU activation 76 77 typedef void (*xnn_f32_gemm_relu_ukernel_function)( 78 size_t mr, 79 size_t nr, 80 size_t k, 81 const float* a, 82 size_t a_stride, 83 const float* w, 84 float* c, 85 size_t cm_stride, 86 size_t cn_stride, 87 const union xnn_f32_relu_params* params); 88 89 // GEMM: GEneral Matrix Multiplication with Min+Max activation 90 91 typedef void (*xnn_bf16_gemm_minmax_ukernel_function)( 92 size_t mr, 93 size_t nr, 94 size_t k, 95 const void* a, 96 size_t a_stride, 97 const void* w, 98 void* c, 99 size_t cm_stride, 100 size_t cn_stride, 101 const union xnn_bf16_minmax_params* params); 102 103 typedef void (*xnn_f16_gemm_minmax_ukernel_function)( 104 size_t mr, 105 size_t nr, 106 size_t k, 107 const void* a, 108 size_t a_stride, 109 const void* w, 110 void* c, 111 size_t cm_stride, 112 size_t cn_stride, 113 const union xnn_f16_minmax_params* params); 114 115 typedef void (*xnn_f32_gemm_minmax_ukernel_function)( 116 size_t mr, 117 size_t nr, 118 size_t k, 119 const float* a, 120 size_t a_stride, 121 const float* w, 122 float* c, 123 size_t cm_stride, 124 size_t cn_stride, 125 const union xnn_f32_minmax_params* params); 126 127 typedef void (*xnn_qc8_gemm_minmax_ukernel_function)( 128 size_t mr, 129 size_t nr, 130 size_t k, 131 const int8_t* a, 132 size_t a_stride, 133 const void* w, 134 int8_t* c, 135 size_t cm_stride, 136 size_t cn_stride, 137 const union xnn_qc8_conv_minmax_params* params); 138 139 typedef void (*xnn_qs8_gemm_minmax_ukernel_function)( 140 size_t mr, 141 size_t nr, 142 size_t k, 143 const int8_t* a, 144 size_t a_stride, 145 const void* w, 146 int8_t* c, 147 size_t cm_stride, 148 size_t cn_stride, 149 const union xnn_qs8_conv_minmax_params* params); 150 151 typedef void (*xnn_qu8_gemm_minmax_ukernel_function)( 152 size_t mr, 153 size_t nr, 154 size_t k, 155 const uint8_t* a, 156 size_t a_stride, 157 const void* w, 158 uint8_t* c, 159 size_t cm_stride, 160 size_t cn_stride, 161 const union xnn_qu8_conv_minmax_params* params); 162 163 // GEMMINC: GEMM INCremental with Min+Max activation 164 165 typedef void (*xnn_f32_gemminc_minmax_ukernel_function)( 166 size_t mr, 167 size_t nr, 168 size_t k, 169 const float* a, 170 size_t a_stride, 171 const float* w, 172 float* c, 173 size_t cm_stride, 174 size_t cn_stride, 175 const float* acc, 176 const union xnn_f32_minmax_params* params); 177 178 // IGEMM: Indirect GEMM without activation 179 180 typedef void (*xnn_igemm_ukernel_function)( 181 size_t mr, 182 size_t nr, 183 size_t kc, 184 size_t ks, 185 const void** a, 186 const void* w, 187 void* c, 188 size_t cm_stride, 189 size_t cn_stride, 190 size_t a_offset, 191 const void* zero, 192 const void* params); 193 194 typedef void (*xnn_f32_igemm_ukernel_function)( 195 size_t mr, 196 size_t nr, 197 size_t kc, 198 size_t ks, 199 const float** a, 200 const float* w, 201 float* c, 202 size_t cm_stride, 203 size_t cn_stride, 204 size_t a_offset, 205 const float* zero, 206 const union xnn_f32_default_params* params); 207 208 // IGEMM: Indirect GEMM with ReLU activation 209 210 typedef void (*xnn_f32_igemm_relu_ukernel_function)( 211 size_t mr, 212 size_t nr, 213 size_t kc, 214 size_t ks, 215 const float** a, 216 const float* w, 217 float* c, 218 size_t cm_stride, 219 size_t cn_stride, 220 size_t a_offset, 221 const float* zero, 222 const union xnn_f32_relu_params* params); 223 224 // IGEMM: Indirect GEMM with Min+Max activation 225 226 typedef void (*xnn_f16_igemm_minmax_ukernel_function)( 227 size_t mr, 228 size_t nr, 229 size_t kc, 230 size_t ks, 231 const void** a, 232 const void* w, 233 void* c, 234 size_t cm_stride, 235 size_t cn_stride, 236 size_t a_offset, 237 const void* zero, 238 const union xnn_f16_minmax_params* params); 239 240 typedef void (*xnn_f32_igemm_minmax_ukernel_function)( 241 size_t mr, 242 size_t nr, 243 size_t kc, 244 size_t ks, 245 const float** a, 246 const float* w, 247 float* c, 248 size_t cm_stride, 249 size_t cn_stride, 250 size_t a_offset, 251 const float* zero, 252 const union xnn_f32_minmax_params* params); 253 254 typedef void (*xnn_qc8_igemm_minmax_ukernel_function)( 255 size_t mr, 256 size_t nr, 257 size_t kc, 258 size_t ks, 259 const int8_t** a, 260 const void* w, 261 int8_t* c, 262 size_t cm_stride, 263 size_t cn_stride, 264 size_t a_offset, 265 const int8_t* zero, 266 const union xnn_qc8_conv_minmax_params* params); 267 268 typedef void (*xnn_qs8_igemm_minmax_ukernel_function)( 269 size_t mr, 270 size_t nr, 271 size_t kc, 272 size_t ks, 273 const int8_t** a, 274 const void* w, 275 int8_t* c, 276 size_t cm_stride, 277 size_t cn_stride, 278 size_t a_offset, 279 const int8_t* zero, 280 const union xnn_qs8_conv_minmax_params* params); 281 282 typedef void (*xnn_qu8_igemm_minmax_ukernel_function)( 283 size_t mr, 284 size_t nr, 285 size_t kc, 286 size_t ks, 287 const uint8_t** a, 288 const void* w, 289 uint8_t* c, 290 size_t cm_stride, 291 size_t cn_stride, 292 size_t a_offset, 293 const uint8_t* zero, 294 const union xnn_qu8_conv_minmax_params* params); 295 296 // PPMM: Pre-Packed Matrix Multiplication) 297 298 typedef void (*xnn_ppmm_ukernel_function)( 299 size_t mr, 300 size_t nc, 301 size_t kc, 302 const void* a, 303 const void* w, 304 void* c, 305 size_t cm_stride, 306 size_t cn_stride, 307 const void* params); 308 309 typedef void (*xnn_f16_ppmm_ukernel_function)( 310 size_t mr, 311 size_t nc, 312 size_t kc, 313 const void* a, 314 const void* w, 315 void* c, 316 size_t cm_stride, 317 size_t cn_stride, 318 const union xnn_f16_minmax_params* params); 319 320 typedef void (*xnn_f32_ppmm_minmax_ukernel_function)( 321 size_t mr, 322 size_t nc, 323 size_t kc, 324 const float* a, 325 const float* w, 326 float* c, 327 size_t cm_stride, 328 size_t cn_stride, 329 const union xnn_f32_minmax_params* params); 330 331 // DWCONV: DepthWise CONVolution single-pass without activation 332 333 typedef void (*xnn_dwconv_unipass_ukernel_function)( 334 size_t channels, 335 size_t output_width, 336 const void** input, 337 const void* weights, 338 void* output, 339 size_t input_stride, 340 size_t output_increment, 341 size_t input_offset, 342 const void* zero, 343 const void* params); 344 345 typedef void (*xnn_f32_dwconv_unipass_ukernel_function)( 346 size_t channels, 347 size_t output_width, 348 const float** input, 349 const float* weights, 350 float* output, 351 size_t input_stride, 352 size_t output_increment, 353 size_t input_offset, 354 const float* zero, 355 const union xnn_f32_default_params* params); 356 357 // DWCONV: DepthWise CONVolution single-pass with Min+Max activation 358 359 typedef void (*xnn_f16_dwconv_minmax_unipass_ukernel_function)( 360 size_t channels, 361 size_t output_width, 362 const void** input, 363 const void* weights, 364 void* output, 365 size_t input_stride, 366 size_t output_increment, 367 size_t input_offset, 368 const void* zero, 369 const union xnn_f16_minmax_params* params); 370 371 typedef void (*xnn_f32_dwconv_minmax_unipass_ukernel_function)( 372 size_t channels, 373 size_t output_width, 374 const float** input, 375 const float* weights, 376 float* output, 377 size_t input_stride, 378 size_t output_increment, 379 size_t input_offset, 380 const float* zero, 381 const union xnn_f32_minmax_params* params); 382 383 typedef void (*xnn_qc8_dwconv_minmax_unipass_ukernel_function)( 384 size_t channels, 385 size_t output_width, 386 const int8_t** input, 387 const void* weights, 388 int8_t* output, 389 size_t input_stride, 390 size_t output_increment, 391 size_t input_offset, 392 const int8_t* zero, 393 const union xnn_qc8_conv_minmax_params* params); 394 395 typedef void (*xnn_qs8_dwconv_minmax_unipass_ukernel_function)( 396 size_t channels, 397 size_t output_width, 398 const int8_t** input, 399 const void* weights, 400 int8_t* output, 401 size_t input_stride, 402 size_t output_increment, 403 size_t input_offset, 404 const int8_t* zero, 405 const union xnn_qs8_conv_minmax_params* params); 406 407 typedef void (*xnn_qu8_dwconv_minmax_unipass_ukernel_function)( 408 size_t channels, 409 size_t output_width, 410 const uint8_t** input, 411 const void* weights, 412 uint8_t* output, 413 size_t input_stride, 414 size_t output_increment, 415 size_t input_offset, 416 const uint8_t* zero, 417 const union xnn_qu8_conv_minmax_params* params); 418 419 // DWCONV: DepthWise CONVolution multi-pass without activation 420 421 typedef void (*xnn_dwconv_multipass_ukernel_function)( 422 size_t channels, 423 size_t output_width, 424 const void** input, 425 const void* weights, 426 void* buffer, 427 void* output, 428 size_t input_stride, 429 size_t output_increment, 430 size_t input_offset, 431 const void* zero, 432 const void* params); 433 434 // VMULCADDC: Vector MULtiply-by-Constant, ADD-Constant 435 436 typedef void (*xnn_vmulcaddc_ukernel_function)( 437 size_t batch, 438 size_t channels, 439 const void* input, 440 size_t input_stride, 441 const void* weights, 442 void* output, 443 size_t output_stride, 444 const void* params); 445 446 typedef void (*xnn_f16_vmulcaddc_ukernel_function)( 447 size_t batch, 448 size_t channels, 449 const void* input, 450 size_t input_stride, 451 const void* weights, 452 void* output, 453 size_t output_stride, 454 const union xnn_f16_minmax_params* params); 455 456 typedef void (*xnn_f32_vmulcaddc_ukernel_function)( 457 size_t batch, 458 size_t channels, 459 const float* input, 460 size_t input_stride, 461 const float* weights, 462 float* output, 463 size_t output_stride, 464 const union xnn_f32_minmax_params* params); 465 466 // PRELU: Parametric RELU 467 468 typedef void (*xnn_prelu_ukernel_function)( 469 size_t batch, 470 size_t channels, 471 const void* input, 472 size_t input_stride, 473 const void* weights, 474 void* output, 475 size_t output_stride); 476 477 typedef void (*xnn_f16_prelu_ukernel_function)( 478 size_t batch, 479 size_t channels, 480 const void* input, 481 size_t input_stride, 482 const void* weights, 483 void* output, 484 size_t output_stride); 485 486 typedef void (*xnn_f32_prelu_ukernel_function)( 487 size_t batch, 488 size_t channels, 489 const float* input, 490 size_t input_stride, 491 const float* weights, 492 float* output, 493 size_t output_stride); 494 495 // IBILINEAR: Indirect BILINEAR interpolation 496 497 typedef void (*xnn_ibilinear_ukernel_function)( 498 size_t output_pixels, 499 size_t channels, 500 const void** input, 501 size_t input_offset, 502 const void* weights, 503 void* output, 504 size_t output_increment); 505 506 typedef void (*xnn_f16_ibilinear_ukernel_function)( 507 size_t output_pixels, 508 size_t channels, 509 const void** input, 510 size_t input_offset, 511 const void* weights, 512 void* output, 513 size_t output_increment); 514 515 typedef void (*xnn_f32_ibilinear_ukernel_function)( 516 size_t output_pixels, 517 size_t channels, 518 const float** input, 519 size_t input_offset, 520 const float* weights, 521 float* output, 522 size_t output_increment); 523 524 typedef void (*xnn_s8_ibilinear_ukernel_function)( 525 size_t output_pixels, 526 size_t channels, 527 const int8_t** input, 528 size_t input_offset, 529 const int16_t* weights, 530 int8_t* output, 531 size_t output_increment); 532 533 typedef void (*xnn_u8_ibilinear_ukernel_function)( 534 size_t output_pixels, 535 size_t channels, 536 const uint8_t** input, 537 size_t input_offset, 538 const int16_t* weights, 539 uint8_t* output, 540 size_t output_increment); 541 542 // GAVGPOOL: Global AVeraGe POOLing single-pass 543 544 typedef void (*xnn_gavgpool_unipass_ukernel_function)( 545 size_t rows, 546 size_t channels, 547 const void* input, 548 size_t input_stride, 549 const void* zero, 550 void* output, 551 const void* params); 552 553 typedef void (*xnn_f16_gavgpool_minmax_unipass_ukernel_function)( 554 size_t rows, 555 size_t channels, 556 const void* input, 557 size_t input_stride, 558 const void* zero, 559 void* output, 560 const union xnn_f16_scaleminmax_params* params); 561 562 typedef void (*xnn_f32_gavgpool_minmax_unipass_ukernel_function)( 563 size_t rows, 564 size_t channels, 565 const float* input, 566 size_t input_stride, 567 const float* zero, 568 float* output, 569 const union xnn_f32_scaleminmax_params* params); 570 571 typedef void (*xnn_qs8_gavgpool_minmax_unipass_ukernel_function)( 572 size_t rows, 573 size_t channels, 574 const int8_t* input, 575 size_t input_stride, 576 const int8_t* zero, 577 int8_t* output, 578 const union xnn_qs8_avgpool_minmax_params* params); 579 580 typedef void (*xnn_qu8_gavgpool_minmax_unipass_ukernel_function)( 581 size_t rows, 582 size_t channels, 583 const uint8_t* input, 584 size_t input_stride, 585 const uint8_t* zero, 586 uint8_t* output, 587 const union xnn_qu8_avgpool_minmax_params* params); 588 589 // GAVGPOOL: Global AVeraGe POOLing multi-pass 590 591 typedef void (*xnn_gavgpool_multipass_ukernel_function)( 592 size_t rows, 593 size_t channels, 594 const void* input, 595 size_t input_stride, 596 const void* zero, 597 void* buffer, 598 void* output, 599 const void* params); 600 601 typedef void (*xnn_f16_gavgpool_minmax_multipass_ukernel_function)( 602 size_t rows, 603 size_t channels, 604 const void* input, 605 size_t input_stride, 606 const void* zero, 607 void* buffer, 608 void* output, 609 const union xnn_f16_scaleminmax_params* params); 610 611 typedef void (*xnn_f32_gavgpool_minmax_multipass_ukernel_function)( 612 size_t rows, 613 size_t channels, 614 const float* input, 615 size_t input_stride, 616 const float* zero, 617 float* buffer, 618 float* output, 619 const union xnn_f32_scaleminmax_params* params); 620 621 typedef void (*xnn_qs8_gavgpool_minmax_multipass_ukernel_function)( 622 size_t rows, 623 size_t channels, 624 const int8_t* input, 625 size_t input_stride, 626 const int8_t* zero, 627 int32_t* buffer, 628 int8_t* output, 629 const union xnn_qs8_avgpool_minmax_params* params); 630 631 typedef void (*xnn_qu8_gavgpool_minmax_multipass_ukernel_function)( 632 size_t rows, 633 size_t channels, 634 const uint8_t* input, 635 size_t input_stride, 636 const uint8_t* zero, 637 int32_t* buffer, 638 uint8_t* output, 639 const union xnn_qu8_avgpool_minmax_params* params); 640 641 // AVGPOOL: AVeraGe POOLing single-pass 642 643 typedef void (*xnn_avgpool_unipass_ukernel_function)( 644 size_t output_pixels, 645 size_t kernel_elements, 646 size_t channels, 647 const void** input, 648 size_t input_offset, 649 const void* zero, 650 void* output, 651 size_t input_increment, 652 size_t output_increment, 653 const void* params); 654 655 typedef void (*xnn_f16_avgpool_minmax_unipass_ukernel_function)( 656 size_t output_pixels, 657 size_t kernel_elements, 658 size_t channels, 659 const void** input, 660 size_t input_offset, 661 const void* zero, 662 void* output, 663 size_t input_increment, 664 size_t output_increment, 665 const union xnn_f16_scaleminmax_params* params); 666 667 typedef void (*xnn_f32_avgpool_minmax_unipass_ukernel_function)( 668 size_t output_pixels, 669 size_t kernel_elements, 670 size_t channels, 671 const float** input, 672 size_t input_offset, 673 const float* zero, 674 float* output, 675 size_t input_increment, 676 size_t output_increment, 677 const union xnn_f32_scaleminmax_params* params); 678 679 typedef void (*xnn_qu8_avgpool_minmax_unipass_ukernel_function)( 680 size_t output_pixels, 681 size_t kernel_elements, 682 size_t channels, 683 const uint8_t** input, 684 size_t input_offset, 685 const uint8_t* zero, 686 uint8_t* output, 687 size_t input_increment, 688 size_t output_increment, 689 const union xnn_qu8_avgpool_minmax_params* params); 690 691 // AVGPOOL: AVeraGe POOLing multi-pass 692 693 typedef void (*xnn_avgpool_multipass_ukernel_function)( 694 size_t output_pixels, 695 size_t kernel_elements, 696 size_t channels, 697 const void** input, 698 size_t input_offset, 699 const void* zero, 700 void* buffer, 701 void* output, 702 size_t input_increment, 703 size_t output_increment, 704 const void* params); 705 706 typedef void (*xnn_f16_avgpool_minmax_multipass_ukernel_function)( 707 size_t output_pixels, 708 size_t kernel_elements, 709 size_t channels, 710 const void** input, 711 size_t input_offset, 712 const void* zero, 713 void* buffer, 714 void* output, 715 size_t input_increment, 716 size_t output_increment, 717 const union xnn_f16_scaleminmax_params* params); 718 719 typedef void (*xnn_f32_avgpool_minmax_multipass_ukernel_function)( 720 size_t output_pixels, 721 size_t kernel_elements, 722 size_t channels, 723 const float** input, 724 size_t input_offset, 725 const float* zero, 726 float* buffer, 727 float* output, 728 size_t input_increment, 729 size_t output_increment, 730 const union xnn_f32_scaleminmax_params* params); 731 732 typedef void (*xnn_qu8_avgpool_minmax_multipass_ukernel_function)( 733 size_t output_pixels, 734 size_t kernel_elements, 735 size_t channels, 736 const uint8_t** input, 737 size_t input_offset, 738 const uint8_t* zero, 739 int32_t* buffer, 740 uint8_t* output, 741 size_t input_increment, 742 size_t output_increment, 743 const union xnn_qu8_avgpool_minmax_params* params); 744 745 // PAVGPOOL: Pixelwise AVeraGe POOLing single-pass 746 747 typedef void (*xnn_pavgpool_unipass_ukernel_function)( 748 size_t output_pixels, 749 size_t kernel_elements, 750 size_t channels, 751 const void** input, 752 size_t input_offset, 753 const void* zero, 754 const void* multiplier, 755 void* output, 756 size_t input_increment, 757 size_t output_increment, 758 const void* params); 759 760 typedef void (*xnn_f16_pavgpool_minmax_unipass_ukernel_function)( 761 size_t output_pixels, 762 size_t kernel_elements, 763 size_t channels, 764 const void** input, 765 size_t input_offset, 766 const void* zero, 767 const void* multiplier, 768 void* output, 769 size_t input_increment, 770 size_t output_increment, 771 const union xnn_f16_minmax_params* params); 772 773 typedef void (*xnn_f32_pavgpool_minmax_unipass_ukernel_function)( 774 size_t output_pixels, 775 size_t kernel_elements, 776 size_t channels, 777 const float** input, 778 size_t input_offset, 779 const float* zero, 780 const float* multiplier, 781 float* output, 782 size_t input_increment, 783 size_t output_increment, 784 const union xnn_f32_minmax_params* params); 785 786 // PAVGPOOL: Pixelwise AVeraGe POOLing multi-pass 787 788 typedef void (*xnn_pavgpool_multipass_ukernel_function)( 789 size_t output_pixels, 790 size_t kernel_elements, 791 size_t channels, 792 const void** input, 793 size_t input_offset, 794 const void* zero, 795 const void* multiplier, 796 void* buffer, 797 void* output, 798 size_t input_increment, 799 size_t output_increment, 800 const void* params); 801 802 typedef void (*xnn_f16_pavgpool_minmax_multipass_ukernel_function)( 803 size_t output_pixels, 804 size_t kernel_elements, 805 size_t channels, 806 const void** input, 807 size_t input_offset, 808 const void* zero, 809 const void* multiplier, 810 void* buffer, 811 void* output, 812 size_t input_increment, 813 size_t output_increment, 814 const union xnn_f16_minmax_params* params); 815 816 typedef void (*xnn_f32_pavgpool_minmax_multipass_ukernel_function)( 817 size_t output_pixels, 818 size_t kernel_elements, 819 size_t channels, 820 const float** input, 821 size_t input_offset, 822 const float* zero, 823 const float* multiplier, 824 float* buffer, 825 float* output, 826 size_t input_increment, 827 size_t output_increment, 828 const union xnn_f32_minmax_params* params); 829 830 // MAXPOOL: MAX POOLing 831 832 typedef void (*xnn_maxpool_ukernel_function)( 833 size_t output_pixels, 834 size_t kernel_elements, 835 size_t channels, 836 const void** input, 837 size_t input_offset, 838 void* output, 839 size_t input_increment, 840 size_t output_increment, 841 const void* params); 842 843 typedef void (*xnn_f16_maxpool_ukernel_function)( 844 size_t output_pixels, 845 size_t kernel_elements, 846 size_t channels, 847 const void** input, 848 size_t input_offset, 849 void* output, 850 size_t input_increment, 851 size_t output_increment, 852 const union xnn_f16_minmax_params* params); 853 854 typedef void (*xnn_f32_maxpool_ukernel_function)( 855 size_t output_pixels, 856 size_t kernel_elements, 857 size_t channels, 858 const float** input, 859 size_t input_offset, 860 float* output, 861 size_t input_increment, 862 size_t output_increment, 863 const union xnn_f32_minmax_params* params); 864 865 typedef void (*xnn_s8_maxpool_ukernel_function)( 866 size_t output_pixels, 867 size_t kernel_elements, 868 size_t channels, 869 const int8_t** input, 870 size_t input_offset, 871 int8_t* output, 872 size_t input_increment, 873 size_t output_increment, 874 const union xnn_s8_minmax_params* params); 875 876 typedef void (*xnn_u8_maxpool_ukernel_function)( 877 size_t output_pixels, 878 size_t kernel_elements, 879 size_t channels, 880 const uint8_t** input, 881 size_t input_offset, 882 uint8_t* output, 883 size_t input_increment, 884 size_t output_increment, 885 const union xnn_u8_minmax_params* params); 886 887 // ARGMAXPOOL: ARG MAX POOLing single-pass 888 889 typedef void (*xnn_argmaxpool_unipass_ukernel_function)( 890 size_t output_pixels, 891 size_t kernel_elements, 892 size_t channels, 893 const void** input, 894 size_t input_offset, 895 void* output, 896 uint32_t* index, 897 size_t input_increment, 898 size_t output_increment); 899 900 typedef void (*xnn_f32_argmaxpool_unipass_ukernel_function)( 901 size_t output_pixels, 902 size_t kernel_elements, 903 size_t channels, 904 const float** input, 905 size_t input_offset, 906 float* output, 907 uint32_t* index, 908 size_t input_increment, 909 size_t output_increment); 910 911 // ARGMAXPOOL: ARG MAX POOLing multi-pass 912 913 typedef void (*xnn_argmaxpool_multipass_ukernel_function)( 914 size_t output_pixels, 915 size_t kernel_elements, 916 size_t channels, 917 const void** input, 918 size_t input_offset, 919 void* accumulation_buffer, 920 uint32_t* index_buffer, 921 void* output, 922 uint32_t* index, 923 size_t input_increment, 924 size_t output_increment); 925 926 typedef void (*xnn_f32_argmaxpool_multipass_ukernel_function)( 927 size_t output_pixels, 928 size_t kernel_elements, 929 size_t channels, 930 const float** input, 931 size_t input_offset, 932 float* accumulation_buffer, 933 uint32_t* index_buffer, 934 float* output, 935 uint32_t* index, 936 size_t input_increment, 937 size_t output_increment); 938 939 // UNPOOL: UNPOOLing 940 941 typedef void (*xnn_unpool_ukernel_function)( 942 size_t p, 943 size_t c, 944 uint32_t f, 945 const void* input, 946 const uint32_t* index, 947 void** output); 948 949 typedef void (*xnn_x32_unpool_ukernel_function)( 950 size_t p, 951 size_t c, 952 uint32_t f, 953 const uint32_t* input, 954 const uint32_t* index, 955 uint32_t** output); 956 957 // TRANSPOSEC: TRANSPOSE Constant-size elements 958 959 typedef void (*xnn_x8_transposec_ukernel_function)( 960 const uint8_t* a, 961 uint8_t* b, 962 size_t input_stride, 963 size_t output_stride, 964 size_t block_width, 965 size_t block_height); 966 967 typedef void (*xnn_x16_transposec_ukernel_function)( 968 const uint16_t* a, 969 uint16_t* b, 970 size_t input_stride, 971 size_t output_stride, 972 size_t block_width, 973 size_t block_height); 974 975 typedef void (*xnn_x24_transposec_ukernel_function)( 976 const void* a, 977 void* b, 978 size_t input_stride, 979 size_t output_stride, 980 size_t block_width, 981 size_t block_height); 982 983 typedef void (*xnn_x32_transposec_ukernel_function)( 984 const uint32_t* a, 985 uint32_t* b, 986 size_t input_stride, 987 size_t output_stride, 988 size_t block_width, 989 size_t block_height); 990 991 typedef void (*xnn_x64_transposec_ukernel_function)( 992 const uint64_t* a, 993 uint64_t* b, 994 size_t input_stride, 995 size_t output_stride, 996 size_t block_width, 997 size_t block_height); 998 999 typedef void (*xnn_transposec_ukernel_function)( 1000 const void* input, 1001 void* output, 1002 size_t input_stride, 1003 size_t output_size, 1004 size_t block_width, 1005 size_t block_height); 1006 1007 // TRANSPOSEV: TRANSPOSE Variable-size elements 1008 1009 typedef void (*xnn_transposev_ukernel_function)( 1010 const void* input, 1011 void* output, 1012 size_t input_row_stride, 1013 size_t output_row_stride, 1014 size_t input_element_stride, 1015 size_t output_element_stride, 1016 size_t element_size, 1017 size_t block_width, 1018 size_t block_height); 1019 1020 // PACKX: PACK X (input) tensor for pre-packed matrix multiplication 1021 1022 typedef void (*xnn_packx_ukernel_function)( 1023 size_t m, 1024 size_t k, 1025 const void* x, 1026 size_t x_stride, 1027 void* y); 1028 1029 typedef void (*xnn_x32_packx_ukernel_function)( 1030 size_t m, 1031 size_t k, 1032 const uint32_t* x, 1033 size_t x_stride, 1034 uint32_t* y); 1035 1036 // FILL: FILL array with value 1037 1038 typedef void (*xnn_fill_ukernel_function)( 1039 size_t rows, 1040 size_t channels, 1041 void* output, 1042 size_t output_stride, 1043 const uint32_t fill_pattern); 1044 1045 // PAD: PAD array with values (fill before, copy array, fill after) 1046 1047 typedef void (*xnn_pad_ukernel_function)( 1048 size_t rows, 1049 size_t channels, 1050 size_t pre_padding, 1051 size_t post_padding, 1052 const void* input, 1053 size_t input_stride, 1054 void* output, 1055 size_t output_stride, 1056 const uint32_t fill_value); 1057 1058 // RMAX: Reduce-MAX 1059 1060 typedef void (*xnn_rmax_ukernel_function)( 1061 size_t n, 1062 const void* x, 1063 void* y); 1064 1065 typedef void (*xnn_f16_rmax_ukernel_function)( 1066 size_t n, 1067 const void* x, 1068 void* y); 1069 1070 typedef void (*xnn_f32_rmax_ukernel_function)( 1071 size_t n, 1072 const float* x, 1073 float* y); 1074 1075 typedef void (*xnn_u8_rmax_ukernel_function)( 1076 size_t n, 1077 const uint8_t* x, 1078 uint8_t* y); 1079 1080 // RADDSTOREEXPMINUSMAX: Reduce-ADD & STORE EXP(x_i MINUS MAX[x_i]) 1081 1082 typedef void (*xnn_raddstoreexpminusmax_ukernel_function)( 1083 size_t n, 1084 const void* input, 1085 const void* max, 1086 void* output, 1087 void* sum, 1088 const void* params); 1089 1090 typedef void (*xnn_f16_raddstoreexpminusmax_ukernel_function)( 1091 size_t n, 1092 const void* input, 1093 const void* max, 1094 void* output, 1095 void* sum, 1096 const union xnn_f16_expminus_params* params); 1097 1098 typedef void (*xnn_f32_raddstoreexpminusmax_ukernel_function)( 1099 size_t n, 1100 const float* input, 1101 const float* max, 1102 float* output, 1103 float* sum, 1104 const union xnn_f32_expminus_params* params); 1105 1106 // VUNARY: Vector UNARY elementwise 1107 1108 typedef void (*xnn_vunary_ukernel_function)( 1109 size_t batch, 1110 const void* input, 1111 void* output, 1112 const void* params); 1113 1114 // VABS: Vector ABSolute value elementwise 1115 1116 typedef void (*xnn_f16_vabs_ukernel_function)( 1117 size_t batch, 1118 const void* input, 1119 void* output, 1120 const union xnn_f16_abs_params* params); 1121 1122 typedef void (*xnn_f32_vabs_ukernel_function)( 1123 size_t batch, 1124 const float* input, 1125 float* output, 1126 const union xnn_f32_abs_params* params); 1127 1128 // VCLAMP: Vector CLAMP elementwise 1129 1130 typedef void (*xnn_f16_vclamp_ukernel_function)( 1131 size_t batch, 1132 const void* input, 1133 void* output, 1134 const union xnn_f16_minmax_params* params); 1135 1136 typedef void (*xnn_f32_vclamp_ukernel_function)( 1137 size_t batch, 1138 const float* input, 1139 float* output, 1140 const union xnn_f32_minmax_params* params); 1141 1142 typedef void (*xnn_s8_vclamp_ukernel_function)( 1143 size_t batch, 1144 const int8_t* input, 1145 int8_t* output, 1146 const union xnn_s8_minmax_params* params); 1147 1148 typedef void (*xnn_u8_vclamp_ukernel_function)( 1149 size_t batch, 1150 const uint8_t* input, 1151 uint8_t* output, 1152 const union xnn_u8_minmax_params* params); 1153 1154 // VCVT: Vector ConVerT elementwise 1155 1156 typedef void (*xnn_f16_f32_vcvt_ukernel_function)( 1157 size_t batch, 1158 const void* input, 1159 float* output, 1160 const union xnn_f16_f32_cvt_params* params); 1161 1162 typedef void (*xnn_f32_f16_vcvt_ukernel_function)( 1163 size_t batch, 1164 const float* input, 1165 void* output, 1166 const union xnn_f32_f16_cvt_params* params); 1167 1168 typedef void (*xnn_f32_qs8_vcvt_ukernel_function)( 1169 size_t batch, 1170 const float* input, 1171 int8_t* output, 1172 const union xnn_f32_qs8_cvt_params* params); 1173 1174 typedef void (*xnn_f32_qu8_vcvt_ukernel_function)( 1175 size_t batch, 1176 const float* input, 1177 uint8_t* output, 1178 const union xnn_f32_qu8_cvt_params* params); 1179 1180 typedef void (*xnn_qs8_vcvt_ukernel_function)( 1181 size_t batch, 1182 const int8_t* input, 1183 int8_t* output, 1184 const union xnn_qs8_cvt_params* params); 1185 1186 typedef void (*xnn_qs8_f32_vcvt_ukernel_function)( 1187 size_t batch, 1188 const int8_t* input, 1189 float* output, 1190 const union xnn_qs8_f32_cvt_params* params); 1191 1192 typedef void (*xnn_qu8_vcvt_ukernel_function)( 1193 size_t batch, 1194 const uint8_t* input, 1195 uint8_t* output, 1196 const union xnn_qu8_cvt_params* params); 1197 1198 typedef void (*xnn_qu8_f32_vcvt_ukernel_function)( 1199 size_t batch, 1200 const uint8_t* input, 1201 float* output, 1202 const union xnn_qu8_f32_cvt_params* params); 1203 1204 // VELU: Vector Exponential Linear Unit elementwise 1205 1206 typedef void (*xnn_f16_velu_ukernel_function)( 1207 size_t batch, 1208 const void* input, 1209 void* output, 1210 const union xnn_f16_elu_params* params); 1211 1212 typedef void (*xnn_f32_velu_ukernel_function)( 1213 size_t batch, 1214 const float* input, 1215 float* output, 1216 const union xnn_f32_elu_params* params); 1217 1218 // VHSWISH: Vector Hard SWISH elementwise 1219 1220 typedef void (*xnn_f16_vhswish_ukernel_function)( 1221 size_t batch, 1222 const void* input, 1223 void* output, 1224 const union xnn_f16_hswish_params* params); 1225 1226 typedef void (*xnn_f32_vhswish_ukernel_function)( 1227 size_t batch, 1228 const float* input, 1229 float* output, 1230 const union xnn_f32_hswish_params* params); 1231 1232 // VLRELU: Vector Leaky REctified Linear Unit elementwise 1233 1234 typedef void (*xnn_f16_vlrelu_ukernel_function)( 1235 size_t batch, 1236 const void* input, 1237 void* output, 1238 const union xnn_f16_lrelu_params* params); 1239 1240 typedef void (*xnn_f32_vlrelu_ukernel_function)( 1241 size_t batch, 1242 const float* input, 1243 float* output, 1244 const union xnn_f32_lrelu_params* params); 1245 1246 typedef void (*xnn_qs8_vlrelu_ukernel_function)( 1247 size_t batch, 1248 const int8_t* input, 1249 int8_t* output, 1250 const union xnn_qs8_lrelu_params* params); 1251 1252 typedef void (*xnn_qu8_vlrelu_ukernel_function)( 1253 size_t batch, 1254 const uint8_t* input, 1255 uint8_t* output, 1256 const union xnn_qu8_lrelu_params* params); 1257 1258 // VNEG: Vector NEGate elementwise 1259 1260 typedef void (*xnn_f16_vneg_ukernel_function)( 1261 size_t batch, 1262 const void* input, 1263 void* output, 1264 const union xnn_f16_neg_params* params); 1265 1266 typedef void (*xnn_f32_vneg_ukernel_function)( 1267 size_t batch, 1268 const float* input, 1269 float* output, 1270 const union xnn_f32_neg_params* params); 1271 1272 // VRELU: Vector REctified Linear Unit elementwise 1273 1274 typedef void (*xnn_f32_vrelu_ukernel_function)( 1275 size_t batch, 1276 const float* input, 1277 float* output, 1278 const union xnn_f32_relu_params* params); 1279 1280 // VROUND: Vector ROUNDing elementwise 1281 1282 typedef void (*xnn_f16_vround_ukernel_function)( 1283 size_t batch, 1284 const void* input, 1285 void* output, 1286 const union xnn_f16_rnd_params* params); 1287 1288 typedef void (*xnn_f32_vround_ukernel_function)( 1289 size_t batch, 1290 const float* input, 1291 float* output, 1292 const union xnn_f32_rnd_params* params); 1293 1294 // VSIGMOID: Vector SIGMOID elementwise 1295 1296 typedef void (*xnn_f16_vsigmoid_ukernel_function)( 1297 size_t batch, 1298 const void* input, 1299 void* output, 1300 const union xnn_f16_sigmoid_params* params); 1301 1302 typedef void (*xnn_f32_vsigmoid_ukernel_function)( 1303 size_t batch, 1304 const float* input, 1305 float* output, 1306 const union xnn_f32_sigmoid_params* params); 1307 1308 // VSQR: Vector SQuaRe elementwise 1309 1310 typedef void (*xnn_f16_vsqr_ukernel_function)( 1311 size_t batch, 1312 const void* input, 1313 void* output, 1314 const union xnn_f16_default_params* params); 1315 1316 typedef void (*xnn_f32_vsqr_ukernel_function)( 1317 size_t batch, 1318 const float* input, 1319 float* output, 1320 const union xnn_f32_default_params* params); 1321 1322 // VSQRT: Vector SQuare RooT elementwise 1323 1324 typedef void (*xnn_f16_vsqrt_ukernel_function)( 1325 size_t batch, 1326 const void* input, 1327 void* output, 1328 const union xnn_f16_sqrt_params* params); 1329 1330 typedef void (*xnn_f32_vsqrt_ukernel_function)( 1331 size_t batch, 1332 const float* input, 1333 float* output, 1334 const union xnn_f32_sqrt_params* params); 1335 1336 // VSQRTSHIFT: Vector SQuare RooT and SHIFT elementwise 1337 1338 typedef void (*xnn_u64_u32_vsqrtshift_ukernel_function)( 1339 size_t batch, 1340 const uint64_t* input, 1341 uint32_t* output, 1342 uint32_t shift); 1343 1344 // LUT: vector LookUp Table elementwise 1345 1346 typedef void (*xnn_x8_lut_ukernel_function)( 1347 size_t batch, 1348 const uint8_t* input, 1349 uint8_t* output, 1350 const uint8_t* table); 1351 1352 // LUT32NORM: vector LookUp Table of 32-bit elements and NORMalize elementwise 1353 1354 typedef void (*xnn_u8_lut32norm_ukernel_function)( 1355 size_t n, 1356 const uint8_t* x, 1357 const uint32_t* t, 1358 uint8_t* y); 1359 1360 // VBINARY: Vector BINARY elementwise 1361 1362 typedef void (*xnn_vbinary_ukernel_function)( 1363 size_t batch, 1364 const void* input_x, 1365 const void* input_y, 1366 void* output, 1367 const void* params); 1368 1369 typedef void (*xnn_f16_vbinary_ukernel_function)( 1370 size_t batch, 1371 const void* input_x, 1372 const void* input_y, 1373 void* output, 1374 const union xnn_f16_default_params* params); 1375 1376 typedef void (*xnn_f32_vbinary_ukernel_function)( 1377 size_t batch, 1378 const float* input_x, 1379 const float* input_y, 1380 float* output, 1381 const union xnn_f32_default_params* params); 1382 1383 // VBINARY: Vector BINARY elementwise with ReLU activation 1384 1385 typedef void (*xnn_f32_vbinary_relu_ukernel_function)( 1386 size_t batch, 1387 const float* input_x, 1388 const float* input_y, 1389 float* output, 1390 const union xnn_f32_relu_params* params); 1391 1392 // VBINARY: Vector BINARY elementwise with Min+Max activation 1393 1394 typedef void (*xnn_f16_vbinary_minmax_ukernel_function)( 1395 size_t batch, 1396 const void* input_x, 1397 const void* input_y, 1398 void* output, 1399 const union xnn_f16_minmax_params* params); 1400 1401 typedef void (*xnn_f32_vbinary_minmax_ukernel_function)( 1402 size_t batch, 1403 const float* input_x, 1404 const float* input_y, 1405 float* output, 1406 const union xnn_f32_minmax_params* params); 1407 1408 // VADD: Vector ADD elementwise with Min+Max activation 1409 1410 typedef void (*xnn_qs8_vadd_minmax_ukernel_function)( 1411 size_t batch, 1412 const int8_t* input_x, 1413 const int8_t* input_y, 1414 int8_t* output, 1415 const union xnn_qs8_add_minmax_params* params); 1416 1417 typedef void (*xnn_qu8_vadd_minmax_ukernel_function)( 1418 size_t batch, 1419 const uint8_t* input_x, 1420 const uint8_t* input_y, 1421 uint8_t* output, 1422 const union xnn_qu8_add_minmax_params* params); 1423 1424 // VMUL: Vector MUL elementwise with Min+Max activation 1425 1426 typedef void (*xnn_qs8_vmul_minmax_ukernel_function)( 1427 size_t batch, 1428 const int8_t* input_x, 1429 const int8_t* input_y, 1430 int8_t* output, 1431 const union xnn_qs8_mul_minmax_params* params); 1432 1433 typedef void (*xnn_qu8_vmul_minmax_ukernel_function)( 1434 size_t batch, 1435 const uint8_t* input_x, 1436 const uint8_t* input_y, 1437 uint8_t* output, 1438 const union xnn_qu8_mul_minmax_params* params); 1439 1440 1441 /***************** Microkernel pointers for sparse inference *****************/ 1442 1443 // SpMM: Sparse Matrix-Matrix multiplication 1444 1445 typedef void (*xnn_spmm_ukernel_function)( 1446 size_t batch_size, 1447 size_t output_channels, 1448 const void* input, 1449 const void* weights, 1450 const int32_t* widx_dmap, 1451 const uint32_t* nidx_nnzmap, 1452 void* output, 1453 size_t output_stride, 1454 const void* params); 1455 1456 typedef void (*xnn_f16_spmm_minmax_ukernel_function)( 1457 size_t batch_size, 1458 size_t output_channels, 1459 const void* input, 1460 const void* weights, 1461 const int32_t* widx_dmap, 1462 const uint32_t* nidx_nnzmap, 1463 void* output, 1464 size_t output_stride, 1465 const union xnn_f16_minmax_params* params); 1466 1467 typedef void (*xnn_f32_spmm_minmax_ukernel_function)( 1468 size_t batch_size, 1469 size_t output_channels, 1470 const float* input, 1471 const float* weights, 1472 const int32_t* widx_dmap, 1473 const uint32_t* nidx_nnzmap, 1474 float* output, 1475 size_t output_stride, 1476 const union xnn_f32_minmax_params* params); 1477 1478 // CONV-HWC2CHW: direct CONVolution from HWC-layout tensor to CHW-layout tensor 1479 1480 typedef void (*xnn_conv_hwc2chw_ukernel_function)( 1481 size_t input_height, 1482 size_t input_width, 1483 size_t output_y_start, 1484 size_t output_y_end, 1485 const void* input, 1486 const void* zero, 1487 const void* weights, 1488 void* output, 1489 size_t input_padding_top, 1490 size_t output_channels, 1491 size_t output_height_stride, 1492 size_t output_channel_stride, 1493 const void* params); 1494 1495 typedef void (*xnn_f16_conv_hwc2chw_ukernel_function)( 1496 size_t input_height, 1497 size_t input_width, 1498 size_t output_y_start, 1499 size_t output_y_end, 1500 const void* input, 1501 const void* zero, 1502 const void* weights, 1503 void* output, 1504 size_t input_padding_top, 1505 size_t output_channels, 1506 size_t output_height_stride, 1507 size_t output_channel_stride, 1508 const union xnn_f16_minmax_params* params); 1509 1510 typedef void (*xnn_f32_conv_hwc2chw_ukernel_function)( 1511 size_t input_height, 1512 size_t input_width, 1513 size_t output_y_start, 1514 size_t output_y_end, 1515 const float* input, 1516 const float* zero, 1517 const float* weights, 1518 float* output, 1519 size_t input_padding_top, 1520 size_t output_channels, 1521 size_t output_height_stride, 1522 size_t output_channel_stride, 1523 const union xnn_f32_minmax_params* params); 1524 1525 // DWCONV2D-CHW: direct 2D DepthWise CONVolution in CHW layout 1526 1527 typedef void (*xnn_dwconv2d_chw_ukernel_function)( 1528 size_t input_height, 1529 size_t input_width, 1530 const void* input, 1531 const void* weights, 1532 const void* zero, 1533 void* output, 1534 uint32_t padding_top, 1535 const void* params); 1536 1537 typedef void (*xnn_f16_dwconv2d_chw_ukernel_function)( 1538 size_t input_height, 1539 size_t input_width, 1540 const void* input, 1541 const void* weights, 1542 const void* zero, 1543 void* output, 1544 uint32_t padding_top, 1545 const union xnn_f16_chw_params* params); 1546 1547 typedef void (*xnn_f32_dwconv2d_chw_ukernel_function)( 1548 size_t input_height, 1549 size_t input_width, 1550 const float* input, 1551 const float* weights, 1552 const float* zero, 1553 float* output, 1554 uint32_t padding_top, 1555 const union xnn_f32_chw_params* params); 1556 1557 // IBILINEAR-CHW: Indirect BILINEAR interpolation in CHW layout 1558 1559 typedef void (*xnn_ibilinear_chw_ukernel_function)( 1560 size_t output_pixels, 1561 size_t channels, 1562 const void** input, 1563 size_t input_offset, 1564 const void* weights, 1565 void* output, 1566 size_t input_increment); 1567 1568 typedef void (*xnn_f16_ibilinear_chw_ukernel_function)( 1569 size_t output_pixels, 1570 size_t channels, 1571 const void** input, 1572 size_t input_offset, 1573 const void* weights, 1574 void* output, 1575 size_t input_increment); 1576 1577 typedef void (*xnn_f32_ibilinear_chw_ukernel_function)( 1578 size_t output_pixels, 1579 size_t channels, 1580 const float** input, 1581 size_t input_offset, 1582 const float* weights, 1583 float* output, 1584 size_t input_increment); 1585 1586 // GAVGPOOL-CW: Global AVeraGe POOLing in CW layout. 1587 1588 typedef void (*xnn_gavgpool_cw_ukernel_function)( 1589 size_t elements, 1590 size_t channels, 1591 const float* input, 1592 float* output, 1593 const void* params); 1594 1595 typedef void (*xnn_f16_gavgpool_cw_ukernel_function)( 1596 size_t elements, 1597 size_t channels, 1598 const void* input, 1599 void* output, 1600 const union xnn_f16_gavgpool_params* params); 1601 1602 typedef void (*xnn_f32_gavgpool_cw_ukernel_function)( 1603 size_t elements, 1604 size_t channels, 1605 const float* input, 1606 float* output, 1607 const union xnn_f32_gavgpool_params* params); 1608 1609 1610 /********************* JIT microkernel generator pointers ********************/ 1611 1612 // JIT GEMM: used by GEMM/IGEMM microkernel generators. 1613 1614 struct jit_gemm_params { 1615 struct { 1616 float min; 1617 float max; 1618 } f32_minmax; 1619 size_t num_post_operations; 1620 struct xnn_post_operation* post_operations; 1621 }; 1622 1623 typedef xnn_status_t (*xnn_jit_gemm_code_generator_function)( 1624 struct xnn_code_buffer *code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void *params); 1625 typedef xnn_status_t (*xnn_jit_igemm_code_generator_function)( 1626 struct xnn_code_buffer *code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void *params); 1627 1628 1629 /***************** Audio pre-processing microkernel pointers *****************/ 1630 1631 typedef void (*xnn_s16_rmaxabs_ukernel_function)( 1632 size_t batch_size, 1633 const int16_t* x, 1634 uint16_t* y); 1635 1636 typedef void (*xnn_s16_window_ukernel_function)( 1637 size_t rows, 1638 size_t batch_size, 1639 const int16_t* input, 1640 const int16_t* weights, 1641 int16_t* output, 1642 uint32_t shift); 1643 1644 typedef void (*xnn_u32_filterbank_accumulate_ukernel_function)( 1645 size_t rows, 1646 const uint32_t* input, 1647 const uint8_t* weight_widths, 1648 const uint16_t* weights, 1649 uint64_t* output); 1650 1651 typedef void (*xnn_u32_filterbank_subtract_ukernel_function)( 1652 size_t batch_size, 1653 const uint32_t* input, 1654 uint32_t smoothing, 1655 uint32_t alternate_smoothing, 1656 uint32_t one_minus_smoothing, 1657 uint32_t alternate_one_minus_smoothing, 1658 uint32_t min_signal_remaining, 1659 uint32_t smoothing_bits, 1660 uint32_t spectral_subtraction_bits, 1661 uint32_t* noise_estimate, 1662 uint32_t* output); 1663 1664 typedef void (*xnn_s16_vlshift_ukernel_function)( 1665 size_t batch, 1666 const int16_t* input, 1667 int16_t* output, 1668 uint32_t shift); 1669 1670 typedef void (*xnn_cs16_vsquareabs_ukernel_function)( 1671 size_t batch_size, 1672 const int16_t* input, 1673 uint32_t* output); 1674 1675 typedef void (*xnn_u32_vlog_ukernel_function)( 1676 size_t batch_size, 1677 const uint32_t* input, 1678 uint32_t input_lshift, 1679 uint32_t output_scale, 1680 uint16_t* output); 1681 1682 typedef void (*xnn_cs16_bfly4_ukernel_function)( 1683 size_t samples, 1684 int16_t* data, 1685 size_t stride, 1686 const int16_t* twiddle); 1687 1688 typedef void (*xnn_cs16_fftr_ukernel_function)( 1689 size_t samples, 1690 int16_t* data, 1691 const int16_t* twiddle); 1692 1693 1694 /********************* Experimental microkernel pointers *********************/ 1695 1696 // ZIPC: ZIP Constant number of arrays 1697 1698 typedef void (*xnn_zipc_ukernel_function)( 1699 size_t n, 1700 const void* x, 1701 void* y); 1702 1703 typedef void (*xnn_x8_zipc_ukernel_function)( 1704 size_t n, 1705 const uint8_t* x, 1706 uint8_t* y); 1707 1708 typedef void (*xnn_x32_zipc_ukernel_function)( 1709 size_t n, 1710 const uint32_t* x, 1711 uint32_t* y); 1712 1713 // ZIPV: ZIP Variable number of arrays 1714 1715 typedef void (*xnn_zipv_ukernel_function)( 1716 size_t n, 1717 size_t m, 1718 const void* x, 1719 void* y); 1720 1721 typedef void (*xnn_x8_zipv_ukernel_function)( 1722 size_t n, 1723 size_t m, 1724 const uint8_t* x, 1725 uint8_t* y); 1726 1727 typedef void (*xnn_x32_zipv_ukernel_function)( 1728 size_t n, 1729 size_t m, 1730 const uint32_t* x, 1731 uint32_t* y); 1732 1733 // RADDEXPMINUSMAX: Reduce-ADD EXP(x_i MINUS MAX[x_i]) 1734 1735 typedef void (*xnn_f32_raddexpminusmax_ukernel_function)( 1736 size_t batch, 1737 const float* input, 1738 float* sum, 1739 float max); 1740 1741 // VSCALEEXPMINUSMAX: Vector SCALE EXP(x_i MINUS MAX[x_i]) 1742 1743 typedef void (*xnn_f32_vscaleexpminusmax_ukernel_function)( 1744 size_t batch, 1745 const float* input, 1746 float* output, 1747 float max, 1748 float scale); 1749 1750 // RADDEXTEXP: Reduce-ADD EXTended ("mantissa" + "exponent") EXPonentials 1751 typedef void (*xnn_f32_raddextexp_ukernel_function)( 1752 size_t batch, 1753 const float* input, 1754 float* sum); 1755 1756 // VSCALEEXTEXP: Vector SCALE EXTended ("mantissa" + "exponent") EXPonentials 1757 typedef void (*xnn_f32_vscaleextexp_ukernel_function)( 1758 size_t batch, 1759 const float* input, 1760 float* output, 1761 float scale_mantissa, 1762 float scale_exponent); 1763 1764 1765 /***************** Microkernel parameter initializer pointers ****************/ 1766 1767 typedef size_t (*xnn_init_f16_f32_cvt_params_fn)( 1768 union xnn_f16_f32_cvt_params params[XNN_MIN_ELEMENTS(1)]); 1769 1770 typedef size_t (*xnn_init_f32_f16_cvt_params_fn)( 1771 union xnn_f32_f16_cvt_params params[XNN_MIN_ELEMENTS(1)]); 1772 1773 typedef size_t (*xnn_init_f32_qs8_cvt_params_fn)( 1774 union xnn_f32_qs8_cvt_params params[XNN_MIN_ELEMENTS(1)], 1775 float scale, 1776 int8_t output_zero_point, 1777 int8_t output_min, 1778 int8_t output_max); 1779 1780 typedef size_t (*xnn_init_f32_qu8_cvt_params_fn)( 1781 union xnn_f32_qu8_cvt_params params[XNN_MIN_ELEMENTS(1)], 1782 float scale, 1783 uint8_t output_zero_point, 1784 uint8_t output_min, 1785 uint8_t output_max); 1786 1787 typedef size_t (*xnn_init_qs8_cvt_params_fn)( 1788 union xnn_qs8_cvt_params params[XNN_MIN_ELEMENTS(1)], 1789 float input_output_scale, 1790 int8_t input_zero_point, 1791 int8_t output_zero_point); 1792 1793 typedef size_t (*xnn_init_qs8_f32_cvt_params_fn)( 1794 union xnn_qs8_f32_cvt_params params[XNN_MIN_ELEMENTS(1)], 1795 float scale, 1796 int8_t zero_point); 1797 1798 typedef size_t (*xnn_init_qu8_cvt_params_fn)( 1799 union xnn_qu8_cvt_params params[XNN_MIN_ELEMENTS(1)], 1800 float input_output_scale, 1801 uint8_t input_zero_point, 1802 uint8_t output_zero_point); 1803 1804 typedef size_t (*xnn_init_qu8_f32_cvt_params_fn)( 1805 union xnn_qu8_f32_cvt_params params[XNN_MIN_ELEMENTS(1)], 1806 float scale, 1807 uint8_t zero_point); 1808 1809 typedef size_t (*xnn_init_qc8_conv_minmax_params_fn)( 1810 union xnn_qc8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)], 1811 int8_t output_zero_point, 1812 int8_t output_min, 1813 int8_t output_max); 1814 1815 typedef size_t (*xnn_init_qs8_conv_minmax_params_fn)( 1816 union xnn_qs8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)], 1817 float scale, 1818 int8_t output_zero_point, 1819 int8_t output_min, 1820 int8_t output_max); 1821 1822 typedef size_t (*xnn_init_qu8_conv_minmax_params_fn)( 1823 union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)], 1824 uint8_t kernel_zero_point, 1825 float scale, 1826 uint8_t output_zero_point, 1827 uint8_t output_min, 1828 uint8_t output_max); 1829 1830 typedef size_t (*xnn_init_qs8_avgpool_minmax_params_fn)( 1831 union xnn_qs8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)], 1832 int32_t bias, 1833 float scale, 1834 int8_t output_zero_point, 1835 int8_t output_min, 1836 int8_t output_max); 1837 1838 typedef size_t (*xnn_init_qu8_avgpool_minmax_params_fn)( 1839 union xnn_qu8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)], 1840 int32_t bias, 1841 float scale, 1842 uint8_t output_zero_point, 1843 uint8_t output_min, 1844 uint8_t output_max); 1845 1846 typedef void (*xnn_update_qs8_avgpool_minmax_params_fn)( 1847 union xnn_qs8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)], 1848 int32_t bias, 1849 float scale); 1850 1851 typedef void (*xnn_update_qu8_avgpool_minmax_params_fn)( 1852 union xnn_qu8_avgpool_minmax_params params[XNN_MIN_ELEMENTS(1)], 1853 int32_t bias, 1854 float scale); 1855 1856 typedef size_t (*xnn_init_qs8_add_minmax_params_fn)( 1857 union xnn_qs8_add_minmax_params params[XNN_MIN_ELEMENTS(1)], 1858 int8_t input_x_zero_point, 1859 int8_t input_y_zero_point, 1860 int8_t output_zero_point, 1861 float input_x_output_scale, 1862 float input_y_output_scale, 1863 int8_t output_min, 1864 int8_t output_max); 1865 1866 typedef size_t (*xnn_init_qu8_add_minmax_params_fn)( 1867 union xnn_qu8_add_minmax_params params[XNN_MIN_ELEMENTS(1)], 1868 uint8_t input_x_zero_point, 1869 uint8_t input_y_zero_point, 1870 uint8_t output_zero_point, 1871 float input_x_output_scale, 1872 float input_y_output_scale, 1873 uint8_t output_min, 1874 uint8_t output_max); 1875 1876 typedef size_t (*xnn_init_qs8_mul_minmax_params_fn)( 1877 union xnn_qs8_mul_minmax_params params[XNN_MIN_ELEMENTS(1)], 1878 int8_t input_x_zero_point, 1879 int8_t input_y_zero_point, 1880 int8_t output_zero_point, 1881 float product_output_scale, 1882 int8_t output_min, 1883 int8_t output_max); 1884 1885 typedef size_t (*xnn_init_qu8_mul_minmax_params_fn)( 1886 union xnn_qu8_mul_minmax_params params[XNN_MIN_ELEMENTS(1)], 1887 uint8_t input_x_zero_point, 1888 uint8_t input_y_zero_point, 1889 uint8_t output_zero_point, 1890 float product_output_scale, 1891 uint8_t output_min, 1892 uint8_t output_max); 1893 1894 typedef size_t (*xnn_init_f16_abs_params_fn)( 1895 union xnn_f16_abs_params params[XNN_MIN_ELEMENTS(1)]); 1896 1897 typedef size_t (*xnn_init_f32_abs_params_fn)( 1898 union xnn_f32_abs_params params[XNN_MIN_ELEMENTS(1)]); 1899 1900 typedef size_t (*xnn_init_f16_default_params_fn)( 1901 union xnn_f16_default_params params[XNN_MIN_ELEMENTS(1)]); 1902 1903 typedef size_t (*xnn_init_f32_default_params_fn)( 1904 union xnn_f32_default_params params[XNN_MIN_ELEMENTS(1)]); 1905 1906 typedef size_t (*xnn_init_f16_expminus_params_fn)( 1907 union xnn_f16_expminus_params params[XNN_MIN_ELEMENTS(1)]); 1908 1909 typedef size_t (*xnn_init_f32_expminus_params_fn)( 1910 union xnn_f32_expminus_params params[XNN_MIN_ELEMENTS(1)]); 1911 1912 typedef size_t (*xnn_init_f16_elu_params_fn)( 1913 union xnn_f16_elu_params params[XNN_MIN_ELEMENTS(1)], 1914 uint16_t prescale, 1915 uint16_t alpha, 1916 uint16_t beta); 1917 1918 typedef size_t (*xnn_init_f32_elu_params_fn)( 1919 union xnn_f32_elu_params params[XNN_MIN_ELEMENTS(1)], 1920 float prescale, 1921 float alpha, 1922 float beta); 1923 1924 typedef size_t (*xnn_init_f16_hswish_params_fn)( 1925 union xnn_f16_hswish_params params[XNN_MIN_ELEMENTS(1)]); 1926 1927 typedef size_t (*xnn_init_f32_hswish_params_fn)( 1928 union xnn_f32_hswish_params params[XNN_MIN_ELEMENTS(1)]); 1929 1930 typedef size_t (*xnn_init_f16_lrelu_params_fn)( 1931 union xnn_f16_lrelu_params params[XNN_MIN_ELEMENTS(1)], 1932 uint16_t slope); 1933 1934 typedef size_t (*xnn_init_f32_lrelu_params_fn)( 1935 union xnn_f32_lrelu_params params[XNN_MIN_ELEMENTS(1)], 1936 float slope); 1937 1938 typedef size_t (*xnn_init_qs8_lrelu_params_fn)( 1939 union xnn_qs8_lrelu_params params[XNN_MIN_ELEMENTS(1)], 1940 float positive_slope, 1941 float negative_slope, 1942 int8_t input_zero_point, 1943 int8_t output_zero_point); 1944 1945 typedef size_t (*xnn_init_qu8_lrelu_params_fn)( 1946 union xnn_qu8_lrelu_params params[XNN_MIN_ELEMENTS(1)], 1947 float positive_slope, 1948 float negative_slope, 1949 uint8_t input_zero_point, 1950 uint8_t output_zero_point); 1951 1952 typedef size_t (*xnn_init_bf16_minmax_params_fn)( 1953 union xnn_bf16_minmax_params params[XNN_MIN_ELEMENTS(1)], 1954 uint16_t min, 1955 uint16_t max); 1956 1957 typedef size_t (*xnn_init_f16_minmax_params_fn)( 1958 union xnn_f16_minmax_params params[XNN_MIN_ELEMENTS(1)], 1959 uint16_t min, 1960 uint16_t max); 1961 1962 typedef size_t (*xnn_init_f32_minmax_params_fn)( 1963 union xnn_f32_minmax_params params[XNN_MIN_ELEMENTS(1)], 1964 float min, 1965 float max); 1966 1967 typedef size_t (*xnn_init_s8_minmax_params_fn)( 1968 union xnn_s8_minmax_params params[XNN_MIN_ELEMENTS(1)], 1969 int8_t min, 1970 int8_t max); 1971 1972 typedef size_t (*xnn_init_u8_minmax_params_fn)( 1973 union xnn_u8_minmax_params params[XNN_MIN_ELEMENTS(1)], 1974 uint8_t min, 1975 uint8_t max); 1976 1977 typedef size_t (*xnn_init_f16_neg_params_fn)( 1978 union xnn_f16_neg_params params[XNN_MIN_ELEMENTS(1)]); 1979 1980 typedef size_t (*xnn_init_f32_neg_params_fn)( 1981 union xnn_f32_neg_params params[XNN_MIN_ELEMENTS(1)]); 1982 1983 typedef size_t (*xnn_init_f16_rnd_params_fn)( 1984 union xnn_f16_rnd_params params[XNN_MIN_ELEMENTS(1)]); 1985 1986 typedef size_t (*xnn_init_f32_rnd_params_fn)( 1987 union xnn_f32_rnd_params params[XNN_MIN_ELEMENTS(1)]); 1988 1989 typedef size_t (*xnn_init_f16_scaleminmax_params_fn)( 1990 union xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)], 1991 uint16_t scale, 1992 uint16_t min, 1993 uint16_t max); 1994 1995 typedef void (*xnn_update_f16_scaleminmax_params_fn)( 1996 union xnn_f16_scaleminmax_params params[XNN_MIN_ELEMENTS(1)], 1997 uint16_t scale); 1998 1999 typedef size_t (*xnn_init_f32_scaleminmax_params_fn)( 2000 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)], 2001 float scale, 2002 float min, 2003 float max); 2004 2005 typedef void (*xnn_update_f32_scaleminmax_params_fn)( 2006 union xnn_f32_scaleminmax_params params[XNN_MIN_ELEMENTS(1)], 2007 float scale); 2008 2009 typedef size_t (*xnn_init_f16_sigmoid_params_fn)( 2010 union xnn_f16_sigmoid_params params[XNN_MIN_ELEMENTS(1)]); 2011 2012 typedef size_t (*xnn_init_f32_sigmoid_params_fn)( 2013 union xnn_f32_sigmoid_params params[XNN_MIN_ELEMENTS(1)]); 2014 2015 typedef size_t (*xnn_init_f16_sqrt_params_fn)( 2016 union xnn_f16_sqrt_params params[XNN_MIN_ELEMENTS(1)]); 2017 2018 typedef size_t (*xnn_init_f32_sqrt_params_fn)( 2019 union xnn_f32_sqrt_params params[XNN_MIN_ELEMENTS(1)]); 2020 2021 typedef void (*xnn_init_qc8_scale_params_fn)( 2022 size_t channels, 2023 size_t channels_tile, 2024 size_t stride, 2025 const float scale[XNN_MIN_ELEMENTS(1)], 2026 void* packed_w); 2027 2028 typedef size_t (*xnn_init_f16_gavgpool_neonfp16arith_params_fn)( 2029 union xnn_f16_gavgpool_params params[XNN_MIN_ELEMENTS(1)], 2030 uint16_t multiplier, 2031 uint16_t output_min, 2032 uint16_t output_max, 2033 uint32_t width); 2034 2035 typedef size_t (*xnn_init_f16_chw_params_fn)( 2036 union xnn_f16_chw_params params[XNN_MIN_ELEMENTS(1)], 2037 uint32_t width, 2038 uint16_t output_min, 2039 uint16_t output_max); 2040