• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22                    help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24                    help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
28def split_ukernel_name(name):
29  common_name, target_name = name.split("__", 1)
30  common_parts = common_name.split("_")
31  param_spec = common_parts[-1]
32  mr, nr = map(int, param_spec.split("x"))
33  arch, isa = xnncommon.parse_target_name(target_name)
34  return mr, nr, arch, isa
35
36
37TEST_TEMPLATE = """\
38TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
39  $if ISA_CHECK:
40    ${ISA_CHECK};
41  SpMMMicrokernelTester()
42    .mr(${MR})
43    .nr(${NR})
44    .m(${MR})
45    .n(${NR})
46    .k(${KBLOCK})
47    .sparsity(0.0f)
48    .Test(${", ".join(TEST_ARGS)});
49}
50
51$if NR > 1:
52  TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
53    $if ISA_CHECK:
54      ${ISA_CHECK};
55    for (uint32_t n = 1; n <= ${NR}; n++) {
56      SpMMMicrokernelTester()
57        .mr(${MR})
58        .nr(${NR})
59        .m(${MR})
60        .n(n)
61        .k(${KBLOCK})
62        .sparsity(0.0f)
63        .Test(${", ".join(TEST_ARGS)});
64    }
65  }
66
67$if IS_PIPELINED:
68  TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
69    $if ISA_CHECK:
70      ${ISA_CHECK};
71    SpMMMicrokernelTester()
72      .mr(${MR})
73      .nr(${NR})
74      .m(${MR})
75      .n(${NR})
76      .k(${KBLOCK * 2})
77      .sparsity(0.0f)
78      .Test(${", ".join(TEST_ARGS)});
79  }
80
81  $if NR > 1:
82    TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
83      $if ISA_CHECK:
84        ${ISA_CHECK};
85      for (uint32_t n = 1; n <= ${NR}; n++) {
86        SpMMMicrokernelTester()
87          .mr(${MR})
88          .nr(${NR})
89          .m(${MR})
90          .n(n)
91          .k(${KBLOCK * 2})
92          .sparsity(0.0f)
93          .Test(${", ".join(TEST_ARGS)});
94      }
95    }
96
97$if KBLOCK > 1:
98  TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
99    $if ISA_CHECK:
100      ${ISA_CHECK};
101    for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
102      SpMMMicrokernelTester()
103        .mr(${MR})
104        .nr(${NR})
105        .m(${MR})
106        .n(${NR})
107        .k(k)
108        .sparsity(0.0f)
109        .Test(${", ".join(TEST_ARGS)});
110    }
111  }
112
113  $if NR > 1:
114    TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
115      $if ISA_CHECK:
116        ${ISA_CHECK};
117      for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
118        for (uint32_t n = 1; n <= ${NR}; n++) {
119          SpMMMicrokernelTester()
120            .mr(${MR})
121            .nr(${NR})
122            .m(${MR})
123            .n(n)
124            .k(k)
125            .sparsity(0.0f)
126            .Test(${", ".join(TEST_ARGS)});
127        }
128      }
129    }
130
131TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
132  $if ISA_CHECK:
133    ${ISA_CHECK};
134  for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
135    SpMMMicrokernelTester()
136      .mr(${MR})
137      .nr(${NR})
138      .m(${MR})
139      .n(${NR})
140      .k(k)
141      .sparsity(0.0f)
142      .Test(${", ".join(TEST_ARGS)});
143  }
144}
145
146$if NR > 1:
147  TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
148    $if ISA_CHECK:
149      ${ISA_CHECK};
150    for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
151      for (uint32_t n = 1; n <= ${NR}; n++) {
152        SpMMMicrokernelTester()
153          .mr(${MR})
154          .nr(${NR})
155          .m(${MR})
156          .n(n)
157          .k(k)
158          .sparsity(0.0f)
159          .Test(${", ".join(TEST_ARGS)});
160      }
161    }
162  }
163
164$if KBLOCK > 1:
165  TEST(${TEST_NAME}, k_div_${KBLOCK}) {
166    $if ISA_CHECK:
167      ${ISA_CHECK};
168    for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
169      SpMMMicrokernelTester()
170        .mr(${MR})
171        .nr(${NR})
172        .m(${MR})
173        .n(${NR})
174        .k(k)
175        .sparsity(0.0f)
176        .Test(${", ".join(TEST_ARGS)});
177    }
178  }
179
180  $if NR > 1:
181    TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
182      $if ISA_CHECK:
183        ${ISA_CHECK};
184      for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
185        for (uint32_t n = 1; n <= ${NR}; n++) {
186          SpMMMicrokernelTester()
187            .mr(${MR})
188            .nr(${NR})
189            .m(${MR})
190            .n(n)
191            .k(k)
192            .sparsity(0.0f)
193            .Test(${", ".join(TEST_ARGS)});
194        }
195      }
196    }
197
198TEST(${TEST_NAME}, n_gt_${NR}) {
199  $if ISA_CHECK:
200    ${ISA_CHECK};
201  for (uint32_t n = ${NR + 1}; n < ${max(10, NR * 2)}; n++) {
202    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
203      SpMMMicrokernelTester()
204        .mr(${MR})
205        .nr(${NR})
206        .m(${MR})
207        .n(n)
208        .k(k)
209        .sparsity(0.0f)
210        .Test(${", ".join(TEST_ARGS)});
211    }
212  }
213}
214
215$if NR > 1:
216  TEST(${TEST_NAME}, n_div_${NR}) {
217    $if ISA_CHECK:
218      ${ISA_CHECK};
219    for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
220      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
221        SpMMMicrokernelTester()
222          .mr(${MR})
223          .nr(${NR})
224          .m(${MR})
225          .n(n)
226          .k(k)
227          .Test(${", ".join(TEST_ARGS)});
228      }
229    }
230  }
231
232TEST(${TEST_NAME}, m_lt_${MR}) {
233  $if ISA_CHECK:
234    ${ISA_CHECK};
235  for (uint32_t m = ${1}; m < ${MR}; m++) {
236    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
237      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
238        SpMMMicrokernelTester()
239          .mr(${MR})
240          .nr(${NR})
241          .m(m)
242          .n(n)
243          .k(k)
244          .sparsity(0.0f)
245          .Test(${", ".join(TEST_ARGS)});
246      }
247    }
248  }
249}
250
251TEST(${TEST_NAME}, m_div_${MR}) {
252  $if ISA_CHECK:
253    ${ISA_CHECK};
254  for (uint32_t m = ${MR * 2}; m <= ${MR * 3}; m += ${MR}) {
255    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
256      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
257        SpMMMicrokernelTester()
258          .mr(${MR})
259          .nr(${NR})
260          .m(m)
261          .n(n)
262          .k(k)
263          .sparsity(0.0f)
264          .Test(${", ".join(TEST_ARGS)});
265      }
266    }
267  }
268}
269
270TEST(${TEST_NAME}, m_gt_${MR}) {
271  $if ISA_CHECK:
272    ${ISA_CHECK};
273  for (uint32_t m = ${MR + 1}; m < ${MR * 2}; m++) {
274    for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
275      for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
276        SpMMMicrokernelTester()
277          .mr(${MR})
278          .nr(${NR})
279          .m(m)
280          .n(n)
281          .k(k)
282          .sparsity(0.0f)
283          .Test(${", ".join(TEST_ARGS)});
284      }
285    }
286  }
287}
288
289TEST(${TEST_NAME}, qmin) {
290  $if ISA_CHECK:
291    ${ISA_CHECK};
292  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
293    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
294      SpMMMicrokernelTester()
295        .mr(${MR})
296        .nr(${NR})
297        .m(${MR * 2})
298        .n(n)
299        .k(k)
300        .sparsity(0.0f)
301        .qmin(128)
302        .Test(${", ".join(TEST_ARGS)});
303    }
304  }
305}
306
307TEST(${TEST_NAME}, qmax) {
308  $if ISA_CHECK:
309    ${ISA_CHECK};
310  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
311    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
312      SpMMMicrokernelTester()
313        .mr(${MR})
314        .nr(${NR})
315        .m(${MR * 2})
316        .n(n)
317        .k(k)
318        .sparsity(0.0f)
319        .qmax(128)
320        .Test(${", ".join(TEST_ARGS)});
321    }
322  }
323}
324
325TEST(${TEST_NAME}, half_sparse) {
326  $if ISA_CHECK:
327    ${ISA_CHECK};
328  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
329    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
330      SpMMMicrokernelTester()
331        .mr(${MR})
332        .nr(${NR})
333        .m(${MR * 2})
334        .n(n)
335        .k(k)
336        .sparsity(0.5f)
337        .Test(${", ".join(TEST_ARGS)});
338    }
339  }
340}
341
342TEST(${TEST_NAME}, zero_weights) {
343  $if ISA_CHECK:
344    ${ISA_CHECK};
345  for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
346    for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
347      SpMMMicrokernelTester()
348        .mr(${MR})
349        .nr(${NR})
350        .m(${MR * 2})
351        .n(n)
352        .k(k)
353        .sparsity(1.0f)
354        .Test(${", ".join(TEST_ARGS)});
355    }
356  }
357}
358"""
359
360
361def generate_test_cases(ukernel, mr, nr, k_block, is_pipelined, isa):
362  """Generates all tests cases for a GEMM micro-kernel.
363
364  Args:
365    ukernel: C name of the micro-kernel function.
366    mr: MR parameter of the GEMM micro-kernel.
367    nr: NR parameter of the GEMM micro-kernel.
368    k_block: Number of K values processed per one iteration of the main loop of
369             the micro-kernel.
370    is_pipelined: Indicates if the micro-kernel is implemented with software
371                  pipelining. Additional test cases are generated for software
372                  pipelined micro-kernels to separately test prologue + epiloque
373                  of the pipelined loop and iteration of the pipelined loop.
374    isa: instruction set required to run the micro-kernel. Generated unit test
375         will skip execution if the host processor doesn't support this ISA.
376
377  Returns:
378    Code for the test case.
379  """
380  _, test_name = ukernel.split("_", 1)
381  _, datatype, ukernel_type, _ = ukernel.split("_", 3)
382  test_args = [ukernel]
383  if not isa or isa == "psimd":
384    test_args.append("SpMMMicrokernelTester::Variant::Scalar")
385  return xngen.preprocess(TEST_TEMPLATE, {
386      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
387      "TEST_ARGS": test_args,
388      "UKERNEL_TYPE": ukernel_type.upper(),
389      "DATATYPE": datatype,
390      "MR": mr,
391      "NR": nr,
392      "KBLOCK": k_block,
393      "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
394      "IS_PIPELINED": is_pipelined,
395      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
396      "next_prime": next_prime,
397    })
398
399
400def main(args):
401  options = parser.parse_args(args)
402
403  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
404    spec_yaml = yaml.safe_load(spec_file)
405    if not isinstance(spec_yaml, list):
406      raise ValueError("expected a list of micro-kernels in the spec")
407
408    tests = """\
409// Copyright 2019 Google LLC
410//
411// This source code is licensed under the BSD-style license found in the
412// LICENSE file in the root directory of this source tree.
413//
414// Auto-generated file. Do not edit!
415//   Specification: {specification}
416//   Generator: {generator}
417
418
419#include <gtest/gtest.h>
420
421#include <xnnpack/common.h>
422#include <xnnpack/isa-checks.h>
423
424#include <xnnpack/spmm.h>
425#include "spmm-microkernel-tester.h"
426""".format(specification=options.spec, generator=sys.argv[0])
427
428    for ukernel_spec in spec_yaml:
429      name = ukernel_spec["name"]
430      k_block = int(ukernel_spec["k-block"])
431      pipelined = bool(ukernel_spec.get("pipelined", False))
432      mr, nr, arch, isa = split_ukernel_name(name)
433
434      # specification can override architecture
435      arch = ukernel_spec.get("arch", arch)
436
437      test_case = generate_test_cases(name, mr, nr, k_block, pipelined, isa)
438      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
439
440    with codecs.open(options.output, "w", encoding="utf-8") as output_file:
441      output_file.write(tests)
442
443
444if __name__ == "__main__":
445  main(sys.argv[1:])
446