1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <stdint.h> 14 15 #include <xnnpack.h> 16 #include <xnnpack/common.h> 17 18 struct xnn_f16_output_params { 19 uint16_t scale; 20 uint16_t max; 21 uint16_t min; 22 }; 23 24 union xnn_f32_output_params { 25 struct { 26 float max; 27 float min; 28 } scalar; 29 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 30 struct { 31 XNN_ALIGN(16) float max[4]; 32 XNN_ALIGN(16) float min[4]; 33 } sse; 34 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 35 }; 36 37 union xnn_f32_spchw_params { 38 struct { 39 float max; 40 float min; 41 } scalar; 42 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 43 struct { 44 float min; 45 float max; 46 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels 47 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels 48 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels 49 } neon; 50 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 51 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 52 struct { 53 XNN_ALIGN(16) float max[4]; 54 XNN_ALIGN(16) float min[4]; 55 XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels 56 XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels 57 XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels 58 } sse; 59 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 60 }; 61 62 union xnn_u8_output_params { 63 struct { 64 int32_t max; 65 int32_t min; 66 } scalar; 67 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 68 struct { 69 uint8_t max; 70 uint8_t min; 71 } neon; 72 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 73 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 74 struct { 75 XNN_ALIGN(16) uint8_t max[16]; 76 XNN_ALIGN(16) uint8_t min[16]; 77 } sse2; 78 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 79 }; 80 81 union xnn_f32_avgpool_params { 82 struct { 83 float multiplier; 84 float output_min; 85 float output_max; 86 } scalar; 87 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 88 struct { 89 XNN_ALIGN(16) float multiplier[4]; 90 XNN_ALIGN(16) float output_max[4]; 91 XNN_ALIGN(16) float output_min[4]; 92 } sse2; 93 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 94 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 95 struct { 96 XNN_ALIGN(16) float multiplier; 97 XNN_ALIGN(16) float output_max; 98 XNN_ALIGN(16) float output_min; 99 } neon; 100 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 101 }; 102 103 union xnn_f32_gavgpool_params { 104 struct { 105 float multiplier; 106 float output_min; 107 float output_max; 108 } scalar; 109 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 110 struct { 111 XNN_ALIGN(16) float multiplier[4]; 112 XNN_ALIGN(16) float output_max[4]; 113 XNN_ALIGN(16) float output_min[4]; 114 XNN_ALIGN(16) uint32_t mask[4]; 115 } sse; 116 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 117 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 118 struct { 119 XNN_ALIGN(16) float multiplier; 120 XNN_ALIGN(16) float output_max; 121 XNN_ALIGN(16) float output_min; 122 XNN_ALIGN(16) uint32_t mask[4]; 123 } neon; 124 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 */ 125 }; 126 127 union xnn_f32_hswish_params { 128 struct { 129 float sixth; 130 float half; 131 float one; 132 } scalar; 133 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 134 struct { 135 XNN_ALIGN(16) float sixth[4]; 136 XNN_ALIGN(16) float half[4]; 137 XNN_ALIGN(16) float one[4]; 138 } sse; 139 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 140 }; 141 142 union xnn_q8_gemm_params { 143 struct { 144 int32_t kernel_zero_point; 145 int32_t input_zero_point; 146 int32_t multiplier; 147 int32_t remainder_mask; 148 int32_t remainder_threshold; 149 uint32_t shift; 150 int32_t output_min_less_zero_point; 151 int32_t output_max_less_zero_point; 152 int32_t output_zero_point; 153 } scalar; 154 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 155 struct { 156 int16_t kernel_zero_point; 157 int16_t input_zero_point; 158 int32_t multiplier; 159 int32_t right_shift; 160 int16_t output_zero_point; 161 uint8_t output_max; 162 uint8_t output_min; 163 } neon; 164 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 165 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 166 struct { 167 XNN_ALIGN(16) int16_t kernel_zero_point[8]; 168 XNN_ALIGN(16) int16_t input_zero_point[8]; 169 XNN_ALIGN(16) uint32_t multiplier[4]; 170 XNN_ALIGN(16) uint64_t rounding[2]; 171 XNN_ALIGN(16) int32_t remainder_mask[4]; 172 XNN_ALIGN(16) int32_t remainder_threshold[4]; 173 XNN_ALIGN(16) uint64_t shift[2]; 174 XNN_ALIGN(16) int16_t output_zero_point[8]; 175 XNN_ALIGN(16) uint8_t output_max[16]; 176 XNN_ALIGN(16) uint8_t output_min[16]; 177 } sse2; 178 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 179 }; 180 181 union xnn_q8_add_params { 182 struct { 183 int32_t zero_point_product; 184 uint32_t a_multiplier; 185 uint32_t b_multiplier; 186 uint32_t shift; 187 int32_t remainder_mask; 188 int32_t remainder_threshold; 189 int32_t y_zero_point; 190 int32_t y_max; 191 int32_t y_min; 192 } scalar; 193 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 194 struct { 195 uint8_t a_zero_point; 196 uint8_t b_zero_point; 197 int16_t y_zero_point; 198 int32_t a_multiplier; 199 int32_t b_multiplier; 200 int32_t right_shift; 201 uint8_t y_max; 202 uint8_t y_min; 203 } neon; 204 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 205 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 206 struct { 207 XNN_ALIGN(16) int32_t zero_point_product[4]; 208 XNN_ALIGN(16) uint16_t a_multiplier_lo[8]; 209 XNN_ALIGN(16) uint16_t a_multiplier_hi[8]; 210 XNN_ALIGN(16) uint16_t b_multiplier_lo[8]; 211 XNN_ALIGN(16) uint16_t b_multiplier_hi[8]; 212 XNN_ALIGN(16) int32_t remainder_mask[4]; 213 XNN_ALIGN(16) int32_t remainder_threshold[4]; 214 XNN_ALIGN(16) int16_t y_zero_point[8]; 215 XNN_ALIGN(16) uint8_t y_max[16]; 216 XNN_ALIGN(16) uint8_t y_min[16]; 217 uint32_t shift; 218 uint32_t a_multiplier; 219 uint32_t b_multiplier; 220 } sse2; 221 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 222 }; 223 224 union xnn_q8_avgpool_params { 225 struct { 226 int32_t bias; 227 int32_t multiplier; 228 int64_t rounding; 229 uint32_t right_shift; 230 int32_t output_min_less_zero_point; 231 int32_t output_max_less_zero_point; 232 int32_t output_zero_point; 233 } scalar; 234 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 235 struct { 236 int32_t bias; 237 int32_t multiplier; 238 int64_t left_shift; 239 int16_t output_zero_point; 240 uint8_t output_max; 241 uint8_t output_min; 242 } neon; 243 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 244 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 245 struct { 246 XNN_ALIGN(16) int32_t bias[4]; 247 XNN_ALIGN(16) uint32_t multiplier[4]; 248 XNN_ALIGN(16) uint64_t rounding[2]; 249 XNN_ALIGN(16) uint64_t right_shift[2]; 250 XNN_ALIGN(16) int16_t output_zero_point[8]; 251 XNN_ALIGN(16) uint8_t output_max[16]; 252 XNN_ALIGN(16) uint8_t output_min[16]; 253 } sse2; 254 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 255 }; 256 257 union xnn_fp32_requantization_params { 258 struct { 259 float scale; 260 float min_less_zero_point; 261 float max_less_zero_point; 262 float magic; 263 int32_t magic_less_zero_point; 264 } scalar; 265 struct { 266 float scale; 267 float max; 268 float min; 269 float magic; 270 int32_t magic_less_zero_point; 271 } neon; 272 struct { 273 float scale; 274 int16_t zero_point; 275 uint8_t max; 276 uint8_t min; 277 } neonv8; 278 struct { 279 XNN_ALIGN(16) float scale[4]; 280 XNN_ALIGN(16) int16_t zero_point[8]; 281 XNN_ALIGN(16) uint8_t max[16]; 282 XNN_ALIGN(16) uint8_t min[16]; 283 } sse2; 284 struct { 285 XNN_ALIGN(16) float scale[4]; 286 XNN_ALIGN(16) float min_less_zero_point[4]; 287 XNN_ALIGN(16) float max_less_zero_point[4]; 288 XNN_ALIGN(16) float magic[4]; 289 XNN_ALIGN(16) int32_t magic_less_zero_point[4]; 290 } psimd; 291 }; 292 293 union xnn_precise_requantization_params { 294 struct { 295 uint32_t multiplier; 296 uint32_t rounding_lo; 297 uint32_t rounding_hi; 298 uint32_t shift_less_32; 299 int32_t min_less_zero_point; 300 int32_t max_less_zero_point; 301 int32_t zero_point; 302 } scalar; 303 struct { 304 int32_t multiplier; 305 int32_t right_shift; 306 int16_t zero_point; 307 uint8_t max; 308 uint8_t min; 309 } neon; 310 struct { 311 XNN_ALIGN(16) uint32_t multiplier[4]; 312 XNN_ALIGN(16) uint64_t rounding[2]; 313 XNN_ALIGN(16) uint32_t shift[4]; 314 XNN_ALIGN(16) int16_t zero_point[8]; 315 XNN_ALIGN(16) uint8_t max[16]; 316 XNN_ALIGN(16) uint8_t min[16]; 317 } sse2; 318 }; 319 320 union xnn_q31_requantization_params { 321 struct { 322 int32_t multiplier; 323 int32_t remainder_mask; 324 int32_t remainder_threshold; 325 uint32_t shift; 326 int32_t min_less_zero_point; 327 int32_t max_less_zero_point; 328 int32_t zero_point; 329 } scalar; 330 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 331 struct { 332 int32_t multiplier; 333 int32_t right_shift; 334 int16_t zero_point; 335 uint8_t max; 336 uint8_t min; 337 } neon; 338 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 339 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 340 struct { 341 XNN_ALIGN(16) uint32_t multiplier[4]; 342 XNN_ALIGN(16) uint64_t rounding[2]; 343 XNN_ALIGN(16) int32_t remainder_mask[4]; 344 XNN_ALIGN(16) int32_t remainder_threshold[4]; 345 XNN_ALIGN(16) uint64_t shift[2]; 346 XNN_ALIGN(16) int16_t zero_point[8]; 347 XNN_ALIGN(16) uint8_t max[16]; 348 XNN_ALIGN(16) uint8_t min[16]; 349 } sse2; 350 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 351 }; 352 353 union xnn_requantization_params { 354 union xnn_precise_requantization_params precise; 355 union xnn_fp32_requantization_params fp32; 356 union xnn_q31_requantization_params q31; 357 }; 358 359 typedef void (*xnn_ppmm_ukernel_function)( 360 size_t mr, 361 size_t nc, 362 size_t kc, 363 const void* a, 364 const void* w, 365 void* c, 366 size_t cm_stride, 367 size_t cn_stride, 368 const void* params); 369 370 typedef void (*xnn_f32_ppmm_ukernel_function)( 371 size_t mr, 372 size_t nc, 373 size_t kc, 374 const float* a, 375 const float* w, 376 float* c, 377 size_t cm_stride, 378 size_t cn_stride, 379 const union xnn_f32_output_params* params); 380 381 typedef void (*xnn_f16_ppmm_ukernel_function)( 382 size_t mr, 383 size_t nc, 384 size_t kc, 385 const void* a, 386 const void* w, 387 void* c, 388 size_t cm_stride, 389 size_t cn_stride, 390 const struct xnn_f16_output_params* params); 391 392 typedef void (*xnn_gemm_ukernel_function)( 393 size_t mr, 394 size_t nr, 395 size_t k, 396 const void* a, 397 size_t a_stride, 398 const void* w, 399 void* c, 400 size_t cm_stride, 401 size_t cn_stride, 402 const void* params); 403 404 typedef void (*xnn_f32_gemm_ukernel_function)( 405 size_t mr, 406 size_t nr, 407 size_t k, 408 const float* a, 409 size_t a_stride, 410 const float* w, 411 float* c, 412 size_t cm_stride, 413 size_t cn_stride, 414 const union xnn_f32_output_params* params); 415 416 typedef void (*xnn_f32_gemminc_ukernel_function)( 417 size_t mr, 418 size_t nr, 419 size_t k, 420 const float* a, 421 size_t a_stride, 422 const float* w, 423 float* c, 424 size_t cm_stride, 425 size_t cn_stride, 426 const float* acc, 427 const union xnn_f32_output_params* params); 428 429 typedef void (*xnn_f16_gemm_ukernel_function)( 430 size_t mr, 431 size_t nr, 432 size_t k, 433 const void* a, 434 size_t a_stride, 435 const void* w, 436 void* c, 437 size_t cm_stride, 438 size_t cn_stride, 439 const struct xnn_f16_output_params* params); 440 441 typedef void (*xnn_q8_gemm_ukernel_function)( 442 size_t mr, 443 size_t nr, 444 size_t k, 445 const uint8_t* a, 446 size_t a_stride, 447 const void* w, 448 uint8_t* c, 449 size_t cm_stride, 450 size_t cn_stride, 451 const union xnn_q8_gemm_params* params); 452 453 typedef void (*xnn_igemm_ukernel_function)( 454 size_t mr, 455 size_t nr, 456 size_t kc, 457 size_t ks, 458 const void** a, 459 const void* w, 460 void* c, 461 size_t cm_stride, 462 size_t cn_stride, 463 size_t a_offset, 464 const void* zero, 465 const void* params); 466 467 typedef void (*xnn_f32_igemm_ukernel_function)( 468 size_t mr, 469 size_t nr, 470 size_t kc, 471 size_t ks, 472 const float** a, 473 const float* w, 474 float* c, 475 size_t cm_stride, 476 size_t cn_stride, 477 size_t a_offset, 478 const float* zero, 479 const union xnn_f32_output_params* params); 480 481 typedef void (*xnn_q8_igemm_ukernel_function)( 482 size_t mr, 483 size_t nr, 484 size_t kc, 485 size_t ks, 486 const uint8_t** a, 487 const void* w, 488 uint8_t* c, 489 size_t cm_stride, 490 size_t cn_stride, 491 size_t a_offset, 492 const uint8_t* zero, 493 const union xnn_q8_gemm_params* params); 494 495 typedef void (*xnn_conv_hwc_ukernel_function)( 496 size_t input_height, 497 size_t input_width, 498 size_t output_y_start, 499 size_t output_y_end, 500 const void* input, 501 const void* zero, 502 const void* weights, 503 void* output, 504 size_t input_padding_top, 505 size_t output_channels, 506 size_t output_height_stride, 507 size_t output_width_stride, 508 const void* params); 509 510 typedef void (*xnn_f32_conv_hwc_ukernel_function)( 511 size_t input_height, 512 size_t input_width, 513 size_t output_y_start, 514 size_t output_y_end, 515 const float* input, 516 const float* zero, 517 const float* weights, 518 float* output, 519 size_t input_padding_top, 520 size_t output_channels, 521 size_t output_height_stride, 522 size_t output_width_stride, 523 const union xnn_f32_output_params* params); 524 525 typedef void (*xnn_conv_hwc2spchw_ukernel_function)( 526 size_t input_height, 527 size_t input_width, 528 size_t output_y_start, 529 size_t output_y_end, 530 const void* input, 531 const void* zero, 532 const void* weights, 533 void* output, 534 size_t input_padding_top, 535 size_t output_channels, 536 size_t output_height_stride, 537 size_t output_channel_stride, 538 const void* params); 539 540 typedef void (*xnn_f32_conv_hwc2spchw_ukernel_function)( 541 size_t input_height, 542 size_t input_width, 543 size_t output_y_start, 544 size_t output_y_end, 545 const float* input, 546 const float* zero, 547 const float* weights, 548 float* output, 549 size_t input_padding_top, 550 size_t output_channels, 551 size_t output_height_stride, 552 size_t output_channel_stride, 553 const union xnn_f32_output_params* params); 554 555 typedef void (*xnn_spmm_ukernel_function)( 556 uint32_t m, 557 uint32_t n, 558 const void* a, 559 const void* w, 560 const int32_t* dmap, 561 const uint32_t* nmap, 562 void* c, 563 const void* params); 564 565 typedef void (*xnn_f16_spmm_ukernel_function)( 566 uint32_t m, 567 uint32_t n, 568 const void* a, 569 const void* w, 570 const int32_t* dmap, 571 const uint32_t* nmap, 572 void* c, 573 const struct xnn_f16_output_params* params); 574 575 typedef void (*xnn_f32_spmm_ukernel_function)( 576 uint32_t m, 577 uint32_t n, 578 const float* a, 579 const float* w, 580 const int32_t* dmap, 581 const uint32_t* nmap, 582 float* c, 583 const union xnn_f32_output_params* params); 584 585 typedef void (*xnn_packx_ukernel_function)( 586 size_t m, 587 size_t k, 588 const void* x, 589 size_t x_stride, 590 void* y); 591 592 typedef void (*xnn_x32_packx_ukernel_function)( 593 size_t m, 594 size_t k, 595 const uint32_t* x, 596 size_t x_stride, 597 uint32_t* y); 598 599 typedef void (*xnn_pad_ukernel_function)( 600 size_t m, 601 size_t n, 602 size_t l, 603 size_t r, 604 uint32_t c, 605 const void* x, 606 size_t x_stride, 607 void* y, 608 size_t y_stride); 609 610 typedef void (*xnn_unpool_ukernel_function)( 611 size_t p, 612 size_t c, 613 uint32_t f, 614 const void* input, 615 const uint32_t* index, 616 void** output); 617 618 typedef void (*xnn_x32_unpool_ukernel_function)( 619 size_t p, 620 size_t c, 621 uint32_t f, 622 const uint32_t* input, 623 const uint32_t* index, 624 uint32_t** output); 625 626 typedef void (*xnn_zipc_ukernel_function)( 627 size_t n, 628 const void* x, 629 void* y); 630 631 typedef void (*xnn_x8_zipc_ukernel_function)( 632 size_t n, 633 const uint8_t* x, 634 uint8_t* y); 635 636 typedef void (*xnn_x32_zipc_ukernel_function)( 637 size_t n, 638 const uint32_t* x, 639 uint32_t* y); 640 641 typedef void (*xnn_zipv_ukernel_function)( 642 size_t n, 643 size_t m, 644 const void* x, 645 void* y); 646 647 typedef void (*xnn_x8_zipv_ukernel_function)( 648 size_t n, 649 size_t m, 650 const uint8_t* x, 651 uint8_t* y); 652 653 typedef void (*xnn_x32_zipv_ukernel_function)( 654 size_t n, 655 size_t m, 656 const uint32_t* x, 657 uint32_t* y); 658 659 typedef void (*xnn_x8_lut_ukernel_function)( 660 size_t n, 661 const uint8_t* x, 662 const uint8_t* t, 663 uint8_t* y); 664 665 typedef void (*xnn_dwconv_spchw_ukernel_function)( 666 size_t output_height, 667 size_t input_width, 668 const void* input, 669 const void* weights, 670 void* output, 671 size_t input_tuple_stride, 672 size_t output_tuple_stride, 673 size_t input_height_stride, 674 size_t output_height_stride, 675 const void* params); 676 677 typedef void (*xnn_f32_dwconv_spchw_ukernel_function)( 678 size_t output_height, 679 size_t input_width, 680 const float* input, 681 const float* weights, 682 float* output, 683 size_t input_tuple_stride, 684 size_t output_tuple_stride, 685 size_t input_height_stride, 686 size_t output_height_stride, 687 const union xnn_f32_spchw_params* params); 688 689 typedef void (*xnn_dwconv_up_ukernel_function)( 690 size_t channels, 691 size_t output_width, 692 const void** input, 693 const void* weights, 694 void* output, 695 size_t input_stride, 696 size_t output_increment, 697 const void* params); 698 699 typedef void (*xnn_f32_dwconv_up_ukernel_function)( 700 size_t channels, 701 size_t output_width, 702 const float** input, 703 const float* weights, 704 float* output, 705 size_t input_stride, 706 size_t output_increment, 707 const union xnn_f32_output_params* params); 708 709 typedef void (*xnn_q8_dwconv_up_ukernel_function)( 710 size_t channels, 711 size_t output_width, 712 const uint8_t** input, 713 const void* weights, 714 uint8_t* output, 715 size_t input_stride, 716 size_t output_increment, 717 const union xnn_q8_gemm_params* params); 718 719 typedef void (*xnn_dwconv_mp_ukernel_function)( 720 size_t channels, 721 size_t output_width, 722 const void** input, 723 const void* weights, 724 void* buffer, 725 void* output, 726 size_t input_stride, 727 size_t output_increment, 728 const void* params); 729 730 typedef void (*xnn_f32_bilinear_ukernel_function)( 731 size_t output_pixels, 732 size_t channels, 733 const float** input, 734 size_t input_offset, 735 const float* weights, 736 float* output, 737 size_t output_increment); 738 739 typedef void (*xnn_bilinear_ukernel_function)( 740 size_t output_pixels, 741 size_t channels, 742 const void** input, 743 size_t input_offset, 744 const void* weights, 745 void* output, 746 size_t output_increment); 747 748 typedef void (*xnn_gavgpool_up_ukernel_function)( 749 size_t m, 750 size_t n, 751 const void* x, 752 size_t x_stride, 753 const void* zero, 754 void* y, 755 const void* params); 756 757 typedef void (*xnn_f32_gavgpool_up_ukernel_function)( 758 size_t m, 759 size_t n, 760 const float* x, 761 size_t x_stride, 762 const float* zero, 763 float* y, 764 const union xnn_f32_avgpool_params* params); 765 766 typedef void (*xnn_gavgpool_spchw_ukernel_function)( 767 size_t elements, 768 size_t channels, 769 const float* input, 770 float* output, 771 const void* params); 772 773 typedef void (*xnn_f32_gavgpool_spchw_ukernel_function)( 774 size_t elements, 775 size_t channels, 776 const float* input, 777 float* output, 778 const union xnn_f32_gavgpool_params* params); 779 780 typedef void (*xnn_q8_gavgpool_up_ukernel_function)( 781 size_t m, 782 size_t n, 783 const uint8_t* x, 784 size_t x_stride, 785 const uint8_t* zero, 786 uint8_t* y, 787 const union xnn_q8_avgpool_params* params); 788 789 typedef void (*xnn_gavgpool_mp_ukernel_function)( 790 size_t m, 791 size_t n, 792 const void* x, 793 size_t x_stride, 794 const void* zero, 795 void* buffer, 796 void* y, 797 const void* params); 798 799 typedef void (*xnn_f32_gavgpool_mp_ukernel_function)( 800 size_t m, 801 size_t n, 802 const float* x, 803 size_t x_stride, 804 const float* zero, 805 float* buffer, 806 float* y, 807 const union xnn_f32_avgpool_params* params); 808 809 typedef void (*xnn_q8_gavgpool_mp_ukernel_function)( 810 size_t m, 811 size_t n, 812 const uint8_t* x, 813 size_t x_stride, 814 const uint8_t* zero, 815 int32_t* buffer, 816 uint8_t* y, 817 const union xnn_q8_avgpool_params* params); 818 819 typedef void (*xnn_avgpool_up_ukernel_function)( 820 size_t n, 821 size_t ks, 822 size_t kc, 823 const void** x, 824 const void* zero, 825 void* y, 826 size_t x_increment, 827 size_t y_increment, 828 const void* params); 829 830 typedef void (*xnn_f32_avgpool_up_ukernel_function)( 831 size_t n, 832 size_t ks, 833 size_t kc, 834 const float** x, 835 const float* zero, 836 float* y, 837 size_t x_increment, 838 size_t y_increment, 839 const union xnn_f32_avgpool_params* params); 840 841 typedef void (*xnn_q8_avgpool_up_ukernel_function)( 842 size_t n, 843 size_t ks, 844 size_t kc, 845 const uint8_t** x, 846 const uint8_t* zero, 847 uint8_t* y, 848 size_t x_increment, 849 size_t y_increment, 850 const union xnn_q8_avgpool_params* params); 851 852 typedef void (*xnn_avgpool_mp_ukernel_function)( 853 size_t n, 854 size_t ks, 855 size_t kc, 856 const void** x, 857 const void* zero, 858 void* buffer, 859 void* y, 860 size_t x_increment, 861 size_t y_increment, 862 const void* params); 863 864 typedef void (*xnn_f32_avgpool_mp_ukernel_function)( 865 size_t n, 866 size_t ks, 867 size_t kc, 868 const float** x, 869 const float* zero, 870 float* buffer, 871 float* y, 872 size_t x_increment, 873 size_t y_increment, 874 const union xnn_f32_avgpool_params* params); 875 876 typedef void (*xnn_q8_avgpool_mp_ukernel_function)( 877 size_t n, 878 size_t ks, 879 size_t kc, 880 const uint8_t** x, 881 const uint8_t* zero, 882 int32_t* buffer, 883 uint8_t* y, 884 size_t x_increment, 885 size_t y_increment, 886 const union xnn_q8_avgpool_params* params); 887 888 typedef void (*xnn_pavgpool_up_ukernel_function)( 889 size_t n, 890 size_t ks, 891 size_t kc, 892 const void** x, 893 const void* zero, 894 const void* multiplier, 895 void* y, 896 size_t x_increment, 897 size_t y_increment, 898 const void* params); 899 900 typedef void (*xnn_f32_pavgpool_up_ukernel_function)( 901 size_t n, 902 size_t ks, 903 size_t kc, 904 const float** x, 905 const float* zero, 906 const float* multiplier, 907 float* y, 908 size_t x_increment, 909 size_t y_increment, 910 const union xnn_f32_output_params* params); 911 912 typedef void (*xnn_pavgpool_mp_ukernel_function)( 913 size_t n, 914 size_t ks, 915 size_t kc, 916 const void** x, 917 const void* zero, 918 const void* multiplier, 919 void* buffer, 920 void* y, 921 size_t x_increment, 922 size_t y_increment, 923 const void* params); 924 925 typedef void (*xnn_f32_pavgpool_mp_ukernel_function)( 926 size_t n, 927 size_t ks, 928 size_t kc, 929 const float** x, 930 const float* zero, 931 const float* multiplier, 932 float* buffer, 933 float* y, 934 size_t x_increment, 935 size_t y_increment, 936 const union xnn_f32_output_params* params); 937 938 typedef void (*xnn_maxpool_ukernel_function)( 939 size_t output_pixels, 940 size_t kernel_elements, 941 size_t channels, 942 const void** input, 943 size_t input_offset, 944 void* output, 945 size_t input_increment, 946 size_t output_increment, 947 const void* params); 948 949 typedef void (*xnn_f32_maxpool_ukernel_function)( 950 size_t output_pixels, 951 size_t kernel_elements, 952 size_t channels, 953 const float** input, 954 size_t input_offset, 955 float* output, 956 size_t input_increment, 957 size_t output_increment, 958 const union xnn_f32_output_params* params); 959 960 typedef void (*xnn_u8_maxpool_ukernel_function)( 961 size_t output_pixels, 962 size_t kernel_elements, 963 size_t channels, 964 const uint8_t** input, 965 size_t input_offset, 966 uint8_t* output, 967 size_t input_increment, 968 size_t output_increment, 969 const union xnn_u8_output_params* params); 970 971 typedef void (*xnn_argmaxpool_up_ukernel_function)( 972 size_t output_pixels, 973 size_t kernel_elements, 974 size_t channels, 975 const void** input, 976 size_t input_offset, 977 void* output, 978 uint32_t* index, 979 size_t input_increment, 980 size_t output_increment, 981 const void* params); 982 983 typedef void (*xnn_f32_argmaxpool_up_ukernel_function)( 984 size_t output_pixels, 985 size_t kernel_elements, 986 size_t channels, 987 const float** input, 988 size_t input_offset, 989 float* output, 990 uint32_t* index, 991 size_t input_increment, 992 size_t output_increment, 993 const union xnn_f32_output_params* params); 994 995 typedef void (*xnn_argmaxpool_mp_ukernel_function)( 996 size_t output_pixels, 997 size_t kernel_elements, 998 size_t channels, 999 const void** input, 1000 size_t input_offset, 1001 void* accumulation_buffer, 1002 uint32_t* index_buffer, 1003 void* output, 1004 uint32_t* index, 1005 size_t input_increment, 1006 size_t output_increment, 1007 const void* params); 1008 1009 typedef void (*xnn_f32_argmaxpool_mp_ukernel_function)( 1010 size_t output_pixels, 1011 size_t kernel_elements, 1012 size_t channels, 1013 const float** input, 1014 size_t input_offset, 1015 float* accumulation_buffer, 1016 uint32_t* index_buffer, 1017 float* output, 1018 uint32_t* index, 1019 size_t input_increment, 1020 size_t output_increment, 1021 const union xnn_f32_output_params* params); 1022 1023 typedef void (*xnn_univector_ukernel_function)( 1024 size_t n, 1025 const void* x, 1026 void* y, 1027 const void* params); 1028 1029 typedef void (*xnn_f32_clamp_ukernel_function)( 1030 size_t n, 1031 const float* x, 1032 float* y, 1033 const union xnn_f32_output_params* params); 1034 1035 typedef void (*xnn_u8_clamp_ukernel_function)( 1036 size_t n, 1037 const uint8_t* x, 1038 uint8_t* y, 1039 const union xnn_u8_output_params* params); 1040 1041 typedef void (*xnn_f32_hswish_ukernel_function)( 1042 size_t n, 1043 const float* x, 1044 float* y, 1045 const union xnn_f32_hswish_params* params); 1046 1047 typedef void (*xnn_rmax_ukernel_function)( 1048 size_t n, 1049 const void* x, 1050 void* y); 1051 1052 typedef void (*xnn_u8_rmax_ukernel_function)( 1053 size_t n, 1054 const uint8_t* x, 1055 uint8_t* y); 1056 1057 typedef void (*xnn_f32_rmax_ukernel_function)( 1058 size_t n, 1059 const float* x, 1060 float* y); 1061 1062 typedef void (*xnn_u8_lut32norm_ukernel_function)( 1063 size_t n, 1064 const uint8_t* x, 1065 const uint32_t* t, 1066 uint8_t* y); 1067 1068 typedef void (*xnn_vadd_ukernel_function)( 1069 size_t n, 1070 const void* a, 1071 const void* b, 1072 void* y, 1073 const void* params); 1074 1075 typedef void (*xnn_f32_vadd_ukernel_function)( 1076 size_t n, 1077 const float* a, 1078 const float* b, 1079 float* y, 1080 const union xnn_f32_output_params* params); 1081 1082 typedef void (*xnn_q8_vadd_ukernel_function)( 1083 size_t n, 1084 const uint8_t* a, 1085 const uint8_t* b, 1086 uint8_t* y, 1087 const union xnn_q8_add_params* params); 1088 1089 typedef void (*xnn_vbinary_ukernel_function)( 1090 size_t n, 1091 const void* a, 1092 const void* b, 1093 void* y, 1094 const void* params); 1095 1096 typedef void (*xnn_f32_vbinary_ukernel_function)( 1097 size_t n, 1098 const float* a, 1099 const float* b, 1100 float* y, 1101 const union xnn_f32_output_params* params); 1102 1103 typedef void (*xnn_vunary_ukernel_function)( 1104 size_t n, 1105 const void* x, 1106 void* y, 1107 const void* params); 1108 1109 typedef void (*xnn_f32_vunary_ukernel_function)( 1110 size_t n, 1111 const float* x, 1112 float* y, 1113 const void* params); 1114 1115 typedef void (*xnn_vmulcaddc_ukernel_function)( 1116 size_t m, 1117 size_t c, 1118 const void* x, 1119 size_t x_stride, 1120 const void* w, 1121 void* y, 1122 size_t y_stride, 1123 const void* params); 1124 1125 typedef void (*xnn_f32_vmulcaddc_ukernel_function)( 1126 size_t m, 1127 size_t c, 1128 const float* x, 1129 size_t x_stride, 1130 const float* w, 1131 float* y, 1132 size_t y_stride, 1133 const union xnn_f32_output_params* params); 1134 1135 typedef void (*xnn_prelu_ukernel_function)( 1136 size_t mr, 1137 size_t n, 1138 const void* x, 1139 size_t x_stride, 1140 const void* w, 1141 void* y, 1142 size_t y_stride, 1143 const void* params); 1144 1145 typedef void (*xnn_f32_prelu_ukernel_function)( 1146 size_t mr, 1147 size_t n, 1148 const float* x, 1149 size_t x_stride, 1150 const float* w, 1151 float* y, 1152 size_t y_stride, 1153 const union xnn_f32_output_params* params); 1154 1155 typedef void (*xnn_f32_raddexpminusmax_ukernel_function)( 1156 size_t n, 1157 const float* input, 1158 float* sum, 1159 float max); 1160 1161 typedef void (*xnn_f32_raddstoreexpminusmax_ukernel_function)( 1162 size_t n, 1163 const float* input, 1164 float* output, 1165 float* sum, 1166 float max); 1167 1168 typedef void (*xnn_f32_vscaleexpminusmax_ukernel_function)( 1169 size_t n, 1170 const float* input, 1171 float* output, 1172 float max, 1173 float scale); 1174 1175 typedef void (*xnn_f32_vscale_ukernel_function)( 1176 size_t n, 1177 const float* x, 1178 float* y, 1179 float c); 1180 1181 // Reduce-Add Extended ("mantissa" + "exponent") Exponentials 1182 typedef void (*xnn_f32_raddextexp_ukernel_function)( 1183 size_t n, 1184 const float* input, 1185 float* sum); 1186 1187 // Vector Scale Extended ("mantissa" + "exponent") Exponentials 1188 typedef void (*xnn_f32_vscaleextexp_ukernel_function)( 1189 size_t n, 1190 const float* input, 1191 float* output, 1192 float scale_mantissa, 1193 float scale_exponent); 1194 1195 1196 struct gemm_parameters { 1197 xnn_gemm_ukernel_function gemm; 1198 xnn_igemm_ukernel_function igemm; 1199 // Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters. 1200 xnn_gemm_ukernel_function gemm1; 1201 xnn_igemm_ukernel_function igemm1; 1202 uint8_t mr; 1203 uint8_t nr; 1204 uint8_t log2_kr; 1205 uint8_t log2_sr; 1206 }; 1207 1208 struct vbinary_parameters { 1209 xnn_vbinary_ukernel_function op_ukernel; 1210 xnn_vbinary_ukernel_function opc_ukernel; 1211 xnn_vbinary_ukernel_function ropc_ukernel; 1212 // Number of elements in a tile. 1213 // For best efficiency, micro-kernel must process a multiple of this number of elements in each call. 1214 uint8_t element_tile; 1215 }; 1216 1217 struct spmm_parameters { 1218 xnn_spmm_ukernel_function ukernel; 1219 // Number of M-dimension elements in a tile. 1220 // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator. 1221 uint8_t mr; 1222 // Number of N-dimension elements in a tile. 1223 // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator. 1224 uint8_t nr; 1225 }; 1226 1227 struct hwc2spchw_dconv_parameters { 1228 xnn_conv_hwc2spchw_ukernel_function ukernel_with_symm_padding; 1229 // Number of output channels in a tile. 1230 // This parameter must be passed as is to weight packing function. 1231 uint8_t output_channel_tile; 1232 // Number of output height pixels in a tile. 1233 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call. 1234 uint8_t output_height_tile; 1235 // Number of output width pixes in a tile. 1236 uint8_t output_width_tile; 1237 }; 1238 1239 struct spchw_dwconv_parameters { 1240 xnn_dwconv_spchw_ukernel_function ukernel; 1241 // Number of input width pixels in a tile. 1242 uint8_t input_width_tile; 1243 // Number of output width pixels in a tile. 1244 uint8_t output_width_tile; 1245 // Number of output height pixels in a tile. 1246 // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call. 1247 uint8_t output_height_tile; 1248 }; 1249 1250 struct spchw_gavgpool_parameters { 1251 xnn_gavgpool_spchw_ukernel_function ukernel; 1252 // Number of channels in a tile. 1253 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call. 1254 uint8_t channel_tile; 1255 }; 1256 1257 struct dwconv_parameters { 1258 union { 1259 xnn_dwconv_up_ukernel_function up; 1260 xnn_dwconv_mp_ukernel_function mp; 1261 }; 1262 uint8_t cr; 1263 uint8_t mr; 1264 uint8_t qr; 1265 }; 1266 1267 struct gavgpool_parameters { 1268 xnn_gavgpool_up_ukernel_function up; 1269 xnn_gavgpool_mp_ukernel_function mp; 1270 uint8_t mr; 1271 }; 1272 1273 struct avgpool_parameters { 1274 xnn_avgpool_up_ukernel_function up; 1275 xnn_avgpool_mp_ukernel_function mp; 1276 uint8_t mr; 1277 uint8_t qr; 1278 }; 1279 1280 struct pavgpool_parameters { 1281 xnn_pavgpool_up_ukernel_function up; 1282 xnn_pavgpool_mp_ukernel_function mp; 1283 uint8_t mr; 1284 uint8_t qr; 1285 }; 1286 1287 struct argmaxpool_parameters { 1288 union { 1289 xnn_argmaxpool_up_ukernel_function up; 1290 xnn_argmaxpool_mp_ukernel_function mp; 1291 }; 1292 uint8_t mr; 1293 uint8_t qr; 1294 }; 1295 1296 struct maxpool_parameters { 1297 xnn_maxpool_ukernel_function ukernel; 1298 uint8_t mr; 1299 uint8_t qr; 1300 }; 1301 1302 struct bilinear_parameters { 1303 xnn_bilinear_ukernel_function ukernel; 1304 // Number of output pixels in a tile. 1305 // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call. 1306 uint8_t pixel_tile; 1307 // Number of channels in a tile. 1308 // For best efficiency, micro-kernel must process a multiple of this number of channels in each call. 1309 uint8_t channel_tile; 1310 }; 1311 1312 struct zip_parameters { 1313 xnn_zipc_ukernel_function x2; 1314 xnn_zipc_ukernel_function x3; 1315 xnn_zipc_ukernel_function x4; 1316 xnn_zipv_ukernel_function xm; 1317 }; 1318 1319 struct prelu_parameters { 1320 xnn_prelu_ukernel_function ukernel; 1321 uint16_t row_tile; 1322 uint16_t channel_tile; 1323 }; 1324 1325 struct pad_parameters { 1326 xnn_pad_ukernel_function ukernel; 1327 uint8_t mr; 1328 }; 1329 1330 struct vmulcaddc_parameters { 1331 xnn_vmulcaddc_ukernel_function ukernel; 1332 uint8_t channel_tile; 1333 uint8_t row_tile; 1334 }; 1335 1336 #define XNN_MAX_Q8_DWCONV_UKERNELS 1 1337 #define XNN_MAX_F32_DWCONV_UKERNELS 3 1338 #define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3 1339 1340 struct xnn_parameters { 1341 bool initialized; 1342 struct xnn_allocator allocator; 1343 struct { 1344 struct gemm_parameters gemm; 1345 struct dwconv_parameters dwconv[XNN_MAX_Q8_DWCONV_UKERNELS]; 1346 struct avgpool_parameters avgpool; 1347 struct gavgpool_parameters gavgpool; 1348 xnn_vadd_ukernel_function vadd; 1349 } q8; 1350 struct { 1351 struct maxpool_parameters maxpool; 1352 xnn_univector_ukernel_function clamp; 1353 xnn_u8_lut32norm_ukernel_function lut32norm; 1354 xnn_u8_rmax_ukernel_function rmax; 1355 } u8; 1356 struct { 1357 xnn_x8_lut_ukernel_function lut; 1358 struct zip_parameters zip; 1359 } x8; 1360 struct { 1361 struct gemm_parameters gemm; 1362 struct gemm_parameters gemm2; 1363 struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS]; 1364 struct avgpool_parameters avgpool; 1365 struct pavgpool_parameters pavgpool; 1366 struct gavgpool_parameters gavgpool; 1367 struct maxpool_parameters maxpool; 1368 struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS]; 1369 // Bilinear interpolation (2D). 1370 struct bilinear_parameters bilinear; 1371 xnn_univector_ukernel_function clamp; 1372 xnn_univector_ukernel_function hswish; 1373 xnn_univector_ukernel_function sigmoid; 1374 struct prelu_parameters prelu; 1375 struct vbinary_parameters vadd; 1376 struct vbinary_parameters vdiv; 1377 struct vbinary_parameters vmax; 1378 struct vbinary_parameters vmin; 1379 struct vbinary_parameters vmul; 1380 struct vbinary_parameters vsub; 1381 struct vmulcaddc_parameters vmulcaddc; 1382 xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax; 1383 xnn_f32_rmax_ukernel_function rmax; 1384 // Sparse Matrix-Dense Matrix Multiplication (NR=1 block). 1385 struct spmm_parameters spmm; 1386 // Sparse Matrix-Dense Matrix Multiplication (NR=2 block). 1387 struct spmm_parameters spmm2; 1388 // Sparse Matrix-Dense Matrix Multiplication (NR=4 block). 1389 struct spmm_parameters spmm4; 1390 // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->SpCHW layout conversion. 1391 struct hwc2spchw_dconv_parameters hwc2spchw_dconv3x3c3s2; 1392 // Direct 3x3 stride-1 Convolution with padding 1 on left and right in SpCHW layout. 1393 struct spchw_dwconv_parameters spchw_dwconv3x3; 1394 // Direct 3x3 stride-2 Convolution with padding 1 on left and right in SpCHW layout. 1395 struct spchw_dwconv_parameters spchw_dwconv3x3s2; 1396 // Direct 5x5 stride-1 Convolution with padding 2 on left and right in SpCHW layout. 1397 struct spchw_dwconv_parameters spchw_dwconv5x5; 1398 // Direct 5x5 stride-2 Convolution with padding 2 on left and right in SpCHW layout. 1399 struct spchw_dwconv_parameters spchw_dwconv5x5s2; 1400 // Global Average Pooling in SpCHW layout. 1401 struct spchw_gavgpool_parameters spchw_gavgpool; 1402 } f32; 1403 struct { 1404 struct pad_parameters pad; 1405 xnn_unpool_ukernel_function unpool; 1406 struct zip_parameters zip; 1407 } x32; 1408 }; 1409 1410 extern XNN_INTERNAL struct xnn_parameters xnn_params; 1411