• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2021 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(
21  description='Vector conversion operation microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23                    help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25                    help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
30  match = re.match(r"^xnn_(f16|f32|qs8|qu8)_(f16|f32|qs8|qu8)_vcvt_ukernel__(.+)_x(\d+)$", name)
31  if match is None:
32    raise ValueError("Unexpected microkernel name: " + name)
33
34  input_datatype = match.group(1)
35  output_datatype = match.group(2)
36  batch_tile = int(match.group(4))
37
38  arch, isa = xnncommon.parse_target_name(target_name=match.group(3))
39  return input_datatype, output_datatype, batch_tile, arch, isa
40
41
42CVT_TEST_TEMPLATE = """\
43TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) {
44  $if ISA_CHECK:
45    ${ISA_CHECK};
46  VCvtMicrokernelTester()
47    .batch_size(${BATCH_TILE})
48    $if OUTPUT_DATATYPE == "QS8":
49      .qmin(std::numeric_limits<int8_t>::min())
50      .qmax(std::numeric_limits<int8_t>::max())
51    $elif OUTPUT_DATATYPE == "QU8":
52      .qmin(std::numeric_limits<uint8_t>::min())
53      .qmax(std::numeric_limits<uint8_t>::max())
54    .Test(${", ".join(TEST_ARGS)});
55}
56
57$if BATCH_TILE > 1:
58  TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) {
59    $if ISA_CHECK:
60      ${ISA_CHECK};
61    for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) {
62      VCvtMicrokernelTester()
63        .batch_size(batch_size)
64        $if OUTPUT_DATATYPE == "QS8":
65          .qmin(std::numeric_limits<int8_t>::min())
66          .qmax(std::numeric_limits<int8_t>::max())
67        $elif OUTPUT_DATATYPE == "QU8":
68          .qmin(std::numeric_limits<uint8_t>::min())
69          .qmax(std::numeric_limits<uint8_t>::max())
70        .Test(${", ".join(TEST_ARGS)});
71    }
72  }
73
74  TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) {
75    $if ISA_CHECK:
76      ${ISA_CHECK};
77    for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) {
78      VCvtMicrokernelTester()
79        .batch_size(batch_size)
80        $if OUTPUT_DATATYPE == "QS8":
81          .qmin(std::numeric_limits<int8_t>::min())
82          .qmax(std::numeric_limits<int8_t>::max())
83        $elif OUTPUT_DATATYPE == "QU8":
84          .qmin(std::numeric_limits<uint8_t>::min())
85          .qmax(std::numeric_limits<uint8_t>::max())
86        .Test(${", ".join(TEST_ARGS)});
87    }
88  }
89
90TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) {
91  $if ISA_CHECK:
92    ${ISA_CHECK};
93  for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) {
94    VCvtMicrokernelTester()
95      .batch_size(batch_size)
96      $if OUTPUT_DATATYPE == "QS8":
97        .qmin(std::numeric_limits<int8_t>::min())
98        .qmax(std::numeric_limits<int8_t>::max())
99      $elif OUTPUT_DATATYPE == "QU8":
100        .qmin(std::numeric_limits<uint8_t>::min())
101        .qmax(std::numeric_limits<uint8_t>::max())
102      .Test(${", ".join(TEST_ARGS)});
103  }
104}
105
106$if (INPUT_DATATYPE, OUTPUT_DATATYPE) in [("F32", "QS8"), ("F32", "QU8"), ("QS8", "F32"), ("QU8", "F32")]:
107  TEST(${TEST_NAME}, scale) {
108    $if ISA_CHECK:
109      ${ISA_CHECK};
110    for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
111      VCvtMicrokernelTester()
112        .batch_size(batch_size)
113        .scale(50)
114        $if OUTPUT_DATATYPE == "QS8":
115          .qmin(std::numeric_limits<int8_t>::min())
116          .qmax(std::numeric_limits<int8_t>::max())
117        $elif OUTPUT_DATATYPE == "QU8":
118          .zero_point(100)
119          .qmin(std::numeric_limits<uint8_t>::min())
120          .qmax(std::numeric_limits<uint8_t>::max())
121        .Test(${", ".join(TEST_ARGS)});
122    }
123  }
124
125  TEST(${TEST_NAME}, zero_point) {
126    $if ISA_CHECK:
127      ${ISA_CHECK};
128    for (int16_t zero_point = 0; zero_point < 5; zero_point += 2) {
129      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
130        VCvtMicrokernelTester()
131          .batch_size(batch_size)
132          .zero_point(zero_point)
133          $if OUTPUT_DATATYPE == "QS8":
134            .qmin(std::numeric_limits<int8_t>::min())
135            .qmax(std::numeric_limits<int8_t>::max())
136          $elif OUTPUT_DATATYPE == "QU8":
137            .qmin(std::numeric_limits<uint8_t>::min())
138            .qmax(std::numeric_limits<uint8_t>::max())
139          .Test(${", ".join(TEST_ARGS)});
140      }
141    }
142  }
143
144  $if INPUT_DATATYPE == "F32":
145    TEST(${TEST_NAME}, saturation) {
146      $if ISA_CHECK:
147        ${ISA_CHECK};
148      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
149        VCvtMicrokernelTester()
150          .batch_size(batch_size)
151          .scale(500)
152          $if OUTPUT_DATATYPE == "QS8":
153            .qmin(std::numeric_limits<int8_t>::min())
154            .qmax(std::numeric_limits<int8_t>::max())
155          $elif OUTPUT_DATATYPE == "QU8":
156            .zero_point(128)
157            .qmin(std::numeric_limits<uint8_t>::min())
158            .qmax(std::numeric_limits<uint8_t>::max())
159          .Test(${", ".join(TEST_ARGS)});
160      }
161    }
162
163    TEST(${TEST_NAME}, overflow) {
164      $if ISA_CHECK:
165        ${ISA_CHECK};
166      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
167        VCvtMicrokernelTester()
168          .batch_size(batch_size)
169          .scale(4294967296.0f)
170          $if OUTPUT_DATATYPE == "QS8":
171            .qmin(std::numeric_limits<int8_t>::min())
172            .qmax(std::numeric_limits<int8_t>::max())
173          $elif OUTPUT_DATATYPE == "QU8":
174            .qmin(std::numeric_limits<uint8_t>::min())
175            .qmax(std::numeric_limits<uint8_t>::max())
176          .Test(${", ".join(TEST_ARGS)});
177      }
178    }
179
180$if OUTPUT_DATATYPE == "QS8":
181  TEST(${TEST_NAME}, qmin) {
182    $if ISA_CHECK:
183      ${ISA_CHECK};
184    for (int16_t qmin = -128; qmin < 127; qmin += 51) {
185      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
186        VCvtMicrokernelTester()
187          .batch_size(batch_size)
188          .scale(500)
189          .qmin(qmin)
190          .qmax(std::numeric_limits<int8_t>::max())
191          .Test(${", ".join(TEST_ARGS)});
192      }
193    }
194  }
195
196  TEST(${TEST_NAME}, qmax) {
197    $if ISA_CHECK:
198      ${ISA_CHECK};
199    for (int16_t qmax = -127; qmax <= 127; qmax += 51) {
200      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
201        VCvtMicrokernelTester()
202          .batch_size(batch_size)
203          .scale(500)
204          .qmin(std::numeric_limits<int8_t>::min())
205          .qmax(qmax)
206          .Test(${", ".join(TEST_ARGS)});
207      }
208    }
209  }
210
211$if OUTPUT_DATATYPE == "QU8":
212  TEST(${TEST_NAME}, qmin) {
213    $if ISA_CHECK:
214      ${ISA_CHECK};
215    for (int16_t qmin = 0; qmin < 255; qmin += 51) {
216      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
217        VCvtMicrokernelTester()
218          .batch_size(batch_size)
219          .scale(500)
220          .zero_point(128)
221          .qmin(qmin)
222          .qmax(std::numeric_limits<uint8_t>::max())
223          .Test(${", ".join(TEST_ARGS)});
224      }
225    }
226  }
227
228  TEST(${TEST_NAME}, qmax) {
229    $if ISA_CHECK:
230      ${ISA_CHECK};
231    for (int16_t qmax = 1; qmax <= 255; qmax += 51) {
232      for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
233        VCvtMicrokernelTester()
234          .batch_size(batch_size)
235          .scale(500)
236          .zero_point(128)
237          .qmin(std::numeric_limits<uint8_t>::min())
238          .qmax(qmax)
239          .Test(${", ".join(TEST_ARGS)});
240      }
241    }
242  }
243"""
244
245
246def generate_test_cases(ukernel, init_fn, input_datatype, output_datatype,
247                        batch_tile, isa):
248  """Generates all tests cases for a Vector Convert Operation micro-kernel.
249
250  Args:
251    ukernel: C name of the micro-kernel function.
252    init_fn: C name of the function to initialize microkernel parameters.
253    input_datatype: input conversion data type.
254    output_datatype: output conversion data type.
255    batch_tile: Number of batch elements processed per one iteration of the
256                inner loop of the micro-kernel.
257    isa: instruction set required to run the micro-kernel. Generated unit test
258         will skip execution if the host processor doesn't support this ISA.
259
260  Returns:
261    Code for the test case.
262  """
263  _, test_name = ukernel.split("_", 1)
264  test_args = [ukernel]
265  if init_fn:
266    test_args.append(init_fn)
267  return xngen.preprocess(CVT_TEST_TEMPLATE, {
268      "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
269      "TEST_ARGS": test_args,
270      "BATCH_TILE": batch_tile,
271      "INPUT_DATATYPE": input_datatype.upper(),
272      "OUTPUT_DATATYPE": output_datatype.upper(),
273      "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
274    })
275
276
277def main(args):
278  options = parser.parse_args(args)
279
280  with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
281    spec_yaml = yaml.safe_load(spec_file)
282    if not isinstance(spec_yaml, list):
283      raise ValueError("expected a list of micro-kernels in the spec")
284
285    tests = """\
286// Copyright 2021 Google LLC
287//
288// This source code is licensed under the BSD-style license found in the
289// LICENSE file in the root directory of this source tree.
290//
291// Auto-generated file. Do not edit!
292//   Specification: {specification}
293//   Generator: {generator}
294
295
296#include <gtest/gtest.h>
297
298#include <xnnpack/common.h>
299#include <xnnpack/isa-checks.h>
300
301#include <xnnpack/vcvt.h>
302#include "vcvt-microkernel-tester.h"
303""".format(specification=options.spec, generator=sys.argv[0])
304
305    for ukernel_spec in spec_yaml:
306      name = ukernel_spec["name"]
307      init_fn = ukernel_spec.get("init")
308      input_datatype, output_datatype, batch_tile, arch, isa = \
309        split_ukernel_name(name)
310
311      # specification can override architecture
312      arch = ukernel_spec.get("arch", arch)
313
314      test_case = generate_test_cases(
315        name, init_fn, input_datatype, output_datatype, batch_tile, isa)
316      tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
317
318    txt_changed = True
319    if os.path.exists(options.output):
320      with codecs.open(options.output, "r", encoding="utf-8") as output_file:
321        txt_changed = output_file.read() != tests
322
323    if txt_changed:
324      with codecs.open(options.output, "w", encoding="utf-8") as output_file:
325        output_file.write(tests)
326
327
328if __name__ == "__main__":
329  main(sys.argv[1:])
330