1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import bisect 9import codecs 10import os 11import sys 12import yaml 13 14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 15from primes import next_prime 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser(description='XNNPACK generator') 21parser.add_argument("-s", "--spec", metavar="FILE", required=True, 22 help="Spec (YAML) file") 23parser.add_argument("-o", "--output", metavar="FILE", required=True, 24 help='Output (C++ source) file') 25parser.set_defaults(defines=list()) 26 27 28def split_ukernel_name(name): 29 common_name, target_name = name.split("__", 1) 30 common_parts = common_name.split("_") 31 param_spec = common_parts[-1] 32 if "s" in param_spec: 33 param_spec, sr = param_spec.split("s", 1) 34 sr = int(sr) 35 else: 36 sr = 1 37 if "c" in param_spec: 38 param_spec, kr = param_spec.split("c", 1) 39 kr = int(kr) 40 else: 41 kr = 1 42 mr, nr = map(int, param_spec.split("x")) 43 arch, isa = xnncommon.parse_target_name(target_name) 44 return mr, nr, kr, sr, arch, isa 45 46 47GEMM_TEST_CODE = """\ 48TEST(${TEST_NAME}, k_eq_${KBLOCK}) { 49 $if ISA_CHECK: 50 ${ISA_CHECK}; 51 GemmMicrokernelTester() 52 .mr(${MR}) 53 .nr(${NR}) 54 .kr(${KR}) 55 .sr(${SR}) 56 .m(${MR}) 57 .n(${NR}) 58 .k(${KBLOCK}) 59 .Test(${", ".join(TEST_ARGS)}); 60} 61 62TEST(${TEST_NAME}, strided_cn) { 63 $if ISA_CHECK: 64 ${ISA_CHECK}; 65 GemmMicrokernelTester() 66 .mr(${MR}) 67 .nr(${NR}) 68 .kr(${KR}) 69 .sr(${SR}) 70 .m(${MR}) 71 .n(${NR}) 72 .k(${KBLOCK}) 73 .cn_stride(${next_prime(NR + 1)}) 74 .Test(${", ".join(TEST_ARGS)}); 75} 76 77$if UKERNEL_TYPE != "IGEMM": 78 TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) { 79 $if ISA_CHECK: 80 ${ISA_CHECK}; 81 GemmMicrokernelTester() 82 .mr(${MR}) 83 .nr(${NR}) 84 .kr(${KR}) 85 .sr(${SR}) 86 .m(${MR}) 87 .n(${NR}) 88 .k(${KBLOCK}) 89 .a_stride(${next_prime(KBLOCK + 1)}) 90 .Test(${", ".join(TEST_ARGS)}); 91 } 92 93TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) { 94 $if ISA_CHECK: 95 ${ISA_CHECK}; 96 for (uint32_t m = 1; m <= ${MR}; m++) { 97 for (uint32_t n = 1; n <= ${NR}; n++) { 98 GemmMicrokernelTester() 99 .mr(${MR}) 100 .nr(${NR}) 101 .kr(${KR}) 102 .sr(${SR}) 103 .m(m) 104 .n(n) 105 .k(${KBLOCK}) 106 .iterations(1) 107 .Test(${", ".join(TEST_ARGS)}); 108 } 109 } 110} 111 112TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) { 113 $if ISA_CHECK: 114 ${ISA_CHECK}; 115 for (uint32_t m = 1; m <= ${MR}; m++) { 116 GemmMicrokernelTester() 117 .mr(${MR}) 118 .nr(${NR}) 119 .kr(${KR}) 120 .sr(${SR}) 121 .m(m) 122 .n(${NR}) 123 .k(${KBLOCK}) 124 .iterations(1) 125 .Test(${", ".join(TEST_ARGS)}); 126 } 127} 128 129 130TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) { 131 $if ISA_CHECK: 132 ${ISA_CHECK}; 133 for (uint32_t n = 1; n <= ${NR}; n++) { 134 GemmMicrokernelTester() 135 .mr(${MR}) 136 .nr(${NR}) 137 .kr(${KR}) 138 .sr(${SR}) 139 .m(${MR}) 140 .n(n) 141 .k(${KBLOCK}) 142 .iterations(1) 143 .Test(${", ".join(TEST_ARGS)}); 144 } 145} 146 147$if IS_PIPELINED: 148 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) { 149 $if ISA_CHECK: 150 ${ISA_CHECK}; 151 GemmMicrokernelTester() 152 .mr(${MR}) 153 .nr(${NR}) 154 .kr(${KR}) 155 .sr(${SR}) 156 .m(${MR}) 157 .n(${NR}) 158 .k(${KBLOCK * 2}) 159 .Test(${", ".join(TEST_ARGS)}); 160 } 161 162 $if UKERNEL_TYPE != "IGEMM": 163 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) { 164 $if ISA_CHECK: 165 ${ISA_CHECK}; 166 GemmMicrokernelTester() 167 .mr(${MR}) 168 .nr(${NR}) 169 .kr(${KR}) 170 .sr(${SR}) 171 .m(${MR}) 172 .n(${NR}) 173 .k(${KBLOCK * 2}) 174 .a_stride(${next_prime(KBLOCK * 2 + 1)}) 175 .Test(${", ".join(TEST_ARGS)}); 176 } 177 178 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) { 179 $if ISA_CHECK: 180 ${ISA_CHECK}; 181 for (uint32_t m = 1; m <= ${MR}; m++) { 182 for (uint32_t n = 1; n <= ${NR}; n++) { 183 GemmMicrokernelTester() 184 .mr(${MR}) 185 .nr(${NR}) 186 .kr(${KR}) 187 .sr(${SR}) 188 .m(m) 189 .n(n) 190 .k(${KBLOCK * 2}) 191 .iterations(1) 192 .Test(${", ".join(TEST_ARGS)}); 193 } 194 } 195 } 196 197$if KBLOCK > 1: 198 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) { 199 $if ISA_CHECK: 200 ${ISA_CHECK}; 201 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 202 GemmMicrokernelTester() 203 .mr(${MR}) 204 .nr(${NR}) 205 .kr(${KR}) 206 .sr(${SR}) 207 .m(${MR}) 208 .n(${NR}) 209 .k(k) 210 .Test(${", ".join(TEST_ARGS)}); 211 } 212 } 213 214 $if UKERNEL_TYPE != "IGEMM": 215 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) { 216 $if ISA_CHECK: 217 ${ISA_CHECK}; 218 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 219 GemmMicrokernelTester() 220 .mr(${MR}) 221 .nr(${NR}) 222 .kr(${KR}) 223 .sr(${SR}) 224 .m(${MR}) 225 .n(${NR}) 226 .k(k) 227 .a_stride(${next_prime(ADJKBLOCK + 1)}) 228 .Test(${", ".join(TEST_ARGS)}); 229 } 230 } 231 232 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) { 233 $if ISA_CHECK: 234 ${ISA_CHECK}; 235 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 236 for (uint32_t m = 1; m <= ${MR}; m++) { 237 for (uint32_t n = 1; n <= ${NR}; n++) { 238 GemmMicrokernelTester() 239 .mr(${MR}) 240 .nr(${NR}) 241 .kr(${KR}) 242 .sr(${SR}) 243 .m(m) 244 .n(n) 245 .k(k) 246 .iterations(1) 247 .Test(${", ".join(TEST_ARGS)}); 248 } 249 } 250 } 251 } 252 253TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) { 254 $if ISA_CHECK: 255 ${ISA_CHECK}; 256 for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) { 257 GemmMicrokernelTester() 258 .mr(${MR}) 259 .nr(${NR}) 260 .kr(${KR}) 261 .sr(${SR}) 262 .m(${MR}) 263 .n(${NR}) 264 .k(k) 265 .Test(${", ".join(TEST_ARGS)}); 266 } 267} 268 269$if UKERNEL_TYPE.startswith("GEMM"): 270 TEST(${TEST_NAME}, k_gt_${KBLOCK}_strided_a) { 271 $if ISA_CHECK: 272 ${ISA_CHECK}; 273 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) { 274 GemmMicrokernelTester() 275 .mr(${MR}) 276 .nr(${NR}) 277 .kr(${KR}) 278 .sr(${SR}) 279 .m(${MR}) 280 .n(${NR}) 281 .k(k) 282 .a_stride(${next_prime(10 if KBLOCK == 1 else KBLOCK * 2 + 1)}) 283 .Test(${", ".join(TEST_ARGS)}); 284 } 285 } 286 287TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) { 288 $if ISA_CHECK: 289 ${ISA_CHECK}; 290 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) { 291 for (uint32_t m = 1; m <= ${MR}; m++) { 292 for (uint32_t n = 1; n <= ${NR}; n++) { 293 GemmMicrokernelTester() 294 .mr(${MR}) 295 .nr(${NR}) 296 .kr(${KR}) 297 .sr(${SR}) 298 .m(m) 299 .n(n) 300 .k(k) 301 .iterations(1) 302 .Test(${", ".join(TEST_ARGS)}); 303 } 304 } 305 } 306} 307 308$if KBLOCK > 1: 309 TEST(${TEST_NAME}, k_div_${KBLOCK}) { 310 $if ISA_CHECK: 311 ${ISA_CHECK}; 312 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 313 GemmMicrokernelTester() 314 .mr(${MR}) 315 .nr(${NR}) 316 .kr(${KR}) 317 .sr(${SR}) 318 .m(${MR}) 319 .n(${NR}) 320 .k(k) 321 .Test(${", ".join(TEST_ARGS)}); 322 } 323 } 324 325 $if UKERNEL_TYPE.startswith("GEMM"): 326 TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) { 327 $if ISA_CHECK: 328 ${ISA_CHECK}; 329 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 330 GemmMicrokernelTester() 331 .mr(${MR}) 332 .nr(${NR}) 333 .kr(${KR}) 334 .sr(${SR}) 335 .m(${MR}) 336 .n(${NR}) 337 .k(k) 338 .a_stride(${next_prime(KBLOCK * 10 + 1)}) 339 .Test(${", ".join(TEST_ARGS)}); 340 } 341 } 342 343 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) { 344 $if ISA_CHECK: 345 ${ISA_CHECK}; 346 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 347 for (uint32_t m = 1; m <= ${MR}; m++) { 348 for (uint32_t n = 1; n <= ${NR}; n++) { 349 GemmMicrokernelTester() 350 .mr(${MR}) 351 .nr(${NR}) 352 .kr(${KR}) 353 .sr(${SR}) 354 .m(m) 355 .n(n) 356 .k(k) 357 .iterations(1) 358 .Test(${", ".join(TEST_ARGS)}); 359 } 360 } 361 } 362 } 363 364TEST(${TEST_NAME}, n_gt_${NR}) { 365 $if ISA_CHECK: 366 ${ISA_CHECK}; 367 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 368 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 369 GemmMicrokernelTester() 370 .mr(${MR}) 371 .nr(${NR}) 372 .kr(${KR}) 373 .sr(${SR}) 374 .m(${MR}) 375 .n(${NR}) 376 .k(k) 377 .Test(${", ".join(TEST_ARGS)}); 378 } 379 } 380} 381 382TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) { 383 $if ISA_CHECK: 384 ${ISA_CHECK}; 385 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 386 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 387 GemmMicrokernelTester() 388 .mr(${MR}) 389 .nr(${NR}) 390 .kr(${KR}) 391 .sr(${SR}) 392 .m(${MR}) 393 .n(${NR}) 394 .k(k) 395 .cn_stride(${next_prime(NR + 1)}) 396 .Test(${", ".join(TEST_ARGS)}); 397 } 398 } 399} 400 401$if UKERNEL_TYPE != "IGEMM": 402 TEST(${TEST_NAME}, n_gt_${NR}_strided_a) { 403 $if ISA_CHECK: 404 ${ISA_CHECK}; 405 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 406 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 407 GemmMicrokernelTester() 408 .mr(${MR}) 409 .nr(${NR}) 410 .kr(${KR}) 411 .sr(${SR}) 412 .m(${MR}) 413 .n(n) 414 .k(k) 415 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 416 .Test(${", ".join(TEST_ARGS)}); 417 } 418 } 419 } 420 421TEST(${TEST_NAME}, n_gt_${NR}_subtile) { 422 $if ISA_CHECK: 423 ${ISA_CHECK}; 424 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 425 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 426 for (uint32_t m = 1; m <= ${MR}; m++) { 427 GemmMicrokernelTester() 428 .mr(${MR}) 429 .nr(${NR}) 430 .kr(${KR}) 431 .sr(${SR}) 432 .m(m) 433 .n(n) 434 .k(k) 435 .iterations(1) 436 .Test(${", ".join(TEST_ARGS)}); 437 } 438 } 439 } 440} 441 442TEST(${TEST_NAME}, n_div_${NR}) { 443 $if ISA_CHECK: 444 ${ISA_CHECK}; 445 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 446 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 447 GemmMicrokernelTester() 448 .mr(${MR}) 449 .nr(${NR}) 450 .kr(${KR}) 451 .sr(${SR}) 452 .m(${MR}) 453 .n(${NR}) 454 .k(k) 455 .Test(${", ".join(TEST_ARGS)}); 456 } 457 } 458} 459 460TEST(${TEST_NAME}, n_div_${NR}_strided_cn) { 461 $if ISA_CHECK: 462 ${ISA_CHECK}; 463 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 464 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 465 GemmMicrokernelTester() 466 .mr(${MR}) 467 .nr(${NR}) 468 .kr(${KR}) 469 .sr(${SR}) 470 .m(${MR}) 471 .n(n) 472 .k(k) 473 .cn_stride(${next_prime(NR + 1)}) 474 .Test(${", ".join(TEST_ARGS)}); 475 } 476 } 477} 478 479$if UKERNEL_TYPE != "IGEMM": 480 TEST(${TEST_NAME}, n_div_${NR}_strided_a) { 481 $if ISA_CHECK: 482 ${ISA_CHECK}; 483 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 484 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 485 GemmMicrokernelTester() 486 .mr(${MR}) 487 .nr(${NR}) 488 .kr(${KR}) 489 .sr(${SR}) 490 .m(${MR}) 491 .n(n) 492 .k(k) 493 .a_stride(${next_prime(KBLOCK * 5 + 1)}) 494 .Test(${", ".join(TEST_ARGS)}); 495 } 496 } 497 } 498 499TEST(${TEST_NAME}, n_div_${NR}_subtile) { 500 $if ISA_CHECK: 501 ${ISA_CHECK}; 502 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 503 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 504 for (uint32_t m = 1; m <= ${MR}; m++) { 505 GemmMicrokernelTester() 506 .mr(${MR}) 507 .nr(${NR}) 508 .kr(${KR}) 509 .sr(${SR}) 510 .m(m) 511 .n(n) 512 .k(k) 513 .iterations(1) 514 .Test(${", ".join(TEST_ARGS)}); 515 } 516 } 517 } 518} 519 520$if UKERNEL_TYPE.startswith("IGEMM"): 521 TEST(${TEST_NAME}, small_kernel) { 522 $if ISA_CHECK: 523 ${ISA_CHECK}; 524 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 525 GemmMicrokernelTester() 526 .mr(${MR}) 527 .nr(${NR}) 528 .kr(${KR}) 529 .sr(${SR}) 530 .m(${MR}) 531 .n(${NR}) 532 .k(k) 533 .ks(3) 534 .Test(${", ".join(TEST_ARGS)}); 535 } 536 } 537 538 TEST(${TEST_NAME}, small_kernel_subtile) { 539 $if ISA_CHECK: 540 ${ISA_CHECK}; 541 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 542 for (uint32_t m = 1; m <= ${MR}; m++) { 543 for (uint32_t n = 1; n <= ${NR}; n++) { 544 GemmMicrokernelTester() 545 .mr(${MR}) 546 .nr(${NR}) 547 .kr(${KR}) 548 .sr(${SR}) 549 .m(m) 550 .n(n) 551 .k(k) 552 .ks(3) 553 .iterations(1) 554 .Test(${", ".join(TEST_ARGS)}); 555 } 556 } 557 } 558 } 559 560 TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) { 561 $if ISA_CHECK: 562 ${ISA_CHECK}; 563 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) { 564 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 565 GemmMicrokernelTester() 566 .mr(${MR}) 567 .nr(${NR}) 568 .kr(${KR}) 569 .sr(${SR}) 570 .m(${MR}) 571 .n(${NR}) 572 .k(k) 573 .ks(3) 574 .Test(${", ".join(TEST_ARGS)}); 575 } 576 } 577 } 578 579 TEST(${TEST_NAME}, n_div_${NR}_small_kernel) { 580 $if ISA_CHECK: 581 ${ISA_CHECK}; 582 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 583 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 584 GemmMicrokernelTester() 585 .mr(${MR}) 586 .nr(${NR}) 587 .kr(${KR}) 588 .sr(${SR}) 589 .m(${MR}) 590 .n(${NR}) 591 .k(k) 592 .ks(3) 593 .Test(${", ".join(TEST_ARGS)}); 594 } 595 } 596 } 597 598TEST(${TEST_NAME}, strided_cm_subtile) { 599 $if ISA_CHECK: 600 ${ISA_CHECK}; 601 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 602 for (uint32_t m = 1; m <= ${MR}; m++) { 603 for (uint32_t n = 1; n <= ${NR}; n++) { 604 GemmMicrokernelTester() 605 .mr(${MR}) 606 .nr(${NR}) 607 .kr(${KR}) 608 .sr(${SR}) 609 .m(m) 610 .n(n) 611 .k(k) 612 .cm_stride(${next_prime(NR + 1)}) 613 .iterations(1) 614 .Test(${", ".join(TEST_ARGS)}); 615 } 616 } 617 } 618} 619 620$if UKERNEL_TYPE.startswith("IGEMM"): 621 TEST(${TEST_NAME}, a_offset) { 622 $if ISA_CHECK: 623 ${ISA_CHECK}; 624 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 625 GemmMicrokernelTester() 626 .mr(${MR}) 627 .nr(${NR}) 628 .kr(${KR}) 629 .sr(${SR}) 630 .m(${MR}) 631 .n(${NR}) 632 .k(k) 633 .ks(3) 634 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 635 .Test(${", ".join(TEST_ARGS)}); 636 } 637 } 638 639 TEST(${TEST_NAME}, zero) { 640 $if ISA_CHECK: 641 ${ISA_CHECK}; 642 for (uint32_t mz = 0; mz < ${MR}; mz++) { 643 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 644 GemmMicrokernelTester() 645 .mr(${MR}) 646 .nr(${NR}) 647 .kr(${KR}) 648 .sr(${SR}) 649 .m(${MR}) 650 .n(${NR}) 651 .k(k) 652 .ks(3) 653 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)}) 654 .zero_index(mz) 655 .Test(${", ".join(TEST_ARGS)}); 656 } 657 } 658 } 659 660$if ACTIVATION == "MINMAX": 661 TEST(${TEST_NAME}, qmin) { 662 $if ISA_CHECK: 663 ${ISA_CHECK}; 664 GemmMicrokernelTester() 665 .mr(${MR}) 666 .nr(${NR}) 667 .kr(${KR}) 668 .sr(${SR}) 669 .m(${MR}) 670 .n(${NR}) 671 .k(${KBLOCK}) 672 .qmin(128) 673 .Test(${", ".join(TEST_ARGS)}); 674 } 675 676 TEST(${TEST_NAME}, qmax) { 677 $if ISA_CHECK: 678 ${ISA_CHECK}; 679 GemmMicrokernelTester() 680 .mr(${MR}) 681 .nr(${NR}) 682 .kr(${KR}) 683 .sr(${SR}) 684 .m(${MR}) 685 .n(${NR}) 686 .k(${KBLOCK}) 687 .qmax(128) 688 .Test(${", ".join(TEST_ARGS)}); 689 } 690 691TEST(${TEST_NAME}, strided_cm) { 692 $if ISA_CHECK: 693 ${ISA_CHECK}; 694 GemmMicrokernelTester() 695 .mr(${MR}) 696 .nr(${NR}) 697 .kr(${KR}) 698 .sr(${SR}) 699 .m(${MR}) 700 .n(${NR}) 701 .k(${KBLOCK}) 702 .cm_stride(${next_prime(NR + 1)}) 703 .Test(${", ".join(TEST_ARGS)}); 704} 705 706$if DATATYPE == "qu8": 707 TEST(${TEST_NAME}, no_a_zero_point) { 708 $if ISA_CHECK: 709 ${ISA_CHECK}; 710 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 711 GemmMicrokernelTester() 712 .mr(${MR}) 713 .nr(${NR}) 714 .kr(${KR}) 715 .sr(${SR}) 716 .m(${MR}) 717 .n(${NR}) 718 .k(k) 719 .a_zero_point(0) 720 .Test(${", ".join(TEST_ARGS)}); 721 } 722 } 723 724 TEST(${TEST_NAME}, no_b_zero_point) { 725 $if ISA_CHECK: 726 ${ISA_CHECK}; 727 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 728 GemmMicrokernelTester() 729 .mr(${MR}) 730 .nr(${NR}) 731 .kr(${KR}) 732 .sr(${SR}) 733 .m(${MR}) 734 .n(${NR}) 735 .k(k) 736 .b_zero_point(0) 737 .Test(${", ".join(TEST_ARGS)}); 738 } 739 } 740 741 TEST(${TEST_NAME}, no_zero_point) { 742 $if ISA_CHECK: 743 ${ISA_CHECK}; 744 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 745 GemmMicrokernelTester() 746 .mr(${MR}) 747 .nr(${NR}) 748 .kr(${KR}) 749 .sr(${SR}) 750 .m(${MR}) 751 .n(${NR}) 752 .k(k) 753 .a_zero_point(0) 754 .b_zero_point(0) 755 .Test(${", ".join(TEST_ARGS)}); 756 } 757 } 758""" 759 760 761def generate_test_cases(ukernel, mr, nr, kr, sr, 762 k_block, is_pipelined, isa): 763 """Generates all tests cases for a GEMM micro-kernel. 764 765 Args: 766 ukernel: C name of the micro-kernel function. 767 mr: MR parameter of the GEMM micro-kernel. 768 nr: NR parameter of the GEMM micro-kernel. 769 kr: KR parameter of the GEMM micro-kernel. 770 sr: SR parameter of the GEMM micro-kernel. 771 k_block: Number of K values processed per one iteration of the main loop of 772 the micro-kernel. 773 is_pipelined: Indicates if the micro-kernel is implemented with software 774 pipelining. Additional test cases are generated for software 775 pipelined micro-kernels to separately test prologue + epiloque 776 of the pipelined loop and iteration of the pipelined loop. 777 isa: instruction set required to run the micro-kernel. Generated unit test 778 will skip execution if the host processor doesn't support this ISA. 779 780 Returns: 781 Code for the test case. 782 """ 783 _, test_name = ukernel.split("_", 1) 784 _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4) 785 if activation == "ukernel": 786 activation = "linear" 787 test_args = [ukernel] 788 if activation not in ["linear", "relu"] and not isa: 789 test_args.append("GemmMicrokernelTester::Variant::Scalar") 790 return xngen.preprocess(GEMM_TEST_CODE, { 791 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 792 "TEST_ARGS": test_args, 793 "UKERNEL_TYPE": ukernel_type.upper(), 794 "DATATYPE": datatype, 795 "ACTIVATION": activation.upper(), 796 "MR": mr, 797 "NR": nr, 798 "KR": kr, 799 "SR": sr, 800 "KBLOCK": k_block, 801 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, 802 "IS_PIPELINED": is_pipelined, 803 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 804 "next_prime": next_prime, 805 }) 806 807 808def main(args): 809 options = parser.parse_args(args) 810 811 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 812 spec_yaml = yaml.safe_load(spec_file) 813 if not isinstance(spec_yaml, list): 814 raise ValueError("expected a list of micro-kernels in the spec") 815 816 tests = """\ 817// Copyright (c) Facebook, Inc. and its affiliates. 818// All rights reserved. 819// 820// Copyright 2019 Google LLC 821// 822// This source code is licensed under the BSD-style license found in the 823// LICENSE file in the root directory of this source tree. 824// 825// Auto-generated file. Do not edit! 826// Specification: {specification} 827// Generator: {generator} 828 829 830#include <gtest/gtest.h> 831 832#include <xnnpack/common.h> 833#include <xnnpack/isa-checks.h> 834 835#include <xnnpack/gemm.h> 836#include <xnnpack/igemm.h> 837#include <xnnpack/ppmm.h> 838#include "gemm-microkernel-tester.h" 839""".format(specification=options.spec, generator=sys.argv[0]) 840 841 for ukernel_spec in spec_yaml: 842 name = ukernel_spec["name"] 843 k_block = int(ukernel_spec["k-block"]) 844 pipelined = bool(ukernel_spec.get("pipelined", False)) 845 assembly = bool(ukernel_spec.get("assembly", False)) 846 mr, nr, kr, sr, arch, isa = split_ukernel_name(name) 847 848 # specification can override architecture 849 arch = ukernel_spec.get("arch", arch) 850 851 test_case = generate_test_cases( 852 name, mr, nr, kr, sr, k_block, pipelined, isa) 853 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa, assembly) 854 855 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 856 output_file.write(tests) 857 858 859if __name__ == "__main__": 860 main(sys.argv[1:]) 861