1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stdint.h> 12 #include <stddef.h> 13 14 #include <xnnpack/common.h> 15 #include <xnnpack/operator.h> 16 17 18 #ifdef __cplusplus 19 extern "C" { 20 #endif 21 22 23 struct xnn_qu8_packing_params { 24 uint8_t input_zero_point; 25 uint8_t kernel_zero_point; 26 }; 27 28 struct xnn_qs8_packing_params { 29 int8_t input_zero_point; 30 }; 31 32 33 typedef void (*xnn_pack_gemm_goi_w_function)( 34 size_t g, 35 size_t nc, 36 size_t kc, 37 size_t nr, 38 size_t kr, 39 size_t sr, 40 const void* k, 41 const void* b, 42 void* packed_w, 43 size_t extra_bytes, 44 const void* params); 45 46 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w( 47 size_t g, 48 size_t nc, 49 size_t kc, 50 size_t nr, 51 size_t kr, 52 size_t sr, 53 const float* k, 54 const float* b, 55 float* packed_w, 56 size_t extra_bytes, 57 const void* params); 58 59 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w( 60 size_t g, 61 size_t nc, 62 size_t kc, 63 size_t nr, 64 size_t kr, 65 size_t sr, 66 const uint16_t* k, 67 const uint16_t* b, 68 uint16_t* packed_w, 69 size_t extra_bytes, 70 const void* params); 71 72 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_goi_w( 73 size_t g, 74 size_t nc, 75 size_t kc, 76 size_t nr, 77 size_t kr, 78 size_t sr, 79 const float* k, 80 const float* b, 81 uint16_t* packed_w, 82 size_t extra_bytes, 83 const void* params); 84 85 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w( 86 size_t g, 87 size_t nc, 88 size_t kc, 89 size_t nr, 90 size_t kr, 91 size_t sr, 92 const uint8_t* k, 93 const int32_t* b, 94 void* packed_w, 95 size_t extra_bytes, 96 const struct xnn_qu8_packing_params* params); 97 98 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w( 99 size_t g, 100 size_t nc, 101 size_t kc, 102 size_t nr, 103 size_t kr, 104 size_t sr, 105 const int8_t* k, 106 const int32_t* b, 107 void* packed_w, 108 size_t extra_bytes, 109 const struct xnn_qs8_packing_params* params); 110 111 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w( 112 size_t g, 113 size_t nc, 114 size_t kc, 115 size_t nr, 116 size_t kr, 117 size_t sr, 118 const int8_t* k, 119 const int32_t* b, 120 void* packed_w, 121 size_t extra_bytes, 122 const struct xnn_qs8_packing_params* params); 123 124 125 typedef void (*xnn_pack_gemm_io_w_function)( 126 size_t nc, 127 size_t kc, 128 size_t nr, 129 size_t kr, 130 size_t sr, 131 const void* k, 132 const void* b, 133 void* packed_w, 134 const void* params); 135 136 XNN_INTERNAL void xnn_pack_f32_gemm_io_w( 137 size_t nc, 138 size_t kc, 139 size_t nr, 140 size_t kr, 141 size_t sr, 142 const float* k, 143 const float* b, 144 float* packed_w, 145 const void* params); 146 147 XNN_INTERNAL void xnn_pack_f16_gemm_io_w( 148 size_t nc, 149 size_t kc, 150 size_t nr, 151 size_t kr, 152 size_t sr, 153 const uint16_t* k, 154 const uint16_t* b, 155 uint16_t* packed_w, 156 const void* params); 157 158 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_io_w( 159 size_t nc, 160 size_t kc, 161 size_t nr, 162 size_t kr, 163 size_t sr, 164 const float* k, 165 const float* b, 166 uint16_t* packed_w, 167 const void* params); 168 169 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w( 170 size_t nc, 171 size_t kc, 172 size_t nr, 173 size_t kr, 174 size_t sr, 175 const uint8_t* k, 176 const int32_t* b, 177 void* packed_w, 178 const struct xnn_qu8_packing_params* params); 179 180 XNN_INTERNAL void xnn_pack_qs8_gemm_io_w( 181 size_t nc, 182 size_t kc, 183 size_t nr, 184 size_t kr, 185 size_t sr, 186 const int8_t* k, 187 const int32_t* b, 188 void* packed_w, 189 const struct xnn_qs8_packing_params* params); 190 191 192 typedef void (*xnn_pack_conv_goki_w_function)( 193 size_t g, 194 size_t nc, 195 size_t ks, 196 size_t kc, 197 size_t nr, 198 size_t kr, 199 size_t sr, 200 const void* k, 201 const void* b, 202 void* packed_w, 203 size_t extra_bytes, 204 const void* params); 205 206 XNN_INTERNAL void xnn_pack_f32_conv_goki_w( 207 size_t g, 208 size_t nc, 209 size_t ks, 210 size_t kc, 211 size_t nr, 212 size_t kr, 213 size_t sr, 214 const float* k, 215 const float* b, 216 float* packed_w, 217 size_t extra_bytes, 218 const void* params); 219 220 XNN_INTERNAL void xnn_pack_f16_conv_goki_w( 221 size_t g, 222 size_t nc, 223 size_t ks, 224 size_t kc, 225 size_t nr, 226 size_t kr, 227 size_t sr, 228 const uint16_t* k, 229 const uint16_t* b, 230 uint16_t* packed_w, 231 size_t extra_bytes, 232 const void* params); 233 234 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_goki_w( 235 size_t g, 236 size_t nc, 237 size_t ks, 238 size_t kc, 239 size_t nr, 240 size_t kr, 241 size_t sr, 242 const float* k, 243 const float* b, 244 uint16_t* packed_w, 245 size_t extra_bytes, 246 const void* params); 247 248 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w( 249 size_t g, 250 size_t nc, 251 size_t ks, 252 size_t kc, 253 size_t nr, 254 size_t kr, 255 size_t sr, 256 const uint8_t* k, 257 const int32_t* b, 258 void* packed_w, 259 size_t extra_bytes, 260 const struct xnn_qu8_packing_params* params); 261 262 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w( 263 size_t g, 264 size_t nc, 265 size_t ks, 266 size_t kc, 267 size_t nr, 268 size_t kr, 269 size_t sr, 270 const int8_t* k, 271 const int32_t* b, 272 void* packed_w, 273 size_t extra_bytes, 274 const struct xnn_qs8_packing_params* params); 275 276 277 typedef void (*xnn_pack_conv_kgo_w_function)( 278 size_t g, 279 size_t nc, 280 size_t ks, 281 size_t nr, 282 size_t kr, 283 size_t sr, 284 const void* k, 285 const void* b, 286 void* packed_w, 287 size_t extra_bytes, 288 const void* params); 289 290 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w( 291 size_t g, 292 size_t nc, 293 size_t ks, 294 size_t nr, 295 size_t kr, 296 size_t sr, 297 const float* k, 298 const float* b, 299 float* packed_w, 300 size_t extra_bytes, 301 const void* params); 302 303 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w( 304 size_t g, 305 size_t nc, 306 size_t ks, 307 size_t nr, 308 size_t kr, 309 size_t sr, 310 const uint16_t* k, 311 const uint16_t* b, 312 uint16_t* packed_w, 313 size_t extra_bytes, 314 const void* params); 315 316 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_kgo_w( 317 size_t g, 318 size_t nc, 319 size_t ks, 320 size_t nr, 321 size_t kr, 322 size_t sr, 323 const float* k, 324 const float* b, 325 uint16_t* packed_w, 326 size_t extra_bytes, 327 const void* params); 328 329 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w( 330 size_t g, 331 size_t nc, 332 size_t ks, 333 size_t nr, 334 size_t kr, 335 size_t sr, 336 const uint8_t* k, 337 const int32_t* b, 338 void* packed_w, 339 size_t extra_bytes, 340 const struct xnn_qu8_packing_params* params); 341 342 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w( 343 size_t g, 344 size_t nc, 345 size_t ks, 346 size_t nr, 347 size_t kr, 348 size_t sr, 349 const int8_t* k, 350 const int32_t* b, 351 void* packed_w, 352 size_t extra_bytes, 353 const struct xnn_qs8_packing_params* params); 354 355 356 typedef void (*xnn_pack_deconv_goki_w_function)( 357 size_t g, 358 size_t nc, 359 size_t kh, 360 size_t kw, 361 size_t kc, 362 size_t sh, 363 size_t sw, 364 size_t nr, 365 size_t kr, 366 size_t sr, 367 const void* k, 368 const void* b, 369 void* packed_w, 370 struct subconvolution_params* subconv_params, 371 const void* params); 372 373 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w( 374 size_t g, 375 size_t nc, 376 size_t kh, 377 size_t kw, 378 size_t kc, 379 size_t sh, 380 size_t sw, 381 size_t nr, 382 size_t kr, 383 size_t sr, 384 const float* k, 385 const float* b, 386 float* packed_w, 387 struct subconvolution_params* subconv_params, 388 const void* params); 389 390 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w( 391 size_t g, 392 size_t nc, 393 size_t kh, 394 size_t kw, 395 size_t kc, 396 size_t sh, 397 size_t sw, 398 size_t nr, 399 size_t kr, 400 size_t sr, 401 const uint16_t* k, 402 const uint16_t* b, 403 uint16_t* packed_w, 404 struct subconvolution_params* subconv_params, 405 const void* params); 406 407 XNN_INTERNAL void xnn_pack_qs8_deconv_goki_w( 408 size_t g, 409 size_t nc, 410 size_t kh, 411 size_t kw, 412 size_t kc, 413 size_t sh, 414 size_t sw, 415 size_t nr, 416 size_t kr, 417 size_t sr, 418 const int8_t* k, 419 const int32_t* b, 420 void* packed_w, 421 struct subconvolution_params* subconv_params, 422 const struct xnn_qs8_packing_params* params); 423 424 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w( 425 size_t g, 426 size_t nc, 427 size_t kh, 428 size_t kw, 429 size_t kc, 430 size_t sh, 431 size_t sw, 432 size_t nr, 433 size_t kr, 434 size_t sr, 435 const uint8_t* k, 436 const int32_t* b, 437 void* packed_w, 438 struct subconvolution_params* subconv_params, 439 const struct xnn_qu8_packing_params* params); 440 441 442 typedef void (*xnn_pack_dwconv_ghw_w_function)( 443 size_t h, 444 size_t w, 445 size_t c, 446 size_t cr, 447 const void* k, 448 const void* b, 449 void* packed_w, 450 size_t extra_bytes, 451 const void* params); 452 453 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w( 454 size_t h, 455 size_t w, 456 size_t c, 457 size_t cr, 458 const float* k, 459 const float* b, 460 float* packed_w, 461 size_t extra_bytes, 462 const void* params); 463 464 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w( 465 size_t h, 466 size_t w, 467 size_t c, 468 size_t cr, 469 const uint16_t* k, 470 const uint16_t* b, 471 uint16_t* packed_w, 472 size_t extra_bytes, 473 const void* params); 474 475 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_ghw_w( 476 size_t h, 477 size_t w, 478 size_t c, 479 size_t cr, 480 const float* k, 481 const float* b, 482 uint16_t* packed_w, 483 size_t extra_bytes, 484 const void* params); 485 486 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w( 487 size_t h, 488 size_t w, 489 size_t c, 490 size_t cr, 491 const uint8_t* k, 492 const int32_t* b, 493 void* packed_w, 494 size_t extra_bytes, 495 const struct xnn_qu8_packing_params* params); 496 497 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w( 498 size_t h, 499 size_t w, 500 size_t c, 501 size_t cr, 502 const int8_t* k, 503 const int32_t* b, 504 void* packed_w, 505 size_t extra_bytes, 506 const struct xnn_qs8_packing_params* params); 507 508 509 typedef void (*xnn_pack_dwconv_hwg_w_function)( 510 size_t h, 511 size_t w, 512 size_t c, 513 size_t cr, 514 const void* k, 515 const void* b, 516 void* packed_w, 517 size_t extra_bytes, 518 const void* params); 519 520 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w( 521 size_t h, 522 size_t w, 523 size_t c, 524 size_t cr, 525 const float* k, 526 const float* b, 527 float* packed_w, 528 size_t extra_bytes, 529 const void* params); 530 531 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w( 532 size_t h, 533 size_t w, 534 size_t c, 535 size_t cr, 536 const uint16_t* k, 537 const uint16_t* b, 538 uint16_t* packed_w, 539 size_t extra_bytes, 540 const void* params); 541 542 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_hwg_w( 543 size_t h, 544 size_t w, 545 size_t c, 546 size_t cr, 547 const float* k, 548 const float* b, 549 uint16_t* packed_w, 550 size_t extra_bytes, 551 const void* params); 552 553 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w( 554 size_t h, 555 size_t w, 556 size_t c, 557 size_t cr, 558 const uint8_t* k, 559 const int32_t* b, 560 void* packed_w, 561 size_t extra_bytes, 562 const struct xnn_qu8_packing_params* params); 563 564 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w( 565 size_t h, 566 size_t w, 567 size_t c, 568 size_t cr, 569 const int8_t* k, 570 const int32_t* b, 571 void* packed_w, 572 size_t extra_bytes, 573 const struct xnn_qs8_packing_params* params); 574 575 576 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w( 577 size_t g, 578 size_t nc, 579 size_t kc, 580 size_t nr, 581 size_t kr, 582 size_t sr, 583 const float* k, 584 float* packed_w, 585 const void* params); 586 587 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w( 588 size_t g, 589 size_t nc, 590 size_t kc, 591 size_t nr, 592 size_t kr, 593 size_t sr, 594 const uint16_t* k, 595 uint16_t* packed_w, 596 const void* params); 597 598 599 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w( 600 size_t nc, 601 size_t kc, 602 size_t nr, 603 size_t kh, 604 size_t kw, 605 const float* k, 606 const float* b, 607 float* packed_w, 608 const void* params); 609 610 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w( 611 size_t nc, 612 size_t kc, 613 size_t nr, 614 size_t kh, 615 size_t kw, 616 const uint16_t* k, 617 const uint16_t* b, 618 uint16_t* packed_w, 619 const void* params); 620 621 622 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w( 623 size_t kernel_size, 624 size_t groups, 625 const float* kernel, 626 const float* bias, 627 float* packed_weights, 628 const void* params); 629 630 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w( 631 size_t kernel_size, 632 size_t groups, 633 const uint16_t* kernel, 634 const uint16_t* bias, 635 uint16_t* packed_weights, 636 const void* params); 637 638 639 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w( 640 size_t kernel_size, 641 size_t groups, 642 const float* kernel, 643 const float* bias, 644 float* packed_weights, 645 const void* params); 646 647 648 typedef void (*xnn_pack_vmulcaddc_w_function)( 649 size_t c, 650 size_t cr, 651 const void* s, 652 const void* b, 653 void* packed_w, 654 const void* params); 655 656 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w( 657 size_t c, 658 size_t cr, 659 const float* s, 660 const float* b, 661 float* packed_w, 662 const void* params); 663 664 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w( 665 size_t c, 666 size_t cr, 667 const uint16_t* s, 668 const uint16_t* b, 669 uint16_t* packed_w, 670 const void* params); 671 672 XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w( 673 size_t c, 674 size_t cr, 675 const float* s, 676 const float* b, 677 uint16_t* packed_w, 678 const void* params); 679 680 681 typedef void (*xnn_pack_prelu_w_function)( 682 size_t c, 683 const void* s, 684 void* packed_w); 685 686 XNN_INTERNAL void xnn_pack_f32_prelu_w( 687 size_t c, 688 const float* s, 689 float* packed_w); 690 691 XNN_INTERNAL void xnn_pack_f16_prelu_w( 692 size_t c, 693 const uint16_t* s, 694 uint16_t* packed_w); 695 696 XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w( 697 size_t c, 698 const float* s, 699 uint16_t* packed_w); 700 701 702 #ifdef __cplusplus 703 } // extern "C" 704 #endif 705