1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser( 21 description='Vector unary operation microkernel test generator') 22parser.add_argument("-s", "--spec", metavar="FILE", required=True, 23 help="Specification (YAML) file") 24parser.add_argument("-o", "--output", metavar="FILE", required=True, 25 help='Output (C++ source) file') 26parser.set_defaults(defines=list()) 27 28 29def split_ukernel_name(name): 30 match = re.match(r"^xnn_(s8|u8|f16|f32)_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt)_(fact_)?ukernel__(.+)_x(\d+)$", name) 31 if match is None: 32 raise ValueError("Unexpected microkernel name: " + name) 33 op_type = { 34 "abs": "Abs", 35 "clamp": "Clamp", 36 "elu": "ELU", 37 "hswish": "HardSwish", 38 "lrelu": "LeakyReLU", 39 "neg": "Negate", 40 "relu": "ReLU", 41 "rndd": "RoundDown", 42 "rndne": "RoundToNearestEven", 43 "rndz": "RoundTowardsZero", 44 "rndu": "RoundUp", 45 "sigmoid": "Sigmoid", 46 "sqr": "Square", 47 "sqrt": "SquareRoot", 48 }[match.group(2)] 49 batch_tile = int(match.group(5)) 50 51 arch, isa = xnncommon.parse_target_name(target_name=match.group(4)) 52 return op_type, batch_tile, arch, isa 53 54 55BINOP_TEST_TEMPLATE = """\ 56TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) { 57 $if ISA_CHECK: 58 ${ISA_CHECK}; 59 VUnaryMicrokernelTester() 60 .batch_size(${BATCH_TILE}) 61 .Test(${", ".join(TEST_ARGS)}); 62} 63 64$if BATCH_TILE > 1: 65 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) { 66 $if ISA_CHECK: 67 ${ISA_CHECK}; 68 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) { 69 VUnaryMicrokernelTester() 70 .batch_size(batch_size) 71 .Test(${", ".join(TEST_ARGS)}); 72 } 73 } 74 75 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) { 76 $if ISA_CHECK: 77 ${ISA_CHECK}; 78 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) { 79 VUnaryMicrokernelTester() 80 .batch_size(batch_size) 81 .Test(${", ".join(TEST_ARGS)}); 82 } 83 } 84 85TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) { 86 $if ISA_CHECK: 87 ${ISA_CHECK}; 88 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) { 89 VUnaryMicrokernelTester() 90 .batch_size(batch_size) 91 .Test(${", ".join(TEST_ARGS)}); 92 } 93} 94 95TEST(${TEST_NAME}, inplace) { 96 $if ISA_CHECK: 97 ${ISA_CHECK}; 98 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 99 VUnaryMicrokernelTester() 100 .batch_size(batch_size) 101 .inplace(true) 102 .Test(${", ".join(TEST_ARGS)}); 103 } 104} 105 106$if OP_TYPE == "Clamp": 107 TEST(${TEST_NAME}, qmin) { 108 $if ISA_CHECK: 109 ${ISA_CHECK}; 110 for (uint8_t qmin = 1; qmin < 255; qmin++) { 111 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 112 VUnaryMicrokernelTester() 113 .batch_size(batch_size) 114 .qmin(qmin) 115 .Test(${", ".join(TEST_ARGS)}); 116 } 117 } 118 } 119 120 TEST(${TEST_NAME}, qmax) { 121 $if ISA_CHECK: 122 ${ISA_CHECK}; 123 for (uint8_t qmax = 1; qmax < 255; qmax++) { 124 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 125 VUnaryMicrokernelTester() 126 .batch_size(batch_size) 127 .qmax(qmax) 128 .Test(${", ".join(TEST_ARGS)}); 129 } 130 } 131 } 132 133$if OP_TYPE == "ELU": 134 TEST(${TEST_NAME}, prescale) { 135 $if ISA_CHECK: 136 ${ISA_CHECK}; 137 for (float prescale : std::vector<float>({0.1f, 10.0f})) { 138 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 139 VUnaryMicrokernelTester() 140 .batch_size(batch_size) 141 .prescale(prescale) 142 .Test(${", ".join(TEST_ARGS)}); 143 } 144 } 145 } 146 147 TEST(${TEST_NAME}, alpha) { 148 $if ISA_CHECK: 149 ${ISA_CHECK}; 150 for (float alpha : std::vector<float>({0.3f, 3.0f})) { 151 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 152 VUnaryMicrokernelTester() 153 .batch_size(batch_size) 154 .alpha(alpha) 155 .Test(${", ".join(TEST_ARGS)}); 156 } 157 } 158 } 159 160 TEST(${TEST_NAME}, beta) { 161 $if ISA_CHECK: 162 ${ISA_CHECK}; 163 for (float beta : std::vector<float>({0.3f, 3.0f})) { 164 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 165 VUnaryMicrokernelTester() 166 .batch_size(batch_size) 167 .beta(beta) 168 .Test(${", ".join(TEST_ARGS)}); 169 } 170 } 171 } 172 173$if OP_TYPE == "LeakyReLU": 174 TEST(${TEST_NAME}, slope) { 175 $if ISA_CHECK: 176 ${ISA_CHECK}; 177 for (float slope : std::vector<float>({-0.7f, 0.3f, 1.3f})) { 178 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) { 179 VUnaryMicrokernelTester() 180 .batch_size(batch_size) 181 .slope(slope) 182 .Test(${", ".join(TEST_ARGS)}); 183 } 184 } 185 } 186""" 187 188 189def generate_test_cases(ukernel, op_type, init_fn, batch_tile, isa): 190 """Generates all tests cases for a Vector Unary Operation micro-kernel. 191 192 Args: 193 ukernel: C name of the micro-kernel function. 194 op_type: Operation type. 195 init_fn: C name of the function to initialize microkernel parameters. 196 batch_tile: Number of batch elements processed per one iteration of the 197 inner loop of the micro-kernel. 198 isa: instruction set required to run the micro-kernel. Generated unit test 199 will skip execution if the host processor doesn't support this ISA. 200 201 Returns: 202 Code for the test case. 203 """ 204 _, test_name = ukernel.split("_", 1) 205 _, datatype, _ = ukernel.split("_", 2) 206 test_args = [ukernel] 207 if init_fn or op_type.startswith("Round"): 208 if op_type.startswith("Round"): 209 test_args.append("VUnaryMicrokernelTester::OpType::" + op_type) 210 if init_fn is not None: 211 test_args.append(init_fn) 212 elif op_type not in ["Abs", "Negate", "Square", "SquareRoot"]: 213 test_args.append("VUnaryMicrokernelTester::OpType::" + op_type) 214 if not isa: 215 test_args.append("VUnaryMicrokernelTester::Variant::Scalar") 216 return xngen.preprocess(BINOP_TEST_TEMPLATE, { 217 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 218 "TEST_ARGS": test_args, 219 "DATATYPE": datatype, 220 "BATCH_TILE": batch_tile, 221 "OP_TYPE": op_type, 222 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 223 }) 224 225 226def main(args): 227 options = parser.parse_args(args) 228 229 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 230 spec_yaml = yaml.safe_load(spec_file) 231 if not isinstance(spec_yaml, list): 232 raise ValueError("expected a list of micro-kernels in the spec") 233 234 tests = """\ 235// Copyright 2019 Google LLC 236// 237// This source code is licensed under the BSD-style license found in the 238// LICENSE file in the root directory of this source tree. 239// 240// Auto-generated file. Do not edit! 241// Specification: {specification} 242// Generator: {generator} 243 244 245#include <gtest/gtest.h> 246 247#include <xnnpack/common.h> 248#include <xnnpack/isa-checks.h> 249 250#include <xnnpack/vunary.h> 251#include "vunary-microkernel-tester.h" 252""".format(specification=options.spec, generator=sys.argv[0]) 253 254 for ukernel_spec in spec_yaml: 255 name = ukernel_spec["name"] 256 init_fn = ukernel_spec.get("init") 257 op_type, batch_tile, arch, isa = split_ukernel_name(name) 258 259 # specification can override architecture 260 arch = ukernel_spec.get("arch", arch) 261 262 test_case = generate_test_cases(name, op_type, init_fn, batch_tile, isa) 263 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 264 265 txt_changed = True 266 if os.path.exists(options.output): 267 with codecs.open(options.output, "r", encoding="utf-8") as output_file: 268 txt_changed = output_file.read() != tests 269 270 if txt_changed: 271 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 272 output_file.write(tests) 273 274 275if __name__ == "__main__": 276 main(sys.argv[1:]) 277