• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import collections
11import os
12import sys
13import yaml
14import zlib
15
16sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17from primes import next_prime
18import xngen
19import xnncommon
20
21parser = argparse.ArgumentParser(description="XNNPACK generator")
22parser.add_argument(
23    "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file")
24parser.add_argument(
25    "-o",
26    "--output",
27    action="append",
28    metavar="FILE",
29    required=True,
30    help="Output (C++ source) file(s)")
31parser.set_defaults(defines=list())
32
33
34def split_ukernel_name(name):
35  common_name, target_name = name.split("__", 1)
36  common_parts = common_name.split("_")
37  xw = "gemm_xw_" in common_name
38  param_spec = common_parts[-1]
39  if "s" in param_spec:
40    param_spec, sr = param_spec.split("s", 1)
41    sr = int(sr)
42  else:
43    sr = 1
44  if "c" in param_spec:
45    param_spec, kr = param_spec.split("c", 1)
46    kr = int(kr)
47  else:
48    kr = 1
49  mr, nr = map(int, param_spec.split("x"))
50  arch, isa = xnncommon.parse_target_name(target_name)
51
52  requantization = common_parts[-3]
53  if requantization not in ["fp32", "rndnu"]:
54    requantization = None
55
56  return mr, nr, kr, sr, xw, requantization, arch, isa
57
58
59GEMM_TEST_CODE = """\
60TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
61  $if ISA_CHECK:
62    ${ISA_CHECK};
63  GemmMicrokernelTester()
64    $if EXTENDED_WEIGHTS:
65      .extended_weights(true)
66    .mr(${MR})
67    .nr(${NR})
68    .kr(${KR})
69    .sr(${SR})
70    .m(${MR})
71    .n(${NR})
72    .k(${KBLOCK})
73    .Test(${", ".join(TEST_ARGS)});
74}
75
76TEST(${TEST_NAME}, strided_cn) {
77  $if ISA_CHECK:
78    ${ISA_CHECK};
79  GemmMicrokernelTester()
80    $if EXTENDED_WEIGHTS:
81      .extended_weights(true)
82    .mr(${MR})
83    .nr(${NR})
84    .kr(${KR})
85    .sr(${SR})
86    .m(${MR})
87    .n(${NR})
88    .k(${KBLOCK})
89    .cn_stride(${next_prime(NR + 1)})
90    .Test(${", ".join(TEST_ARGS)});
91}
92
93$if UKERNEL_TYPE != "IGEMM":
94  TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) {
95    $if ISA_CHECK:
96      ${ISA_CHECK};
97    GemmMicrokernelTester()
98      $if EXTENDED_WEIGHTS:
99        .extended_weights(true)
100      .mr(${MR})
101      .nr(${NR})
102      .kr(${KR})
103      .sr(${SR})
104      .m(${MR})
105      .n(${NR})
106      .k(${KBLOCK})
107      .a_stride(${next_prime(KBLOCK + 1)})
108      .Test(${", ".join(TEST_ARGS)});
109  }
110
111TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
112  $if ISA_CHECK:
113    ${ISA_CHECK};
114  for (uint32_t n = 1; n <= ${NR}; n++) {
115    for (uint32_t m = 1; m <= ${MR}; m++) {
116      GemmMicrokernelTester()
117        $if EXTENDED_WEIGHTS:
118          .extended_weights(true)
119        .mr(${MR})
120        .nr(${NR})
121        .kr(${KR})
122        .sr(${SR})
123        .m(m)
124        .n(n)
125        .k(${KBLOCK})
126        .iterations(1)
127        .Test(${", ".join(TEST_ARGS)});
128    }
129  }
130}
131
132TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) {
133  $if ISA_CHECK:
134    ${ISA_CHECK};
135  for (uint32_t m = 1; m <= ${MR}; m++) {
136    GemmMicrokernelTester()
137      $if EXTENDED_WEIGHTS:
138        .extended_weights(true)
139      .mr(${MR})
140      .nr(${NR})
141      .kr(${KR})
142      .sr(${SR})
143      .m(m)
144      .n(${NR})
145      .k(${KBLOCK})
146      .iterations(1)
147      .Test(${", ".join(TEST_ARGS)});
148  }
149}
150
151
152TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) {
153  $if ISA_CHECK:
154    ${ISA_CHECK};
155  for (uint32_t n = 1; n <= ${NR}; n++) {
156    GemmMicrokernelTester()
157      $if EXTENDED_WEIGHTS:
158        .extended_weights(true)
159      .mr(${MR})
160      .nr(${NR})
161      .kr(${KR})
162      .sr(${SR})
163      .m(${MR})
164      .n(n)
165      .k(${KBLOCK})
166      .iterations(1)
167      .Test(${", ".join(TEST_ARGS)});
168  }
169}
170
171$if IS_PIPELINED:
172  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
173    $if ISA_CHECK:
174      ${ISA_CHECK};
175    GemmMicrokernelTester()
176      $if EXTENDED_WEIGHTS:
177        .extended_weights(true)
178      .mr(${MR})
179      .nr(${NR})
180      .kr(${KR})
181      .sr(${SR})
182      .m(${MR})
183      .n(${NR})
184      .k(${KBLOCK * 2})
185      .Test(${", ".join(TEST_ARGS)});
186  }
187
188  $if UKERNEL_TYPE != "IGEMM":
189    TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) {
190      $if ISA_CHECK:
191        ${ISA_CHECK};
192      GemmMicrokernelTester()
193        $if EXTENDED_WEIGHTS:
194          .extended_weights(true)
195        .mr(${MR})
196        .nr(${NR})
197        .kr(${KR})
198        .sr(${SR})
199        .m(${MR})
200        .n(${NR})
201        .k(${KBLOCK * 2})
202        .a_stride(${next_prime(KBLOCK * 2 + 1)})
203        .Test(${", ".join(TEST_ARGS)});
204    }
205
206  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
207    $if ISA_CHECK:
208      ${ISA_CHECK};
209    for (uint32_t n = 1; n <= ${NR}; n++) {
210      for (uint32_t m = 1; m <= ${MR}; m++) {
211        GemmMicrokernelTester()
212          $if EXTENDED_WEIGHTS:
213            .extended_weights(true)
214          .mr(${MR})
215          .nr(${NR})
216          .kr(${KR})
217          .sr(${SR})
218          .m(m)
219          .n(n)
220          .k(${KBLOCK * 2})
221          .iterations(1)
222          .Test(${", ".join(TEST_ARGS)});
223      }
224    }
225  }
226
227$if KBLOCK > 1:
228  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
229    $if ISA_CHECK:
230      ${ISA_CHECK};
231    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
232      GemmMicrokernelTester()
233        $if EXTENDED_WEIGHTS:
234          .extended_weights(true)
235        .mr(${MR})
236        .nr(${NR})
237        .kr(${KR})
238        .sr(${SR})
239        .m(${MR})
240        .n(${NR})
241        .k(k)
242        .Test(${", ".join(TEST_ARGS)});
243    }
244  }
245
246  $if UKERNEL_TYPE != "IGEMM":
247    TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) {
248      $if ISA_CHECK:
249        ${ISA_CHECK};
250      for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
251        GemmMicrokernelTester()
252          $if EXTENDED_WEIGHTS:
253            .extended_weights(true)
254          .mr(${MR})
255          .nr(${NR})
256          .kr(${KR})
257          .sr(${SR})
258          .m(${MR})
259          .n(${NR})
260          .k(k)
261          .a_stride(${next_prime(ADJKBLOCK + 1)})
262          .Test(${", ".join(TEST_ARGS)});
263      }
264    }
265
266  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
267    $if ISA_CHECK:
268      ${ISA_CHECK};
269    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
270      for (uint32_t n = 1; n <= ${NR}; n++) {
271        for (uint32_t m = 1; m <= ${MR}; m++) {
272          GemmMicrokernelTester()
273            $if EXTENDED_WEIGHTS:
274              .extended_weights(true)
275            .mr(${MR})
276            .nr(${NR})
277            .kr(${KR})
278            .sr(${SR})
279            .m(m)
280            .n(n)
281            .k(k)
282            .iterations(1)
283            .Test(${", ".join(TEST_ARGS)});
284        }
285      }
286    }
287  }
288
289TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
290  $if ISA_CHECK:
291    ${ISA_CHECK};
292  for (size_t k = ${ADJKBLOCK + 1}; k < ${ADJKBLOCK * 10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
293    GemmMicrokernelTester()
294      $if EXTENDED_WEIGHTS:
295        .extended_weights(true)
296      .mr(${MR})
297      .nr(${NR})
298      .kr(${KR})
299      .sr(${SR})
300      .m(${MR})
301      .n(${NR})
302      .k(k)
303      .Test(${", ".join(TEST_ARGS)});
304  }
305}
306
307$if UKERNEL_TYPE.startswith("GEMM"):
308  TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_strided_a) {
309    $if ISA_CHECK:
310      ${ISA_CHECK};
311    for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
312      GemmMicrokernelTester()
313        $if EXTENDED_WEIGHTS:
314          .extended_weights(true)
315        .mr(${MR})
316        .nr(${NR})
317        .kr(${KR})
318        .sr(${SR})
319        .m(${MR})
320        .n(${NR})
321        .k(k)
322        .a_stride(${next_prime(10 if ADJKBLOCK == 1 else ADJKBLOCK * 2 + 1)})
323        .Test(${", ".join(TEST_ARGS)});
324    }
325  }
326
327TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_subtile) {
328  $if ISA_CHECK:
329    ${ISA_CHECK};
330  for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
331    for (uint32_t n = 1; n <= ${NR}; n++) {
332      for (uint32_t m = 1; m <= ${MR}; m++) {
333        GemmMicrokernelTester()
334          $if EXTENDED_WEIGHTS:
335            .extended_weights(true)
336          .mr(${MR})
337          .nr(${NR})
338          .kr(${KR})
339          .sr(${SR})
340          .m(m)
341          .n(n)
342          .k(k)
343          .iterations(1)
344          .Test(${", ".join(TEST_ARGS)});
345      }
346    }
347  }
348}
349
350$if KBLOCK > 1:
351  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
352    $if ISA_CHECK:
353      ${ISA_CHECK};
354    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
355      GemmMicrokernelTester()
356        $if EXTENDED_WEIGHTS:
357          .extended_weights(true)
358        .mr(${MR})
359        .nr(${NR})
360        .kr(${KR})
361        .sr(${SR})
362        .m(${MR})
363        .n(${NR})
364        .k(k)
365        .Test(${", ".join(TEST_ARGS)});
366    }
367  }
368
369  $if UKERNEL_TYPE.startswith("GEMM"):
370    TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) {
371      $if ISA_CHECK:
372        ${ISA_CHECK};
373      for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
374        GemmMicrokernelTester()
375          $if EXTENDED_WEIGHTS:
376            .extended_weights(true)
377          .mr(${MR})
378          .nr(${NR})
379          .kr(${KR})
380          .sr(${SR})
381          .m(${MR})
382          .n(${NR})
383          .k(k)
384          .a_stride(${next_prime(KBLOCK * 10 + 1)})
385          .Test(${", ".join(TEST_ARGS)});
386      }
387    }
388
389  TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
390    $if ISA_CHECK:
391      ${ISA_CHECK};
392    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
393      for (uint32_t n = 1; n <= ${NR}; n++) {
394        for (uint32_t m = 1; m <= ${MR}; m++) {
395          GemmMicrokernelTester()
396            $if EXTENDED_WEIGHTS:
397              .extended_weights(true)
398            .mr(${MR})
399            .nr(${NR})
400            .kr(${KR})
401            .sr(${SR})
402            .m(m)
403            .n(n)
404            .k(k)
405            .iterations(1)
406            .Test(${", ".join(TEST_ARGS)});
407        }
408      }
409    }
410  }
411
412TEST(${TEST_NAME}, n_gt_${NR}) {
413  $if ISA_CHECK:
414    ${ISA_CHECK};
415  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
416    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
417      GemmMicrokernelTester()
418        $if EXTENDED_WEIGHTS:
419          .extended_weights(true)
420        .mr(${MR})
421        .nr(${NR})
422        .kr(${KR})
423        .sr(${SR})
424        .m(${MR})
425        .n(n)
426        .k(k)
427        .Test(${", ".join(TEST_ARGS)});
428    }
429  }
430}
431
432TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) {
433  $if ISA_CHECK:
434    ${ISA_CHECK};
435  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
436    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
437      GemmMicrokernelTester()
438        $if EXTENDED_WEIGHTS:
439          .extended_weights(true)
440        .mr(${MR})
441        .nr(${NR})
442        .kr(${KR})
443        .sr(${SR})
444        .m(${MR})
445        .n(n)
446        .k(k)
447        .cn_stride(${next_prime(NR + 1)})
448        .Test(${", ".join(TEST_ARGS)});
449    }
450  }
451}
452
453$if UKERNEL_TYPE != "IGEMM":
454  TEST(${TEST_NAME}, n_gt_${NR}_strided_a) {
455    $if ISA_CHECK:
456      ${ISA_CHECK};
457    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
458      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
459        GemmMicrokernelTester()
460          $if EXTENDED_WEIGHTS:
461            .extended_weights(true)
462          .mr(${MR})
463          .nr(${NR})
464          .kr(${KR})
465          .sr(${SR})
466          .m(${MR})
467          .n(n)
468          .k(k)
469          .a_stride(${next_prime(KBLOCK * 5 + 1)})
470          .Test(${", ".join(TEST_ARGS)});
471      }
472    }
473  }
474
475TEST(${TEST_NAME}, n_gt_${NR}_subtile) {
476  $if ISA_CHECK:
477    ${ISA_CHECK};
478  for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
479    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
480      for (uint32_t m = 1; m <= ${MR}; m++) {
481        GemmMicrokernelTester()
482          $if EXTENDED_WEIGHTS:
483            .extended_weights(true)
484          .mr(${MR})
485          .nr(${NR})
486          .kr(${KR})
487          .sr(${SR})
488          .m(m)
489          .n(n)
490          .k(k)
491          .iterations(1)
492          .Test(${", ".join(TEST_ARGS)});
493      }
494    }
495  }
496}
497
498TEST(${TEST_NAME}, n_div_${NR}) {
499  $if ISA_CHECK:
500    ${ISA_CHECK};
501  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
502    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
503      GemmMicrokernelTester()
504        $if EXTENDED_WEIGHTS:
505          .extended_weights(true)
506        .mr(${MR})
507        .nr(${NR})
508        .kr(${KR})
509        .sr(${SR})
510        .m(${MR})
511        .n(n)
512        .k(k)
513        .Test(${", ".join(TEST_ARGS)});
514    }
515  }
516}
517
518TEST(${TEST_NAME}, n_div_${NR}_strided_cn) {
519  $if ISA_CHECK:
520    ${ISA_CHECK};
521  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
522    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
523      GemmMicrokernelTester()
524        $if EXTENDED_WEIGHTS:
525          .extended_weights(true)
526        .mr(${MR})
527        .nr(${NR})
528        .kr(${KR})
529        .sr(${SR})
530        .m(${MR})
531        .n(n)
532        .k(k)
533        .cn_stride(${next_prime(NR + 1)})
534        .Test(${", ".join(TEST_ARGS)});
535    }
536  }
537}
538
539$if UKERNEL_TYPE != "IGEMM":
540  TEST(${TEST_NAME}, n_div_${NR}_strided_a) {
541    $if ISA_CHECK:
542      ${ISA_CHECK};
543    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
544      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
545        GemmMicrokernelTester()
546          $if EXTENDED_WEIGHTS:
547            .extended_weights(true)
548          .mr(${MR})
549          .nr(${NR})
550          .kr(${KR})
551          .sr(${SR})
552          .m(${MR})
553          .n(n)
554          .k(k)
555          .a_stride(${next_prime(KBLOCK * 5 + 1)})
556          .Test(${", ".join(TEST_ARGS)});
557      }
558    }
559  }
560
561TEST(${TEST_NAME}, n_div_${NR}_subtile) {
562  $if ISA_CHECK:
563    ${ISA_CHECK};
564  for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
565    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
566      for (uint32_t m = 1; m <= ${MR}; m++) {
567        GemmMicrokernelTester()
568          $if EXTENDED_WEIGHTS:
569            .extended_weights(true)
570          .mr(${MR})
571          .nr(${NR})
572          .kr(${KR})
573          .sr(${SR})
574          .m(m)
575          .n(n)
576          .k(k)
577          .iterations(1)
578          .Test(${", ".join(TEST_ARGS)});
579      }
580    }
581  }
582}
583
584$if UKERNEL_TYPE.startswith("IGEMM"):
585  TEST(${TEST_NAME}, small_kernel) {
586    $if ISA_CHECK:
587      ${ISA_CHECK};
588    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
589      GemmMicrokernelTester()
590        $if EXTENDED_WEIGHTS:
591          .extended_weights(true)
592        .mr(${MR})
593        .nr(${NR})
594        .kr(${KR})
595        .sr(${SR})
596        .m(${MR})
597        .n(${NR})
598        .k(k)
599        .ks(3)
600        .Test(${", ".join(TEST_ARGS)});
601    }
602  }
603
604  TEST(${TEST_NAME}, small_kernel_subtile) {
605    $if ISA_CHECK:
606      ${ISA_CHECK};
607    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
608      for (uint32_t n = 1; n <= ${NR}; n++) {
609        for (uint32_t m = 1; m <= ${MR}; m++) {
610          GemmMicrokernelTester()
611            $if EXTENDED_WEIGHTS:
612              .extended_weights(true)
613            .mr(${MR})
614            .nr(${NR})
615            .kr(${KR})
616            .sr(${SR})
617            .m(m)
618            .n(n)
619            .k(k)
620            .ks(3)
621            .iterations(1)
622            .Test(${", ".join(TEST_ARGS)});
623        }
624      }
625    }
626  }
627
628  TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) {
629    $if ISA_CHECK:
630      ${ISA_CHECK};
631    for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
632      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
633        GemmMicrokernelTester()
634          $if EXTENDED_WEIGHTS:
635            .extended_weights(true)
636          .mr(${MR})
637          .nr(${NR})
638          .kr(${KR})
639          .sr(${SR})
640          .m(${MR})
641          .n(n)
642          .k(k)
643          .ks(3)
644          .Test(${", ".join(TEST_ARGS)});
645      }
646    }
647  }
648
649  TEST(${TEST_NAME}, n_div_${NR}_small_kernel) {
650    $if ISA_CHECK:
651      ${ISA_CHECK};
652    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
653      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
654        GemmMicrokernelTester()
655          $if EXTENDED_WEIGHTS:
656            .extended_weights(true)
657          .mr(${MR})
658          .nr(${NR})
659          .kr(${KR})
660          .sr(${SR})
661          .m(${MR})
662          .n(n)
663          .k(k)
664          .ks(3)
665          .Test(${", ".join(TEST_ARGS)});
666      }
667    }
668  }
669
670TEST(${TEST_NAME}, strided_cm_subtile) {
671  $if ISA_CHECK:
672    ${ISA_CHECK};
673  for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
674    for (uint32_t n = 1; n <= ${NR}; n++) {
675      for (uint32_t m = 1; m <= ${MR}; m++) {
676        GemmMicrokernelTester()
677          $if EXTENDED_WEIGHTS:
678            .extended_weights(true)
679          .mr(${MR})
680          .nr(${NR})
681          .kr(${KR})
682          .sr(${SR})
683          .m(m)
684          .n(n)
685          .k(k)
686          .cm_stride(${next_prime(NR + 1)})
687          .iterations(1)
688          .Test(${", ".join(TEST_ARGS)});
689      }
690    }
691  }
692}
693
694$if UKERNEL_TYPE.startswith("IGEMM"):
695  TEST(${TEST_NAME}, a_offset) {
696    $if ISA_CHECK:
697      ${ISA_CHECK};
698    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
699      GemmMicrokernelTester()
700        $if EXTENDED_WEIGHTS:
701          .extended_weights(true)
702        .mr(${MR})
703        .nr(${NR})
704        .kr(${KR})
705        .sr(${SR})
706        .m(${MR})
707        .n(${NR})
708        .k(k)
709        .ks(3)
710        .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
711        .Test(${", ".join(TEST_ARGS)});
712    }
713  }
714
715  TEST(${TEST_NAME}, zero) {
716    $if ISA_CHECK:
717      ${ISA_CHECK};
718    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
719      for (uint32_t mz = 0; mz < ${MR}; mz++) {
720        GemmMicrokernelTester()
721          $if EXTENDED_WEIGHTS:
722            .extended_weights(true)
723          .mr(${MR})
724          .nr(${NR})
725          .kr(${KR})
726          .sr(${SR})
727          .m(${MR})
728          .n(${NR})
729          .k(k)
730          .ks(3)
731          .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
732          .zero_index(mz)
733          .Test(${", ".join(TEST_ARGS)});
734      }
735    }
736  }
737
738$if ACTIVATION == "MINMAX":
739  TEST(${TEST_NAME}, qmin) {
740    $if ISA_CHECK:
741      ${ISA_CHECK};
742    GemmMicrokernelTester()
743      $if EXTENDED_WEIGHTS:
744        .extended_weights(true)
745      .mr(${MR})
746      .nr(${NR})
747      .kr(${KR})
748      .sr(${SR})
749      .m(${MR})
750      .n(${NR})
751      .k(${KBLOCK})
752      .qmin(128)
753      .Test(${", ".join(TEST_ARGS)});
754  }
755
756  TEST(${TEST_NAME}, qmax) {
757    $if ISA_CHECK:
758      ${ISA_CHECK};
759    GemmMicrokernelTester()
760      $if EXTENDED_WEIGHTS:
761        .extended_weights(true)
762      .mr(${MR})
763      .nr(${NR})
764      .kr(${KR})
765      .sr(${SR})
766      .m(${MR})
767      .n(${NR})
768      .k(${KBLOCK})
769      .qmax(128)
770      .Test(${", ".join(TEST_ARGS)});
771  }
772
773TEST(${TEST_NAME}, strided_cm) {
774  $if ISA_CHECK:
775    ${ISA_CHECK};
776  GemmMicrokernelTester()
777    $if EXTENDED_WEIGHTS:
778      .extended_weights(true)
779    .mr(${MR})
780    .nr(${NR})
781    .kr(${KR})
782    .sr(${SR})
783    .m(${MR})
784    .n(${NR})
785    .k(${KBLOCK})
786    .cm_stride(${next_prime(NR + 1)})
787    .Test(${", ".join(TEST_ARGS)});
788}
789
790$if DATATYPE == "qu8":
791  TEST(${TEST_NAME}, no_a_zero_point) {
792    $if ISA_CHECK:
793      ${ISA_CHECK};
794    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
795      GemmMicrokernelTester()
796        $if EXTENDED_WEIGHTS:
797          .extended_weights(true)
798        .mr(${MR})
799        .nr(${NR})
800        .kr(${KR})
801        .sr(${SR})
802        .m(${MR})
803        .n(${NR})
804        .k(k)
805        .a_zero_point(0)
806        .Test(${", ".join(TEST_ARGS)});
807    }
808  }
809
810  TEST(${TEST_NAME}, no_b_zero_point) {
811    $if ISA_CHECK:
812      ${ISA_CHECK};
813    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
814      GemmMicrokernelTester()
815        $if EXTENDED_WEIGHTS:
816          .extended_weights(true)
817        .mr(${MR})
818        .nr(${NR})
819        .kr(${KR})
820        .sr(${SR})
821        .m(${MR})
822        .n(${NR})
823        .k(k)
824        .b_zero_point(0)
825        .Test(${", ".join(TEST_ARGS)});
826    }
827  }
828
829  TEST(${TEST_NAME}, no_zero_point) {
830    $if ISA_CHECK:
831      ${ISA_CHECK};
832    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
833      GemmMicrokernelTester()
834        $if EXTENDED_WEIGHTS:
835          .extended_weights(true)
836        .mr(${MR})
837        .nr(${NR})
838        .kr(${KR})
839        .sr(${SR})
840        .m(${MR})
841        .n(${NR})
842        .k(k)
843        .a_zero_point(0)
844        .b_zero_point(0)
845        .Test(${", ".join(TEST_ARGS)});
846    }
847  }
848"""
849
850
851def generate_test_cases(ukernel, mr, nr, kr, sr, xw, k_block, init_fn,
852                        requantization, is_pipelined, isa, jit):
853  """Generates all tests cases for a GEMM micro-kernel.
854
855  Args:
856    ukernel: C name of the micro-kernel function.
857    mr: MR parameter of the GEMM micro-kernel.
858    nr: NR parameter of the GEMM micro-kernel.
859    kr: KR parameter of the GEMM micro-kernel.
860    sr: SR parameter of the GEMM micro-kernel.
861    xw: boolean indicator for microkernel with extended weights.
862    k_block: Number of K values processed per one iteration of the main loop of
863      the micro-kernel.
864    init_fn: C name of the function to initialize microkernel parameters.
865    requantization: name of the requantization scheme used by the microkernel.
866    is_pipelined: Indicates if the micro-kernel is implemented with software
867      pipelining. Additional test cases are generated for software pipelined
868      micro-kernels to separately test prologue + epiloque of the pipelined loop
869      and iteration of the pipelined loop.
870    isa: instruction set required to run the micro-kernel. Generated unit test
871      will skip execution if the host processor doesn't support this ISA.
872    jit: if we are generating test code for JIT codegen.
873
874  Returns:
875    Code for the test case.
876  """
877  _, ukernel_name = ukernel.split("_", 1)
878
879  if jit:
880    _, _, datatype, ukernel_type, _ = ukernel.split("_", 4)
881    activation = None
882  else:
883    _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
884
885  if activation == "ukernel":
886    activation = "linear"
887  test_args = [ukernel]
888  if init_fn:
889    test_args.append(init_fn)
890    if requantization:
891      requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
892      test_args.append("xnn_%s_requantize_%s" % \
893        (requantization_datatype, requantization))
894
895  if jit:
896    if "minmax" in init_fn:
897      activation = "minmax"
898
899  return xngen.preprocess(
900      GEMM_TEST_CODE, {
901          "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""),
902          "TEST_ARGS": test_args,
903          "UKERNEL_TYPE": ukernel_type.upper(),
904          "DATATYPE": datatype,
905          "ACTIVATION": activation.upper(),
906          "MR": mr,
907          "NR": nr,
908          "KR": kr,
909          "SR": sr,
910          "EXTENDED_WEIGHTS": xw,
911          "KBLOCK": k_block,
912          "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
913          "IS_PIPELINED": is_pipelined,
914          "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
915          "next_prime": next_prime,
916      })
917
918
919def main(args):
920  options = parser.parse_args(args)
921  num_output_files = len(options.output)
922
923  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
924    spec_yaml = yaml.safe_load(spec_file)
925    if not isinstance(spec_yaml, list):
926      raise ValueError("expected a list of micro-kernels in the spec")
927
928    tests = """\
929// Copyright (c) Facebook, Inc. and its affiliates.
930// All rights reserved.
931//
932// Copyright 2019 Google LLC
933//
934// This source code is licensed under the BSD-style license found in the
935// LICENSE file in the root directory of this source tree.
936//
937// Auto-generated file. Do not edit!
938//   Specification: {specification}
939//   Generator: {generator}
940
941
942#include <gtest/gtest.h>
943
944#include <xnnpack/allocator.h>
945#include <xnnpack/common.h>
946#include <xnnpack/isa-checks.h>
947
948#include <xnnpack/gemm.h>
949#include <xnnpack/igemm.h>
950#include <xnnpack/ppmm.h>
951#include "gemm-microkernel-tester.h"
952""".format(
953    specification=options.spec, generator=sys.argv[0])
954
955    outputs = collections.defaultdict(lambda: tests)
956
957    for ukernel_spec in spec_yaml:
958      name = ukernel_spec["name"]
959      k_block = int(ukernel_spec["k-block"])
960      init_fn = ukernel_spec.get("init")
961      pipelined = bool(ukernel_spec.get("pipelined", False))
962      assembly = bool(ukernel_spec.get("assembly", False))
963      jit = name.startswith("xnn_generate")
964      mr, nr, kr, sr, xw, requantization, arch, isa = split_ukernel_name(name)
965
966      # specification can override architecture
967      arch = ukernel_spec.get("arch", arch)
968
969      test_case = generate_test_cases(name, mr, nr, kr, sr, xw, k_block,
970                                      init_fn, requantization, pipelined, isa,
971                                      jit)
972
973      # Hash the name of each microkernel and figure out which output file to
974      # write it to.
975      output_index = zlib.crc32(bytes(name, 'utf-8')) % num_output_files
976      outputs[options.output[output_index]] += "\n\n" + xnncommon.postprocess_test_case(
977          test_case, arch, isa, assembly, jit)
978
979    for output_name in options.output:
980      txt_changed = True
981      if os.path.exists(output_name):
982        with codecs.open(output_name, "r", encoding="utf-8") as output_file:
983          txt_changed = output_file.read() != outputs[output_name]
984
985      if txt_changed:
986        with codecs.open(output_name, "w", encoding="utf-8") as output_file:
987          output_file.write(outputs[output_name])
988
989
990if __name__ == "__main__":
991  main(sys.argv[1:])
992