• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22                    help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24                    help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
28def split_ukernel_name(name):
29  common_name, target_name = name.split("__", 1)
30  common_parts = common_name.split("_")
31  param_spec = common_parts[-1]
32  if "s" in param_spec:
33    param_spec, sr = param_spec.split("s", 1)
34    sr = int(sr)
35  else:
36    sr = 1
37  if "c" in param_spec:
38    param_spec, kr = param_spec.split("c", 1)
39    kr = int(kr)
40  else:
41    kr = 1
42  mr, nr = map(int, param_spec.split("x"))
43  arch, isa = xnncommon.parse_target_name(target_name)
44  return mr, nr, kr, sr, arch, isa
45
46
47GEMM_TEST_CODE = """\
48TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
49  $if ISA_CHECK:
50    ${ISA_CHECK};
51  GemmMicrokernelTester()
52    .mr(${MR})
53    .nr(${NR})
54    .kr(${KR})
55    .sr(${SR})
56    .m(${MR})
57    .n(${NR})
58    .k(${KBLOCK})
59    .Test(${", ".join(TEST_ARGS)});
60}
61
62TEST(${TEST_NAME}, strided_cn) {
63  $if ISA_CHECK:
64    ${ISA_CHECK};
65  GemmMicrokernelTester()
66    .mr(${MR})
67    .nr(${NR})
68    .kr(${KR})
69    .sr(${SR})
70    .m(${MR})
71    .n(${NR})
72    .k(${KBLOCK})
73    .cn_stride(${next_prime(NR + 1)})
74    .Test(${", ".join(TEST_ARGS)});
75}
76
77$if UKERNEL_TYPE != "IGEMM":
78  TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) {
79    $if ISA_CHECK:
80      ${ISA_CHECK};
81    GemmMicrokernelTester()
82      .mr(${MR})
83      .nr(${NR})
84      .kr(${KR})
85      .sr(${SR})
86      .m(${MR})
87      .n(${NR})
88      .k(${KBLOCK})
89      .a_stride(${next_prime(KBLOCK + 1)})
90      .Test(${", ".join(TEST_ARGS)});
91  }
92
93TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
94  $if ISA_CHECK:
95    ${ISA_CHECK};
96  for (uint32_t m = 1; m <= ${MR}; m++) {
97    for (uint32_t n = 1; n <= ${NR}; n++) {
98      GemmMicrokernelTester()
99        .mr(${MR})
100        .nr(${NR})
101        .kr(${KR})
102        .sr(${SR})
103        .m(m)
104        .n(n)
105        .k(${KBLOCK})
106        .iterations(1)
107        .Test(${", ".join(TEST_ARGS)});
108    }
109  }
110}
111
112TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) {
113  $if ISA_CHECK:
114    ${ISA_CHECK};
115  for (uint32_t m = 1; m <= ${MR}; m++) {
116    GemmMicrokernelTester()
117      .mr(${MR})
118      .nr(${NR})
119      .kr(${KR})
120      .sr(${SR})
121      .m(m)
122      .n(${NR})
123      .k(${KBLOCK})
124      .iterations(1)
125      .Test(${", ".join(TEST_ARGS)});
126  }
127}
128
129
130TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) {
131  $if ISA_CHECK:
132    ${ISA_CHECK};
133  for (uint32_t n = 1; n <= ${NR}; n++) {
134    GemmMicrokernelTester()
135      .mr(${MR})
136      .nr(${NR})
137      .kr(${KR})
138      .sr(${SR})
139      .m(${MR})
140      .n(n)
141      .k(${KBLOCK})
142      .iterations(1)
143      .Test(${", ".join(TEST_ARGS)});
144  }
145}
146
147$if IS_PIPELINED:
148  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
149    $if ISA_CHECK:
150      ${ISA_CHECK};
151    GemmMicrokernelTester()
152      .mr(${MR})
153      .nr(${NR})
154      .kr(${KR})
155      .sr(${SR})
156      .m(${MR})
157      .n(${NR})
158      .k(${KBLOCK * 2})
159      .Test(${", ".join(TEST_ARGS)});
160  }
161
162  $if UKERNEL_TYPE != "IGEMM":
163    TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) {
164      $if ISA_CHECK:
165        ${ISA_CHECK};
166      GemmMicrokernelTester()
167        .mr(${MR})
168        .nr(${NR})
169        .kr(${KR})
170        .sr(${SR})
171        .m(${MR})
172        .n(${NR})
173        .k(${KBLOCK * 2})
174        .a_stride(${next_prime(KBLOCK * 2 + 1)})
175        .Test(${", ".join(TEST_ARGS)});
176    }
177
178  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
179    $if ISA_CHECK:
180      ${ISA_CHECK};
181    for (uint32_t m = 1; m <= ${MR}; m++) {
182      for (uint32_t n = 1; n <= ${NR}; n++) {
183        GemmMicrokernelTester()
184          .mr(${MR})
185          .nr(${NR})
186          .kr(${KR})
187          .sr(${SR})
188          .m(m)
189          .n(n)
190          .k(${KBLOCK * 2})
191          .iterations(1)
192          .Test(${", ".join(TEST_ARGS)});
193      }
194    }
195  }
196
197$if KBLOCK > 1:
198  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
199    $if ISA_CHECK:
200      ${ISA_CHECK};
201    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
202      GemmMicrokernelTester()
203        .mr(${MR})
204        .nr(${NR})
205        .kr(${KR})
206        .sr(${SR})
207        .m(${MR})
208        .n(${NR})
209        .k(k)
210        .Test(${", ".join(TEST_ARGS)});
211    }
212  }
213
214  $if UKERNEL_TYPE != "IGEMM":
215    TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) {
216      $if ISA_CHECK:
217        ${ISA_CHECK};
218      for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
219        GemmMicrokernelTester()
220          .mr(${MR})
221          .nr(${NR})
222          .kr(${KR})
223          .sr(${SR})
224          .m(${MR})
225          .n(${NR})
226          .k(k)
227          .a_stride(${next_prime(ADJKBLOCK + 1)})
228          .Test(${", ".join(TEST_ARGS)});
229      }
230    }
231
232  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
233    $if ISA_CHECK:
234      ${ISA_CHECK};
235    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
236      for (uint32_t m = 1; m <= ${MR}; m++) {
237        for (uint32_t n = 1; n <= ${NR}; n++) {
238          GemmMicrokernelTester()
239            .mr(${MR})
240            .nr(${NR})
241            .kr(${KR})
242            .sr(${SR})
243            .m(m)
244            .n(n)
245            .k(k)
246            .iterations(1)
247            .Test(${", ".join(TEST_ARGS)});
248        }
249      }
250    }
251  }
252
253TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
254  $if ISA_CHECK:
255    ${ISA_CHECK};
256  for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
257    GemmMicrokernelTester()
258      .mr(${MR})
259      .nr(${NR})
260      .kr(${KR})
261      .sr(${SR})
262      .m(${MR})
263      .n(${NR})
264      .k(k)
265      .Test(${", ".join(TEST_ARGS)});
266  }
267}
268
269$if UKERNEL_TYPE.startswith("GEMM"):
270  TEST(${TEST_NAME}, k_gt_${KBLOCK}_strided_a) {
271    $if ISA_CHECK:
272      ${ISA_CHECK};
273    for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
274      GemmMicrokernelTester()
275        .mr(${MR})
276        .nr(${NR})
277        .kr(${KR})
278        .sr(${SR})
279        .m(${MR})
280        .n(${NR})
281        .k(k)
282        .a_stride(${next_prime(10 if KBLOCK == 1 else KBLOCK * 2 + 1)})
283        .Test(${", ".join(TEST_ARGS)});
284    }
285  }
286
287TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
288  $if ISA_CHECK:
289    ${ISA_CHECK};
290  for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
291    for (uint32_t m = 1; m <= ${MR}; m++) {
292      for (uint32_t n = 1; n <= ${NR}; n++) {
293        GemmMicrokernelTester()
294          .mr(${MR})
295          .nr(${NR})
296          .kr(${KR})
297          .sr(${SR})
298          .m(m)
299          .n(n)
300          .k(k)
301          .iterations(1)
302          .Test(${", ".join(TEST_ARGS)});
303      }
304    }
305  }
306}
307
308$if KBLOCK > 1:
309  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
310    $if ISA_CHECK:
311      ${ISA_CHECK};
312    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
313      GemmMicrokernelTester()
314        .mr(${MR})
315        .nr(${NR})
316        .kr(${KR})
317        .sr(${SR})
318        .m(${MR})
319        .n(${NR})
320        .k(k)
321        .Test(${", ".join(TEST_ARGS)});
322    }
323  }
324
325  $if UKERNEL_TYPE.startswith("GEMM"):
326    TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) {
327      $if ISA_CHECK:
328        ${ISA_CHECK};
329      for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
330        GemmMicrokernelTester()
331          .mr(${MR})
332          .nr(${NR})
333          .kr(${KR})
334          .sr(${SR})
335          .m(${MR})
336          .n(${NR})
337          .k(k)
338          .a_stride(${next_prime(KBLOCK * 10 + 1)})
339          .Test(${", ".join(TEST_ARGS)});
340      }
341    }
342
343  TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
344    $if ISA_CHECK:
345      ${ISA_CHECK};
346    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
347      for (uint32_t m = 1; m <= ${MR}; m++) {
348        for (uint32_t n = 1; n <= ${NR}; n++) {
349          GemmMicrokernelTester()
350            .mr(${MR})
351            .nr(${NR})
352            .kr(${KR})
353            .sr(${SR})
354            .m(m)
355            .n(n)
356            .k(k)
357            .iterations(1)
358            .Test(${", ".join(TEST_ARGS)});
359        }
360      }
361    }
362  }
363
364TEST(${TEST_NAME}, n_gt_${NR}) {
365  $if ISA_CHECK:
366    ${ISA_CHECK};
367  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
368    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
369      GemmMicrokernelTester()
370        .mr(${MR})
371        .nr(${NR})
372        .kr(${KR})
373        .sr(${SR})
374        .m(${MR})
375        .n(${NR})
376        .k(k)
377        .Test(${", ".join(TEST_ARGS)});
378    }
379  }
380}
381
382TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) {
383  $if ISA_CHECK:
384    ${ISA_CHECK};
385  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
386    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
387      GemmMicrokernelTester()
388        .mr(${MR})
389        .nr(${NR})
390        .kr(${KR})
391        .sr(${SR})
392        .m(${MR})
393        .n(${NR})
394        .k(k)
395        .cn_stride(${next_prime(NR + 1)})
396        .Test(${", ".join(TEST_ARGS)});
397    }
398  }
399}
400
401$if UKERNEL_TYPE != "IGEMM":
402  TEST(${TEST_NAME}, n_gt_${NR}_strided_a) {
403    $if ISA_CHECK:
404      ${ISA_CHECK};
405    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
406      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
407        GemmMicrokernelTester()
408          .mr(${MR})
409          .nr(${NR})
410          .kr(${KR})
411          .sr(${SR})
412          .m(${MR})
413          .n(n)
414          .k(k)
415          .a_stride(${next_prime(KBLOCK * 5 + 1)})
416          .Test(${", ".join(TEST_ARGS)});
417      }
418    }
419  }
420
421TEST(${TEST_NAME}, n_gt_${NR}_subtile) {
422  $if ISA_CHECK:
423    ${ISA_CHECK};
424  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
425    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
426      for (uint32_t m = 1; m <= ${MR}; m++) {
427        GemmMicrokernelTester()
428          .mr(${MR})
429          .nr(${NR})
430          .kr(${KR})
431          .sr(${SR})
432          .m(m)
433          .n(n)
434          .k(k)
435          .iterations(1)
436          .Test(${", ".join(TEST_ARGS)});
437      }
438    }
439  }
440}
441
442TEST(${TEST_NAME}, n_div_${NR}) {
443  $if ISA_CHECK:
444    ${ISA_CHECK};
445  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
446    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
447      GemmMicrokernelTester()
448        .mr(${MR})
449        .nr(${NR})
450        .kr(${KR})
451        .sr(${SR})
452        .m(${MR})
453        .n(${NR})
454        .k(k)
455        .Test(${", ".join(TEST_ARGS)});
456    }
457  }
458}
459
460TEST(${TEST_NAME}, n_div_${NR}_strided_cn) {
461  $if ISA_CHECK:
462    ${ISA_CHECK};
463  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
464    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
465      GemmMicrokernelTester()
466        .mr(${MR})
467        .nr(${NR})
468        .kr(${KR})
469        .sr(${SR})
470        .m(${MR})
471        .n(n)
472        .k(k)
473        .cn_stride(${next_prime(NR + 1)})
474        .Test(${", ".join(TEST_ARGS)});
475    }
476  }
477}
478
479$if UKERNEL_TYPE != "IGEMM":
480  TEST(${TEST_NAME}, n_div_${NR}_strided_a) {
481    $if ISA_CHECK:
482      ${ISA_CHECK};
483    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
484      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
485        GemmMicrokernelTester()
486          .mr(${MR})
487          .nr(${NR})
488          .kr(${KR})
489          .sr(${SR})
490          .m(${MR})
491          .n(n)
492          .k(k)
493          .a_stride(${next_prime(KBLOCK * 5 + 1)})
494          .Test(${", ".join(TEST_ARGS)});
495      }
496    }
497  }
498
499TEST(${TEST_NAME}, n_div_${NR}_subtile) {
500  $if ISA_CHECK:
501    ${ISA_CHECK};
502  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
503    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
504      for (uint32_t m = 1; m <= ${MR}; m++) {
505        GemmMicrokernelTester()
506          .mr(${MR})
507          .nr(${NR})
508          .kr(${KR})
509          .sr(${SR})
510          .m(m)
511          .n(n)
512          .k(k)
513          .iterations(1)
514          .Test(${", ".join(TEST_ARGS)});
515      }
516    }
517  }
518}
519
520$if UKERNEL_TYPE.startswith("IGEMM"):
521  TEST(${TEST_NAME}, small_kernel) {
522    $if ISA_CHECK:
523      ${ISA_CHECK};
524    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
525      GemmMicrokernelTester()
526        .mr(${MR})
527        .nr(${NR})
528        .kr(${KR})
529        .sr(${SR})
530        .m(${MR})
531        .n(${NR})
532        .k(k)
533        .ks(3)
534        .Test(${", ".join(TEST_ARGS)});
535    }
536  }
537
538  TEST(${TEST_NAME}, small_kernel_subtile) {
539    $if ISA_CHECK:
540      ${ISA_CHECK};
541    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
542      for (uint32_t m = 1; m <= ${MR}; m++) {
543        for (uint32_t n = 1; n <= ${NR}; n++) {
544          GemmMicrokernelTester()
545            .mr(${MR})
546            .nr(${NR})
547            .kr(${KR})
548            .sr(${SR})
549            .m(m)
550            .n(n)
551            .k(k)
552            .ks(3)
553            .iterations(1)
554            .Test(${", ".join(TEST_ARGS)});
555        }
556      }
557    }
558  }
559
560  TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) {
561    $if ISA_CHECK:
562      ${ISA_CHECK};
563    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
564      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
565        GemmMicrokernelTester()
566          .mr(${MR})
567          .nr(${NR})
568          .kr(${KR})
569          .sr(${SR})
570          .m(${MR})
571          .n(${NR})
572          .k(k)
573          .ks(3)
574          .Test(${", ".join(TEST_ARGS)});
575      }
576    }
577  }
578
579  TEST(${TEST_NAME}, n_div_${NR}_small_kernel) {
580    $if ISA_CHECK:
581      ${ISA_CHECK};
582    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
583      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
584        GemmMicrokernelTester()
585          .mr(${MR})
586          .nr(${NR})
587          .kr(${KR})
588          .sr(${SR})
589          .m(${MR})
590          .n(${NR})
591          .k(k)
592          .ks(3)
593          .Test(${", ".join(TEST_ARGS)});
594      }
595    }
596  }
597
598TEST(${TEST_NAME}, strided_cm_subtile) {
599  $if ISA_CHECK:
600    ${ISA_CHECK};
601  for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
602    for (uint32_t m = 1; m <= ${MR}; m++) {
603      for (uint32_t n = 1; n <= ${NR}; n++) {
604        GemmMicrokernelTester()
605          .mr(${MR})
606          .nr(${NR})
607          .kr(${KR})
608          .sr(${SR})
609          .m(m)
610          .n(n)
611          .k(k)
612          .cm_stride(${next_prime(NR + 1)})
613          .iterations(1)
614          .Test(${", ".join(TEST_ARGS)});
615      }
616    }
617  }
618}
619
620$if UKERNEL_TYPE.startswith("IGEMM"):
621  TEST(${TEST_NAME}, a_offset) {
622    $if ISA_CHECK:
623      ${ISA_CHECK};
624    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
625      GemmMicrokernelTester()
626        .mr(${MR})
627        .nr(${NR})
628        .kr(${KR})
629        .sr(${SR})
630        .m(${MR})
631        .n(${NR})
632        .k(k)
633        .ks(3)
634        .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
635        .Test(${", ".join(TEST_ARGS)});
636    }
637  }
638
639  TEST(${TEST_NAME}, zero) {
640    $if ISA_CHECK:
641      ${ISA_CHECK};
642    for (uint32_t mz = 0; mz < ${MR}; mz++) {
643      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
644        GemmMicrokernelTester()
645          .mr(${MR})
646          .nr(${NR})
647          .kr(${KR})
648          .sr(${SR})
649          .m(${MR})
650          .n(${NR})
651          .k(k)
652          .ks(3)
653          .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
654          .zero_index(mz)
655          .Test(${", ".join(TEST_ARGS)});
656      }
657    }
658  }
659
660TEST(${TEST_NAME}, qmin) {
661  $if ISA_CHECK:
662    ${ISA_CHECK};
663  GemmMicrokernelTester()
664    .mr(${MR})
665    .nr(${NR})
666    .kr(${KR})
667    .sr(${SR})
668    .m(${MR})
669    .n(${NR})
670    .k(${KBLOCK})
671    .qmin(128)
672    .Test(${", ".join(TEST_ARGS)});
673}
674
675TEST(${TEST_NAME}, qmax) {
676  $if ISA_CHECK:
677    ${ISA_CHECK};
678  GemmMicrokernelTester()
679    .mr(${MR})
680    .nr(${NR})
681    .kr(${KR})
682    .sr(${SR})
683    .m(${MR})
684    .n(${NR})
685    .k(${KBLOCK})
686    .qmax(128)
687    .Test(${", ".join(TEST_ARGS)});
688}
689
690TEST(${TEST_NAME}, strided_cm) {
691  $if ISA_CHECK:
692    ${ISA_CHECK};
693  GemmMicrokernelTester()
694    .mr(${MR})
695    .nr(${NR})
696    .kr(${KR})
697    .sr(${SR})
698    .m(${MR})
699    .n(${NR})
700    .k(${KBLOCK})
701    .cm_stride(${next_prime(NR + 1)})
702    .Test(${", ".join(TEST_ARGS)});
703}
704
705$if DATATYPE == "q8":
706  TEST(${TEST_NAME}, no_a_zero_point) {
707    $if ISA_CHECK:
708      ${ISA_CHECK};
709    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
710      GemmMicrokernelTester()
711        .mr(${MR})
712        .nr(${NR})
713        .kr(${KR})
714        .sr(${SR})
715        .m(${MR})
716        .n(${NR})
717        .k(k)
718        .a_zero_point(0)
719        .Test(${", ".join(TEST_ARGS)});
720    }
721  }
722
723  TEST(${TEST_NAME}, no_b_zero_point) {
724    $if ISA_CHECK:
725      ${ISA_CHECK};
726    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
727      GemmMicrokernelTester()
728        .mr(${MR})
729        .nr(${NR})
730        .kr(${KR})
731        .sr(${SR})
732        .m(${MR})
733        .n(${NR})
734        .k(k)
735        .b_zero_point(0)
736        .Test(${", ".join(TEST_ARGS)});
737    }
738  }
739
740  TEST(${TEST_NAME}, no_zero_point) {
741    $if ISA_CHECK:
742      ${ISA_CHECK};
743    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
744      GemmMicrokernelTester()
745        .mr(${MR})
746        .nr(${NR})
747        .kr(${KR})
748        .sr(${SR})
749        .m(${MR})
750        .n(${NR})
751        .k(k)
752        .a_zero_point(0)
753        .b_zero_point(0)
754        .Test(${", ".join(TEST_ARGS)});
755    }
756  }
757"""
758
759
760def generate_test_cases(ukernel, mr, nr, kr, sr,
761                        k_block, is_pipelined, isa):
762  """Generates all tests cases for a GEMM micro-kernel.
763
764  Args:
765    ukernel: C name of the micro-kernel function.
766    mr: MR parameter of the GEMM micro-kernel.
767    nr: NR parameter of the GEMM micro-kernel.
768    kr: KR parameter of the GEMM micro-kernel.
769    sr: SR parameter of the GEMM micro-kernel.
770    k_block: Number of K values processed per one iteration of the main loop of
771             the micro-kernel.
772    is_pipelined: Indicates if the micro-kernel is implemented with software
773                  pipelining. Additional test cases are generated for software
774                  pipelined micro-kernels to separately test prologue + epiloque
775                  of the pipelined loop and iteration of the pipelined loop.
776    isa: instruction set required to run the micro-kernel. Generated unit test
777         will skip execution if the host processor doesn't support this ISA.
778
779  Returns:
780    Code for the test case.
781  """
782  _, test_name = ukernel.split("_", 1)
783  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
784  test_args = [ukernel]
785  if not isa or isa == "psimd":
786    test_args.append("GemmMicrokernelTester::Variant::Scalar")
787  return xngen.preprocess(GEMM_TEST_CODE, {
788      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
789      "TEST_ARGS": test_args,
790      "UKERNEL_TYPE": ukernel_type.upper(),
791      "DATATYPE": datatype,
792      "MR": mr,
793      "NR": nr,
794      "KR": kr,
795      "SR": sr,
796      "KBLOCK": k_block,
797      "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
798      "IS_PIPELINED": is_pipelined,
799      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
800      "next_prime": next_prime,
801    })
802
803
804def main(args):
805  options = parser.parse_args(args)
806
807  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
808    spec_yaml = yaml.safe_load(spec_file)
809    if not isinstance(spec_yaml, list):
810      raise ValueError("expected a list of micro-kernels in the spec")
811
812    tests = """\
813// Copyright (c) Facebook, Inc. and its affiliates.
814// All rights reserved.
815//
816// Copyright 2019 Google LLC
817//
818// This source code is licensed under the BSD-style license found in the
819// LICENSE file in the root directory of this source tree.
820//
821// Auto-generated file. Do not edit!
822//   Specification: {specification}
823//   Generator: {generator}
824
825
826#include <gtest/gtest.h>
827
828#include <xnnpack/common.h>
829#include <xnnpack/isa-checks.h>
830
831#include <xnnpack/gemm.h>
832#include <xnnpack/igemm.h>
833#include <xnnpack/ppmm.h>
834#include "gemm-microkernel-tester.h"
835""".format(specification=options.spec, generator=sys.argv[0])
836
837    for ukernel_spec in spec_yaml:
838      name = ukernel_spec["name"]
839      k_block = int(ukernel_spec["k-block"])
840      pipelined = bool(ukernel_spec.get("pipelined", False))
841      assembly = bool(ukernel_spec.get("assembly", False))
842      mr, nr, kr, sr, arch, isa = split_ukernel_name(name)
843
844      # specification can override architecture
845      arch = ukernel_spec.get("arch", arch)
846
847      test_case = generate_test_cases(
848        name, mr, nr, kr, sr, k_block, pipelined, isa)
849      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa, assembly)
850
851    with codecs.open(options.output, "w", encoding="utf-8") as output_file:
852      output_file.write(tests)
853
854
855if __name__ == "__main__":
856  main(sys.argv[1:])
857