1#!/usr/bin/env python 2# Copyright 2021 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser( 21 description='Vector conversion operation microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.match(r"^xnn_(f16|f32|qs8|qu8)_(f16|f32|qs8|qu8)_vcvt_ukernel__(.+)_x(\d+)$", name) 31 if match is None: 32 raise ValueError("Unexpected microkernel name: " + name) 33 34 input_datatype = match.group(1) 35 output_datatype = match.group(2) 36 batch_tile = int(match.group(4)) 37 38 arch, isa = xnncommon.parse_target_name(target_name=match.group(3)) 39 return input_datatype, output_datatype, batch_tile, arch, isa 40 41 42CVT_TEST_TEMPLATE = """\ 43TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 44 $if ISA_CHECK: 45 ${ISA_CHECK}; 46 VCvtMicrokernelTester() 47 .batch_size(${BATCH_TILE}) 48 $if OUTPUT_DATATYPE == "QS8": 49 .qmin(std::numeric_limits<int8_t>::min()) 50 .qmax(std::numeric_limits<int8_t>::max()) 51 $elif OUTPUT_DATATYPE == "QU8": 52 .qmin(std::numeric_limits<uint8_t>::min()) 53 .qmax(std::numeric_limits<uint8_t>::max()) 54 .Test(${", ".join(TEST_ARGS)}); 55} 56 57$if BATCH_TILE > 1: 58 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 59 $if ISA_CHECK: 60 ${ISA_CHECK}; 61 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) { 62 VCvtMicrokernelTester() 63 .batch_size(batch_size) 64 $if OUTPUT_DATATYPE == "QS8": 65 .qmin(std::numeric_limits<int8_t>::min()) 66 .qmax(std::numeric_limits<int8_t>::max()) 67 $elif OUTPUT_DATATYPE == "QU8": 68 .qmin(std::numeric_limits<uint8_t>::min()) 69 .qmax(std::numeric_limits<uint8_t>::max()) 70 .Test(${", ".join(TEST_ARGS)}); 71 } 72 } 73 74 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 75 $if ISA_CHECK: 76 ${ISA_CHECK}; 77 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) { 78 VCvtMicrokernelTester() 79 .batch_size(batch_size) 80 $if OUTPUT_DATATYPE == "QS8": 81 .qmin(std::numeric_limits<int8_t>::min()) 82 .qmax(std::numeric_limits<int8_t>::max()) 83 $elif OUTPUT_DATATYPE == "QU8": 84 .qmin(std::numeric_limits<uint8_t>::min()) 85 .qmax(std::numeric_limits<uint8_t>::max()) 86 .Test(${", ".join(TEST_ARGS)}); 87 } 88 } 89 90TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 91 $if ISA_CHECK: 92 ${ISA_CHECK}; 93 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) { 94 VCvtMicrokernelTester() 95 .batch_size(batch_size) 96 $if OUTPUT_DATATYPE == "QS8": 97 .qmin(std::numeric_limits<int8_t>::min()) 98 .qmax(std::numeric_limits<int8_t>::max()) 99 $elif OUTPUT_DATATYPE == "QU8": 100 .qmin(std::numeric_limits<uint8_t>::min()) 101 .qmax(std::numeric_limits<uint8_t>::max()) 102 .Test(${", ".join(TEST_ARGS)}); 103 } 104} 105 106$if (INPUT_DATATYPE, OUTPUT_DATATYPE) in [("F32", "QS8"), ("F32", "QU8"), ("QS8", "F32"), ("QU8", "F32")]: 107 TEST(${TEST_NAME}, scale) { 108 $if ISA_CHECK: 109 ${ISA_CHECK}; 110 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 111 VCvtMicrokernelTester() 112 .batch_size(batch_size) 113 .scale(50) 114 $if OUTPUT_DATATYPE == "QS8": 115 .qmin(std::numeric_limits<int8_t>::min()) 116 .qmax(std::numeric_limits<int8_t>::max()) 117 $elif OUTPUT_DATATYPE == "QU8": 118 .zero_point(100) 119 .qmin(std::numeric_limits<uint8_t>::min()) 120 .qmax(std::numeric_limits<uint8_t>::max()) 121 .Test(${", ".join(TEST_ARGS)}); 122 } 123 } 124 125 TEST(${TEST_NAME}, zero_point) { 126 $if ISA_CHECK: 127 ${ISA_CHECK}; 128 for (int16_t zero_point = 0; zero_point < 5; zero_point += 2) { 129 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 130 VCvtMicrokernelTester() 131 .batch_size(batch_size) 132 .zero_point(zero_point) 133 $if OUTPUT_DATATYPE == "QS8": 134 .qmin(std::numeric_limits<int8_t>::min()) 135 .qmax(std::numeric_limits<int8_t>::max()) 136 $elif OUTPUT_DATATYPE == "QU8": 137 .qmin(std::numeric_limits<uint8_t>::min()) 138 .qmax(std::numeric_limits<uint8_t>::max()) 139 .Test(${", ".join(TEST_ARGS)}); 140 } 141 } 142 } 143 144 $if INPUT_DATATYPE == "F32": 145 TEST(${TEST_NAME}, saturation) { 146 $if ISA_CHECK: 147 ${ISA_CHECK}; 148 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 149 VCvtMicrokernelTester() 150 .batch_size(batch_size) 151 .scale(500) 152 $if OUTPUT_DATATYPE == "QS8": 153 .qmin(std::numeric_limits<int8_t>::min()) 154 .qmax(std::numeric_limits<int8_t>::max()) 155 $elif OUTPUT_DATATYPE == "QU8": 156 .zero_point(128) 157 .qmin(std::numeric_limits<uint8_t>::min()) 158 .qmax(std::numeric_limits<uint8_t>::max()) 159 .Test(${", ".join(TEST_ARGS)}); 160 } 161 } 162 163 TEST(${TEST_NAME}, overflow) { 164 $if ISA_CHECK: 165 ${ISA_CHECK}; 166 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 167 VCvtMicrokernelTester() 168 .batch_size(batch_size) 169 .scale(4294967296.0f) 170 $if OUTPUT_DATATYPE == "QS8": 171 .qmin(std::numeric_limits<int8_t>::min()) 172 .qmax(std::numeric_limits<int8_t>::max()) 173 $elif OUTPUT_DATATYPE == "QU8": 174 .qmin(std::numeric_limits<uint8_t>::min()) 175 .qmax(std::numeric_limits<uint8_t>::max()) 176 .Test(${", ".join(TEST_ARGS)}); 177 } 178 } 179 180$if OUTPUT_DATATYPE == "QS8": 181 TEST(${TEST_NAME}, qmin) { 182 $if ISA_CHECK: 183 ${ISA_CHECK}; 184 for (int16_t qmin = -128; qmin < 127; qmin += 51) { 185 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 186 VCvtMicrokernelTester() 187 .batch_size(batch_size) 188 .scale(500) 189 .qmin(qmin) 190 .qmax(std::numeric_limits<int8_t>::max()) 191 .Test(${", ".join(TEST_ARGS)}); 192 } 193 } 194 } 195 196 TEST(${TEST_NAME}, qmax) { 197 $if ISA_CHECK: 198 ${ISA_CHECK}; 199 for (int16_t qmax = -127; qmax <= 127; qmax += 51) { 200 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 201 VCvtMicrokernelTester() 202 .batch_size(batch_size) 203 .scale(500) 204 .qmin(std::numeric_limits<int8_t>::min()) 205 .qmax(qmax) 206 .Test(${", ".join(TEST_ARGS)}); 207 } 208 } 209 } 210 211$if OUTPUT_DATATYPE == "QU8": 212 TEST(${TEST_NAME}, qmin) { 213 $if ISA_CHECK: 214 ${ISA_CHECK}; 215 for (int16_t qmin = 0; qmin < 255; qmin += 51) { 216 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 217 VCvtMicrokernelTester() 218 .batch_size(batch_size) 219 .scale(500) 220 .zero_point(128) 221 .qmin(qmin) 222 .qmax(std::numeric_limits<uint8_t>::max()) 223 .Test(${", ".join(TEST_ARGS)}); 224 } 225 } 226 } 227 228 TEST(${TEST_NAME}, qmax) { 229 $if ISA_CHECK: 230 ${ISA_CHECK}; 231 for (int16_t qmax = 1; qmax <= 255; qmax += 51) { 232 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 233 VCvtMicrokernelTester() 234 .batch_size(batch_size) 235 .scale(500) 236 .zero_point(128) 237 .qmin(std::numeric_limits<uint8_t>::min()) 238 .qmax(qmax) 239 .Test(${", ".join(TEST_ARGS)}); 240 } 241 } 242 } 243""" 244 245 246def generate_test_cases(ukernel, init_fn, input_datatype, output_datatype, 247 batch_tile, isa): 248 """Generates all tests cases for a Vector Convert Operation micro-kernel. 249 250 Args: 251 ukernel: C name of the micro-kernel function. 252 init_fn: C name of the function to initialize microkernel parameters. 253 input_datatype: input conversion data type. 254 output_datatype: output conversion data type. 255 batch_tile: Number of batch elements processed per one iteration of the 256 inner loop of the micro-kernel. 257 isa: instruction set required to run the micro-kernel. Generated unit test 258 will skip execution if the host processor doesn't support this ISA. 259 260 Returns: 261 Code for the test case. 262 """ 263 _, test_name = ukernel.split("_", 1) 264 test_args = [ukernel] 265 if init_fn: 266 test_args.append(init_fn) 267 return xngen.preprocess(CVT_TEST_TEMPLATE, { 268 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 269 "TEST_ARGS": test_args, 270 "BATCH_TILE": batch_tile, 271 "INPUT_DATATYPE": input_datatype.upper(), 272 "OUTPUT_DATATYPE": output_datatype.upper(), 273 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 274 }) 275 276 277def main(args): 278 options = parser.parse_args(args) 279 280 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 281 spec_yaml = yaml.safe_load(spec_file) 282 if not isinstance(spec_yaml, list): 283 raise ValueError("expected a list of micro-kernels in the spec") 284 285 tests = """\ 286// Copyright 2021 Google LLC 287// 288// This source code is licensed under the BSD-style license found in the 289// LICENSE file in the root directory of this source tree. 290// 291// Auto-generated file. Do not edit! 292// Specification: {specification} 293// Generator: {generator} 294 295 296#include <gtest/gtest.h> 297 298#include <xnnpack/common.h> 299#include <xnnpack/isa-checks.h> 300 301#include <xnnpack/vcvt.h> 302#include "vcvt-microkernel-tester.h" 303""".format(specification=options.spec, generator=sys.argv[0]) 304 305 for ukernel_spec in spec_yaml: 306 name = ukernel_spec["name"] 307 init_fn = ukernel_spec.get("init") 308 input_datatype, output_datatype, batch_tile, arch, isa = \ 309 split_ukernel_name(name) 310 311 # specification can override architecture 312 arch = ukernel_spec.get("arch", arch) 313 314 test_case = generate_test_cases( 315 name, init_fn, input_datatype, output_datatype, batch_tile, isa) 316 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 317 318 txt_changed = True 319 if os.path.exists(options.output): 320 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 321 txt_changed = output_file.read() != tests 322 323 if txt_changed: 324 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 325 output_file.write(tests) 326 327 328if __name__ == "__main__": 329 main(sys.argv[1:]) 330