1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import bisect 9import codecs 10import collections 11import os 12import sys 13import yaml 14import zlib 15 16sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 17from primes import next_prime 18import xngen 19import xnncommon 20 21parser = argparse.ArgumentParser(description="XNNPACK generator") 22parser.add_argument( 23 "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file") 24parser.add_argument( 25 "-o", 26 "--output", 27 action="append", 28 metavar="FILE", 29 required=True, 30 help="Output (C++ source) file(s)") 31parser.set_defaults(defines=list()) 32 33 34def split_ukernel_name(name): 35 common_name, target_name = name.split("__", 1) 36 common_parts = common_name.split("_") 37 xw = "gemm_xw_" in common_name 38 param_spec = common_parts[-1] 39 if "s" in param_spec: 40 param_spec, sr = param_spec.split("s", 1) 41 sr = int(sr) 42 else: 43 sr = 1 44 if "c" in param_spec: 45 param_spec, kr = param_spec.split("c", 1) 46 kr = int(kr) 47 else: 48 kr = 1 49 mr, nr = map(int, param_spec.split("x")) 50 arch, isa = xnncommon.parse_target_name(target_name) 51 52 requantization = common_parts[-3] 53 if requantization not in ["fp32", "rndnu"]: 54 requantization = None 55 56 return mr, nr, kr, sr, xw, requantization, arch, isa 57 58 59GEMM_TEST_CODE = """\ 60TEST(${TEST_NAME}, k_eq_${KBLOCK}) { 61 $if ISA_CHECK: 62 ${ISA_CHECK}; 63 GemmMicrokernelTester() 64 $if EXTENDED_WEIGHTS: 65 .extended_weights(true) 66 .mr(${MR}) 67 .nr(${NR}) 68 .kr(${KR}) 69 .sr(${SR}) 70 .m(${MR}) 71 .n(${NR}) 72 .k(${KBLOCK}) 73 .Test(${", ".join(TEST_ARGS)}); 74} 75 76TEST(${TEST_NAME}, strided_cn) { 77 $if ISA_CHECK: 78 ${ISA_CHECK}; 79 GemmMicrokernelTester() 80 $if EXTENDED_WEIGHTS: 81 .extended_weights(true) 82 .mr(${MR}) 83 .nr(${NR}) 84 .kr(${KR}) 85 .sr(${SR}) 86 .m(${MR}) 87 .n(${NR}) 88 .k(${KBLOCK}) 89 .cn_stride(${next_prime(NR + 1)}) 90 .Test(${", ".join(TEST_ARGS)}); 91} 92 93$if UKERNEL_TYPE != "IGEMM": 94 TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) { 95 $if ISA_CHECK: 96 ${ISA_CHECK}; 97 GemmMicrokernelTester() 98 $if EXTENDED_WEIGHTS: 99 .extended_weights(true) 100 .mr(${MR}) 101 .nr(${NR}) 102 .kr(${KR}) 103 .sr(${SR}) 104 .m(${MR}) 105 .n(${NR}) 106 .k(${KBLOCK}) 107 .a_stride(${next_prime(KBLOCK + 1)}) 108 .Test(${", ".join(TEST_ARGS)}); 109 } 110 111TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) { 112 $if ISA_CHECK: 113 ${ISA_CHECK}; 114 for (uint32_t n = 1; n <= ${NR}; n++) { 115 for (uint32_t m = 1; m <= ${MR}; m++) { 116 GemmMicrokernelTester() 117 $if EXTENDED_WEIGHTS: 118 .extended_weights(true) 119 .mr(${MR}) 120 .nr(${NR}) 121 .kr(${KR}) 122 .sr(${SR}) 123 .m(m) 124 .n(n) 125 .k(${KBLOCK}) 126 .iterations(1) 127 .Test(${", ".join(TEST_ARGS)}); 128 } 129 } 130} 131 132TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) { 133 $if ISA_CHECK: 134 ${ISA_CHECK}; 135 for (uint32_t m = 1; m <= ${MR}; m++) { 136 GemmMicrokernelTester() 137 $if EXTENDED_WEIGHTS: 138 .extended_weights(true) 139 .mr(${MR}) 140 .nr(${NR}) 141 .kr(${KR}) 142 .sr(${SR}) 143 .m(m) 144 .n(${NR}) 145 .k(${KBLOCK}) 146 .iterations(1) 147 .Test(${", ".join(TEST_ARGS)}); 148 } 149} 150 151 152TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) { 153 $if ISA_CHECK: 154 ${ISA_CHECK}; 155 for (uint32_t n = 1; n <= ${NR}; n++) { 156 GemmMicrokernelTester() 157 $if EXTENDED_WEIGHTS: 158 .extended_weights(true) 159 .mr(${MR}) 160 .nr(${NR}) 161 .kr(${KR}) 162 .sr(${SR}) 163 .m(${MR}) 164 .n(n) 165 .k(${KBLOCK}) 166 .iterations(1) 167 .Test(${", ".join(TEST_ARGS)}); 168 } 169} 170 171$if IS_PIPELINED: 172 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) { 173 $if ISA_CHECK: 174 ${ISA_CHECK}; 175 GemmMicrokernelTester() 176 $if EXTENDED_WEIGHTS: 177 .extended_weights(true) 178 .mr(${MR}) 179 .nr(${NR}) 180 .kr(${KR}) 181 .sr(${SR}) 182 .m(${MR}) 183 .n(${NR}) 184 .k(${KBLOCK * 2}) 185 .Test(${", ".join(TEST_ARGS)}); 186 } 187 188 $if UKERNEL_TYPE != "IGEMM": 189 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) { 190 $if ISA_CHECK: 191 ${ISA_CHECK}; 192 GemmMicrokernelTester() 193 $if EXTENDED_WEIGHTS: 194 .extended_weights(true) 195 .mr(${MR}) 196 .nr(${NR}) 197 .kr(${KR}) 198 .sr(${SR}) 199 .m(${MR}) 200 .n(${NR}) 201 .k(${KBLOCK * 2}) 202 .a_stride(${next_prime(KBLOCK * 2 + 1)}) 203 .Test(${", ".join(TEST_ARGS)}); 204 } 205 206 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) { 207 $if ISA_CHECK: 208 ${ISA_CHECK}; 209 for (uint32_t n = 1; n <= ${NR}; n++) { 210 for (uint32_t m = 1; m <= ${MR}; m++) { 211 GemmMicrokernelTester() 212 $if EXTENDED_WEIGHTS: 213 .extended_weights(true) 214 .mr(${MR}) 215 .nr(${NR}) 216 .kr(${KR}) 217 .sr(${SR}) 218 .m(m) 219 .n(n) 220 .k(${KBLOCK * 2}) 221 .iterations(1) 222 .Test(${", ".join(TEST_ARGS)}); 223 } 224 } 225 } 226 227$if KBLOCK > 1: 228 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) { 229 $if ISA_CHECK: 230 ${ISA_CHECK}; 231 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 232 GemmMicrokernelTester() 233 $if EXTENDED_WEIGHTS: 234 .extended_weights(true) 235 .mr(${MR}) 236 .nr(${NR}) 237 .kr(${KR}) 238 .sr(${SR}) 239 .m(${MR}) 240 .n(${NR}) 241 .k(k) 242 .Test(${", ".join(TEST_ARGS)}); 243 } 244 } 245 246 $if UKERNEL_TYPE != "IGEMM": 247 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) { 248 $if ISA_CHECK: 249 ${ISA_CHECK}; 250 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 251 GemmMicrokernelTester() 252 $if EXTENDED_WEIGHTS: 253 .extended_weights(true) 254 .mr(${MR}) 255 .nr(${NR}) 256 .kr(${KR}) 257 .sr(${SR}) 258 .m(${MR}) 259 .n(${NR}) 260 .k(k) 261 .a_stride(${next_prime(ADJKBLOCK + 1)}) 262 .Test(${", ".join(TEST_ARGS)}); 263 } 264 } 265 266 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) { 267 $if ISA_CHECK: 268 ${ISA_CHECK}; 269 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 270 for (uint32_t n = 1; n <= ${NR}; n++) { 271 for (uint32_t m = 1; m <= ${MR}; m++) { 272 GemmMicrokernelTester() 273 $if EXTENDED_WEIGHTS: 274 .extended_weights(true) 275 .mr(${MR}) 276 .nr(${NR}) 277 .kr(${KR}) 278 .sr(${SR}) 279 .m(m) 280 .n(n) 281 .k(k) 282 .iterations(1) 283 .Test(${", ".join(TEST_ARGS)}); 284 } 285 } 286 } 287 } 288 289TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) { 290 $if ISA_CHECK: 291 ${ISA_CHECK}; 292 for (size_t k = ${ADJKBLOCK + 1}; k < ${ADJKBLOCK * 10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 293 GemmMicrokernelTester() 294 $if EXTENDED_WEIGHTS: 295 .extended_weights(true) 296 .mr(${MR}) 297 .nr(${NR}) 298 .kr(${KR}) 299 .sr(${SR}) 300 .m(${MR}) 301 .n(${NR}) 302 .k(k) 303 .Test(${", ".join(TEST_ARGS)}); 304 } 305} 306 307$if UKERNEL_TYPE.startswith("GEMM"): 308 TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_strided_a) { 309 $if ISA_CHECK: 310 ${ISA_CHECK}; 311 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 312 GemmMicrokernelTester() 313 $if EXTENDED_WEIGHTS: 314 .extended_weights(true) 315 .mr(${MR}) 316 .nr(${NR}) 317 .kr(${KR}) 318 .sr(${SR}) 319 .m(${MR}) 320 .n(${NR}) 321 .k(k) 322 .a_stride(${next_prime(10 if ADJKBLOCK == 1 else ADJKBLOCK * 2 + 1)}) 323 .Test(${", ".join(TEST_ARGS)}); 324 } 325 } 326 327TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_subtile) { 328 $if ISA_CHECK: 329 ${ISA_CHECK}; 330 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) { 331 for (uint32_t n = 1; n <= ${NR}; n++) { 332 for (uint32_t m = 1; m <= ${MR}; m++) { 333 GemmMicrokernelTester() 334 $if EXTENDED_WEIGHTS: 335 .extended_weights(true) 336 .mr(${MR}) 337 .nr(${NR}) 338 .kr(${KR}) 339 .sr(${SR}) 340 .m(m) 341 .n(n) 342 .k(k) 343 .iterations(1) 344 .Test(${", ".join(TEST_ARGS)}); 345 } 346 } 347 } 348} 349 350$if KBLOCK > 1: 351 TEST(${TEST_NAME}, k_div_${KBLOCK}) { 352 $if ISA_CHECK: 353 ${ISA_CHECK}; 354 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 355 GemmMicrokernelTester() 356 $if EXTENDED_WEIGHTS: 357 .extended_weights(true) 358 .mr(${MR}) 359 .nr(${NR}) 360 .kr(${KR}) 361 .sr(${SR}) 362 .m(${MR}) 363 .n(${NR}) 364 .k(k) 365 .Test(${", ".join(TEST_ARGS)}); 366 } 367 } 368 369 $if UKERNEL_TYPE.startswith("GEMM"): 370 TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) { 371 $if ISA_CHECK: 372 ${ISA_CHECK}; 373 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 374 GemmMicrokernelTester() 375 $if EXTENDED_WEIGHTS: 376 .extended_weights(true) 377 .mr(${MR}) 378 .nr(${NR}) 379 .kr(${KR}) 380 .sr(${SR}) 381 .m(${MR}) 382 .n(${NR}) 383 .k(k) 384 .a_stride(${next_prime(KBLOCK * 10 + 1)}) 385 .Test(${", ".join(TEST_ARGS)}); 386 } 387 } 388 389 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) { 390 $if ISA_CHECK: 391 ${ISA_CHECK}; 392 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 393 for (uint32_t n = 1; n <= ${NR}; n++) { 394 for (uint32_t m = 1; m <= ${MR}; m++) { 395 GemmMicrokernelTester() 396 $if EXTENDED_WEIGHTS: 397 .extended_weights(true) 398 .mr(${MR}) 399 .nr(${NR}) 400 .kr(${KR}) 401 .sr(${SR}) 402 .m(m) 403 .n(n) 404 .k(k) 405 .iterations(1) 406 .Test(${", ".join(TEST_ARGS)}); 407 } 408 } 409 } 410 } 411 412TEST(${TEST_NAME}, n_gt_${NR}) { 413 $if ISA_CHECK: 414 ${ISA_CHECK}; 415 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 416 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 417 GemmMicrokernelTester() 418 $if EXTENDED_WEIGHTS: 419 .extended_weights(true) 420 .mr(${MR}) 421 .nr(${NR}) 422 .kr(${KR}) 423 .sr(${SR}) 424 .m(${MR}) 425 .n(n) 426 .k(k) 427 .Test(${", ".join(TEST_ARGS)}); 428 } 429 } 430} 431 432TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) { 433 $if ISA_CHECK: 434 ${ISA_CHECK}; 435 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 436 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 437 GemmMicrokernelTester() 438 $if EXTENDED_WEIGHTS: 439 .extended_weights(true) 440 .mr(${MR}) 441 .nr(${NR}) 442 .kr(${KR}) 443 .sr(${SR}) 444 .m(${MR}) 445 .n(n) 446 .k(k) 447 .cn_stride(${next_prime(NR + 1)}) 448 .Test(${", ".join(TEST_ARGS)}); 449 } 450 } 451} 452 453$if UKERNEL_TYPE != "IGEMM": 454 TEST(${TEST_NAME}, n_gt_${NR}_strided_a) { 455 $if ISA_CHECK: 456 ${ISA_CHECK}; 457 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 458 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 459 GemmMicrokernelTester() 460 $if EXTENDED_WEIGHTS: 461 .extended_weights(true) 462 .mr(${MR}) 463 .nr(${NR}) 464 .kr(${KR}) 465 .sr(${SR}) 466 .m(${MR}) 467 .n(n) 468 .k(k) 469 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 470 .Test(${", ".join(TEST_ARGS)}); 471 } 472 } 473 } 474 475TEST(${TEST_NAME}, n_gt_${NR}_subtile) { 476 $if ISA_CHECK: 477 ${ISA_CHECK}; 478 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 479 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 480 for (uint32_t m = 1; m <= ${MR}; m++) { 481 GemmMicrokernelTester() 482 $if EXTENDED_WEIGHTS: 483 .extended_weights(true) 484 .mr(${MR}) 485 .nr(${NR}) 486 .kr(${KR}) 487 .sr(${SR}) 488 .m(m) 489 .n(n) 490 .k(k) 491 .iterations(1) 492 .Test(${", ".join(TEST_ARGS)}); 493 } 494 } 495 } 496} 497 498TEST(${TEST_NAME}, n_div_${NR}) { 499 $if ISA_CHECK: 500 ${ISA_CHECK}; 501 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 502 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 503 GemmMicrokernelTester() 504 $if EXTENDED_WEIGHTS: 505 .extended_weights(true) 506 .mr(${MR}) 507 .nr(${NR}) 508 .kr(${KR}) 509 .sr(${SR}) 510 .m(${MR}) 511 .n(n) 512 .k(k) 513 .Test(${", ".join(TEST_ARGS)}); 514 } 515 } 516} 517 518TEST(${TEST_NAME}, n_div_${NR}_strided_cn) { 519 $if ISA_CHECK: 520 ${ISA_CHECK}; 521 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 522 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 523 GemmMicrokernelTester() 524 $if EXTENDED_WEIGHTS: 525 .extended_weights(true) 526 .mr(${MR}) 527 .nr(${NR}) 528 .kr(${KR}) 529 .sr(${SR}) 530 .m(${MR}) 531 .n(n) 532 .k(k) 533 .cn_stride(${next_prime(NR + 1)}) 534 .Test(${", ".join(TEST_ARGS)}); 535 } 536 } 537} 538 539$if UKERNEL_TYPE != "IGEMM": 540 TEST(${TEST_NAME}, n_div_${NR}_strided_a) { 541 $if ISA_CHECK: 542 ${ISA_CHECK}; 543 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 544 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 545 GemmMicrokernelTester() 546 $if EXTENDED_WEIGHTS: 547 .extended_weights(true) 548 .mr(${MR}) 549 .nr(${NR}) 550 .kr(${KR}) 551 .sr(${SR}) 552 .m(${MR}) 553 .n(n) 554 .k(k) 555 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 556 .Test(${", ".join(TEST_ARGS)}); 557 } 558 } 559 } 560 561TEST(${TEST_NAME}, n_div_${NR}_subtile) { 562 $if ISA_CHECK: 563 ${ISA_CHECK}; 564 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 565 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 566 for (uint32_t m = 1; m <= ${MR}; m++) { 567 GemmMicrokernelTester() 568 $if EXTENDED_WEIGHTS: 569 .extended_weights(true) 570 .mr(${MR}) 571 .nr(${NR}) 572 .kr(${KR}) 573 .sr(${SR}) 574 .m(m) 575 .n(n) 576 .k(k) 577 .iterations(1) 578 .Test(${", ".join(TEST_ARGS)}); 579 } 580 } 581 } 582} 583 584$if UKERNEL_TYPE.startswith("IGEMM"): 585 TEST(${TEST_NAME}, small_kernel) { 586 $if ISA_CHECK: 587 ${ISA_CHECK}; 588 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 589 GemmMicrokernelTester() 590 $if EXTENDED_WEIGHTS: 591 .extended_weights(true) 592 .mr(${MR}) 593 .nr(${NR}) 594 .kr(${KR}) 595 .sr(${SR}) 596 .m(${MR}) 597 .n(${NR}) 598 .k(k) 599 .ks(3) 600 .Test(${", ".join(TEST_ARGS)}); 601 } 602 } 603 604 TEST(${TEST_NAME}, small_kernel_subtile) { 605 $if ISA_CHECK: 606 ${ISA_CHECK}; 607 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 608 for (uint32_t n = 1; n <= ${NR}; n++) { 609 for (uint32_t m = 1; m <= ${MR}; m++) { 610 GemmMicrokernelTester() 611 $if EXTENDED_WEIGHTS: 612 .extended_weights(true) 613 .mr(${MR}) 614 .nr(${NR}) 615 .kr(${KR}) 616 .sr(${SR}) 617 .m(m) 618 .n(n) 619 .k(k) 620 .ks(3) 621 .iterations(1) 622 .Test(${", ".join(TEST_ARGS)}); 623 } 624 } 625 } 626 } 627 628 TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) { 629 $if ISA_CHECK: 630 ${ISA_CHECK}; 631 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 632 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 633 GemmMicrokernelTester() 634 $if EXTENDED_WEIGHTS: 635 .extended_weights(true) 636 .mr(${MR}) 637 .nr(${NR}) 638 .kr(${KR}) 639 .sr(${SR}) 640 .m(${MR}) 641 .n(n) 642 .k(k) 643 .ks(3) 644 .Test(${", ".join(TEST_ARGS)}); 645 } 646 } 647 } 648 649 TEST(${TEST_NAME}, n_div_${NR}_small_kernel) { 650 $if ISA_CHECK: 651 ${ISA_CHECK}; 652 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 653 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 654 GemmMicrokernelTester() 655 $if EXTENDED_WEIGHTS: 656 .extended_weights(true) 657 .mr(${MR}) 658 .nr(${NR}) 659 .kr(${KR}) 660 .sr(${SR}) 661 .m(${MR}) 662 .n(n) 663 .k(k) 664 .ks(3) 665 .Test(${", ".join(TEST_ARGS)}); 666 } 667 } 668 } 669 670TEST(${TEST_NAME}, strided_cm_subtile) { 671 $if ISA_CHECK: 672 ${ISA_CHECK}; 673 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 674 for (uint32_t n = 1; n <= ${NR}; n++) { 675 for (uint32_t m = 1; m <= ${MR}; m++) { 676 GemmMicrokernelTester() 677 $if EXTENDED_WEIGHTS: 678 .extended_weights(true) 679 .mr(${MR}) 680 .nr(${NR}) 681 .kr(${KR}) 682 .sr(${SR}) 683 .m(m) 684 .n(n) 685 .k(k) 686 .cm_stride(${next_prime(NR + 1)}) 687 .iterations(1) 688 .Test(${", ".join(TEST_ARGS)}); 689 } 690 } 691 } 692} 693 694$if UKERNEL_TYPE.startswith("IGEMM"): 695 TEST(${TEST_NAME}, a_offset) { 696 $if ISA_CHECK: 697 ${ISA_CHECK}; 698 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 699 GemmMicrokernelTester() 700 $if EXTENDED_WEIGHTS: 701 .extended_weights(true) 702 .mr(${MR}) 703 .nr(${NR}) 704 .kr(${KR}) 705 .sr(${SR}) 706 .m(${MR}) 707 .n(${NR}) 708 .k(k) 709 .ks(3) 710 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 711 .Test(${", ".join(TEST_ARGS)}); 712 } 713 } 714 715 TEST(${TEST_NAME}, zero) { 716 $if ISA_CHECK: 717 ${ISA_CHECK}; 718 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 719 for (uint32_t mz = 0; mz < ${MR}; mz++) { 720 GemmMicrokernelTester() 721 $if EXTENDED_WEIGHTS: 722 .extended_weights(true) 723 .mr(${MR}) 724 .nr(${NR}) 725 .kr(${KR}) 726 .sr(${SR}) 727 .m(${MR}) 728 .n(${NR}) 729 .k(k) 730 .ks(3) 731 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 732 .zero_index(mz) 733 .Test(${", ".join(TEST_ARGS)}); 734 } 735 } 736 } 737 738$if ACTIVATION == "MINMAX": 739 TEST(${TEST_NAME}, qmin) { 740 $if ISA_CHECK: 741 ${ISA_CHECK}; 742 GemmMicrokernelTester() 743 $if EXTENDED_WEIGHTS: 744 .extended_weights(true) 745 .mr(${MR}) 746 .nr(${NR}) 747 .kr(${KR}) 748 .sr(${SR}) 749 .m(${MR}) 750 .n(${NR}) 751 .k(${KBLOCK}) 752 .qmin(128) 753 .Test(${", ".join(TEST_ARGS)}); 754 } 755 756 TEST(${TEST_NAME}, qmax) { 757 $if ISA_CHECK: 758 ${ISA_CHECK}; 759 GemmMicrokernelTester() 760 $if EXTENDED_WEIGHTS: 761 .extended_weights(true) 762 .mr(${MR}) 763 .nr(${NR}) 764 .kr(${KR}) 765 .sr(${SR}) 766 .m(${MR}) 767 .n(${NR}) 768 .k(${KBLOCK}) 769 .qmax(128) 770 .Test(${", ".join(TEST_ARGS)}); 771 } 772 773TEST(${TEST_NAME}, strided_cm) { 774 $if ISA_CHECK: 775 ${ISA_CHECK}; 776 GemmMicrokernelTester() 777 $if EXTENDED_WEIGHTS: 778 .extended_weights(true) 779 .mr(${MR}) 780 .nr(${NR}) 781 .kr(${KR}) 782 .sr(${SR}) 783 .m(${MR}) 784 .n(${NR}) 785 .k(${KBLOCK}) 786 .cm_stride(${next_prime(NR + 1)}) 787 .Test(${", ".join(TEST_ARGS)}); 788} 789 790$if DATATYPE == "qu8": 791 TEST(${TEST_NAME}, no_a_zero_point) { 792 $if ISA_CHECK: 793 ${ISA_CHECK}; 794 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 795 GemmMicrokernelTester() 796 $if EXTENDED_WEIGHTS: 797 .extended_weights(true) 798 .mr(${MR}) 799 .nr(${NR}) 800 .kr(${KR}) 801 .sr(${SR}) 802 .m(${MR}) 803 .n(${NR}) 804 .k(k) 805 .a_zero_point(0) 806 .Test(${", ".join(TEST_ARGS)}); 807 } 808 } 809 810 TEST(${TEST_NAME}, no_b_zero_point) { 811 $if ISA_CHECK: 812 ${ISA_CHECK}; 813 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 814 GemmMicrokernelTester() 815 $if EXTENDED_WEIGHTS: 816 .extended_weights(true) 817 .mr(${MR}) 818 .nr(${NR}) 819 .kr(${KR}) 820 .sr(${SR}) 821 .m(${MR}) 822 .n(${NR}) 823 .k(k) 824 .b_zero_point(0) 825 .Test(${", ".join(TEST_ARGS)}); 826 } 827 } 828 829 TEST(${TEST_NAME}, no_zero_point) { 830 $if ISA_CHECK: 831 ${ISA_CHECK}; 832 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 833 GemmMicrokernelTester() 834 $if EXTENDED_WEIGHTS: 835 .extended_weights(true) 836 .mr(${MR}) 837 .nr(${NR}) 838 .kr(${KR}) 839 .sr(${SR}) 840 .m(${MR}) 841 .n(${NR}) 842 .k(k) 843 .a_zero_point(0) 844 .b_zero_point(0) 845 .Test(${", ".join(TEST_ARGS)}); 846 } 847 } 848""" 849 850 851def generate_test_cases(ukernel, mr, nr, kr, sr, xw, k_block, init_fn, 852 requantization, is_pipelined, isa, jit): 853 """Generates all tests cases for a GEMM micro-kernel. 854 855 Args: 856 ukernel: C name of the micro-kernel function. 857 mr: MR parameter of the GEMM micro-kernel. 858 nr: NR parameter of the GEMM micro-kernel. 859 kr: KR parameter of the GEMM micro-kernel. 860 sr: SR parameter of the GEMM micro-kernel. 861 xw: boolean indicator for microkernel with extended weights. 862 k_block: Number of K values processed per one iteration of the main loop of 863 the micro-kernel. 864 init_fn: C name of the function to initialize microkernel parameters. 865 requantization: name of the requantization scheme used by the microkernel. 866 is_pipelined: Indicates if the micro-kernel is implemented with software 867 pipelining. Additional test cases are generated for software pipelined 868 micro-kernels to separately test prologue + epiloque of the pipelined loop 869 and iteration of the pipelined loop. 870 isa: instruction set required to run the micro-kernel. Generated unit test 871 will skip execution if the host processor doesn't support this ISA. 872 jit: if we are generating test code for JIT codegen. 873 874 Returns: 875 Code for the test case. 876 """ 877 _, ukernel_name = ukernel.split("_", 1) 878 879 if jit: 880 _, _, datatype, ukernel_type, _ = ukernel.split("_", 4) 881 activation = None 882 else: 883 _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4) 884 885 if activation == "ukernel": 886 activation = "linear" 887 test_args = [ukernel] 888 if init_fn: 889 test_args.append(init_fn) 890 if requantization: 891 requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype) 892 test_args.append("xnn_%s_requantize_%s" % \ 893 (requantization_datatype, requantization)) 894 895 if jit: 896 if "minmax" in init_fn: 897 activation = "minmax" 898 899 return xngen.preprocess( 900 GEMM_TEST_CODE, { 901 "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""), 902 "TEST_ARGS": test_args, 903 "UKERNEL_TYPE": ukernel_type.upper(), 904 "DATATYPE": datatype, 905 "ACTIVATION": activation.upper(), 906 "MR": mr, 907 "NR": nr, 908 "KR": kr, 909 "SR": sr, 910 "EXTENDED_WEIGHTS": xw, 911 "KBLOCK": k_block, 912 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, 913 "IS_PIPELINED": is_pipelined, 914 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 915 "next_prime": next_prime, 916 }) 917 918 919def main(args): 920 options = parser.parse_args(args) 921 num_output_files = len(options.output) 922 923 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 924 spec_yaml = yaml.safe_load(spec_file) 925 if not isinstance(spec_yaml, list): 926 raise ValueError("expected a list of micro-kernels in the spec") 927 928 tests = """\ 929// Copyright (c) Facebook, Inc. and its affiliates. 930// All rights reserved. 931// 932// Copyright 2019 Google LLC 933// 934// This source code is licensed under the BSD-style license found in the 935// LICENSE file in the root directory of this source tree. 936// 937// Auto-generated file. Do not edit! 938// Specification: {specification} 939// Generator: {generator} 940 941 942#include <gtest/gtest.h> 943 944#include <xnnpack/allocator.h> 945#include <xnnpack/common.h> 946#include <xnnpack/isa-checks.h> 947 948#include <xnnpack/gemm.h> 949#include <xnnpack/igemm.h> 950#include <xnnpack/ppmm.h> 951#include "gemm-microkernel-tester.h" 952""".format( 953 specification=options.spec, generator=sys.argv[0]) 954 955 outputs = collections.defaultdict(lambda: tests) 956 957 for ukernel_spec in spec_yaml: 958 name = ukernel_spec["name"] 959 k_block = int(ukernel_spec["k-block"]) 960 init_fn = ukernel_spec.get("init") 961 pipelined = bool(ukernel_spec.get("pipelined", False)) 962 assembly = bool(ukernel_spec.get("assembly", False)) 963 jit = name.startswith("xnn_generate") 964 mr, nr, kr, sr, xw, requantization, arch, isa = split_ukernel_name(name) 965 966 # specification can override architecture 967 arch = ukernel_spec.get("arch", arch) 968 969 test_case = generate_test_cases(name, mr, nr, kr, sr, xw, k_block, 970 init_fn, requantization, pipelined, isa, 971 jit) 972 973 # Hash the name of each microkernel and figure out which output file to 974 # write it to. 975 output_index = zlib.crc32(bytes(name, 'utf-8')) % num_output_files 976 outputs[options.output[output_index]] += "\n\n" + xnncommon.postprocess_test_case( 977 test_case, arch, isa, assembly, jit) 978 979 for output_name in options.output: 980 txt_changed = True 981 if os.path.exists(output_name): 982 with codecs.open(output_name, "r", encoding="utf-8") as output_file: 983 txt_changed = output_file.read() != outputs[output_name] 984 985 if txt_changed: 986 with codecs.open(output_name, "w", encoding="utf-8") as output_file: 987 output_file.write(outputs[output_name]) 988 989 990if __name__ == "__main__": 991 main(sys.argv[1:]) 992