1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import codecs 9import math 10import os 11import re 12import sys 13import yaml 14 15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 16from primes import next_prime 17import xngen 18import xnncommon 19 20 21parser = argparse.ArgumentParser( 22 description='VMulCAddC microkernel test generator') 23parser.add_argument("-s", "--spec", metavar="FILE", required=True, 24 help="Specification (YAML) file") 25parser.add_argument("-o", "--output", metavar="FILE", required=True, 26 help='Output (C++ source) file') 27parser.set_defaults(defines=list()) 28 29 30def split_ukernel_name(name): 31 match = re.match(r"^xnn_(f16|f32)_vmulcaddc(_(minmax))?_ukernel_c(\d+)__(.+)_(\d+)x$", name) 32 assert match is not None 33 channel_tile = int(match.group(4)) 34 row_tile = int(match.group(6)) 35 36 arch, isa = xnncommon.parse_target_name(target_name=match.group(5)) 37 return channel_tile, row_tile, arch, isa 38 39 40VMULCADDC_TEST_TEMPLATE = """\ 41TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}) { 42 $if ISA_CHECK: 43 ${ISA_CHECK}; 44 VMulCAddCMicrokernelTester() 45 .channel_tile(${CHANNEL_TILE}) 46 .channels(${CHANNEL_TILE}) 47 .rows(${ROW_TILE}) 48 .Test(${", ".join(TEST_ARGS)}); 49} 50 51$if CHANNEL_TILE > 1: 52 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}) { 53 $if ISA_CHECK: 54 ${ISA_CHECK}; 55 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*10}; channels += ${CHANNEL_TILE}) { 56 VMulCAddCMicrokernelTester() 57 .channel_tile(${CHANNEL_TILE}) 58 .channels(channels) 59 .rows(${ROW_TILE}) 60 .Test(${", ".join(TEST_ARGS)}); 61 } 62 } 63 64 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}) { 65 $if ISA_CHECK: 66 ${ISA_CHECK}; 67 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) { 68 VMulCAddCMicrokernelTester() 69 .channel_tile(${CHANNEL_TILE}) 70 .channels(channels) 71 .rows(${ROW_TILE}) 72 .Test(${", ".join(TEST_ARGS)}); 73 } 74 } 75 76TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}) { 77 $if ISA_CHECK: 78 ${ISA_CHECK}; 79 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) { 80 VMulCAddCMicrokernelTester() 81 .channel_tile(${CHANNEL_TILE}) 82 .channels(channels) 83 .rows(${ROW_TILE}) 84 .Test(${", ".join(TEST_ARGS)}); 85 } 86} 87 88$if ROW_TILE > 1: 89 TEST(${TEST_NAME}, rows_lt_${ROW_TILE}) { 90 $if ISA_CHECK: 91 ${ISA_CHECK}; 92 for (size_t rows = 1; rows < ${ROW_TILE}; rows++) { 93 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 94 VMulCAddCMicrokernelTester() 95 .channel_tile(${CHANNEL_TILE}) 96 .channels(channels) 97 .rows(rows) 98 .Test(${", ".join(TEST_ARGS)}); 99 } 100 } 101 } 102 103 TEST(${TEST_NAME}, rows_div_${ROW_TILE}) { 104 $if ISA_CHECK: 105 ${ISA_CHECK}; 106 for (size_t rows = ${ROW_TILE*2}; rows <= ${ROW_TILE*4}; rows += ${ROW_TILE}) { 107 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 108 VMulCAddCMicrokernelTester() 109 .channel_tile(${CHANNEL_TILE}) 110 .channels(channels) 111 .rows(rows) 112 .Test(${", ".join(TEST_ARGS)}); 113 } 114 } 115 } 116 117TEST(${TEST_NAME}, rows_gt_${ROW_TILE}) { 118 $if ISA_CHECK: 119 ${ISA_CHECK}; 120 for (size_t rows = ${ROW_TILE+1}; rows < ${ROW_TILE*2}; rows++) { 121 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 122 VMulCAddCMicrokernelTester() 123 .channel_tile(${CHANNEL_TILE}) 124 .channels(channels) 125 .rows(rows) 126 .Test(${", ".join(TEST_ARGS)}); 127 } 128 } 129} 130 131TEST(${TEST_NAME}, input_stride) { 132 $if ISA_CHECK: 133 ${ISA_CHECK}; 134 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 135 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 136 VMulCAddCMicrokernelTester() 137 .channel_tile(${CHANNEL_TILE}) 138 .channels(channels) 139 .rows(rows) 140 .input_stride(${next_prime(CHANNEL_TILE*5+1)}) 141 .Test(${", ".join(TEST_ARGS)}); 142 } 143 } 144} 145 146TEST(${TEST_NAME}, output_stride) { 147 $if ISA_CHECK: 148 ${ISA_CHECK}; 149 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 150 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 151 VMulCAddCMicrokernelTester() 152 .channel_tile(${CHANNEL_TILE}) 153 .channels(channels) 154 .rows(rows) 155 .output_stride(${next_prime(CHANNEL_TILE*5+1)}) 156 .Test(${", ".join(TEST_ARGS)}); 157 } 158 } 159} 160 161TEST(${TEST_NAME}, inplace) { 162 $if ISA_CHECK: 163 ${ISA_CHECK}; 164 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 165 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 166 VMulCAddCMicrokernelTester() 167 .channel_tile(${CHANNEL_TILE}) 168 .channels(channels) 169 .rows(rows) 170 .inplace(true) 171 .Test(${", ".join(TEST_ARGS)}); 172 } 173 } 174} 175 176TEST(${TEST_NAME}, qmin) { 177 $if ISA_CHECK: 178 ${ISA_CHECK}; 179 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 180 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 181 VMulCAddCMicrokernelTester() 182 .channel_tile(${CHANNEL_TILE}) 183 .channels(channels) 184 .rows(rows) 185 .qmin(128) 186 .Test(${", ".join(TEST_ARGS)}); 187 } 188 } 189} 190 191TEST(${TEST_NAME}, qmax) { 192 $if ISA_CHECK: 193 ${ISA_CHECK}; 194 for (size_t rows = 1; rows <= ${ROW_TILE*3}; rows += ${max(1, ROW_TILE-1)}) { 195 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) { 196 VMulCAddCMicrokernelTester() 197 .channel_tile(${CHANNEL_TILE}) 198 .channels(channels) 199 .rows(rows) 200 .qmax(128) 201 .Test(${", ".join(TEST_ARGS)}); 202 } 203 } 204} 205""" 206 207 208def generate_test_cases(ukernel, channel_tile, row_tile, isa): 209 """Generates all tests cases for a VMULCADDC micro-kernel. 210 211 Args: 212 ukernel: C name of the micro-kernel function. 213 channel_tile: Number of channels processed per one iteration of the inner 214 loop of the micro-kernel. 215 row_tile: Number of rows processed per one iteration of the outer loop of 216 the micro-kernel. 217 isa: instruction set required to run the micro-kernel. Generated unit test 218 will skip execution if the host processor doesn't support this ISA. 219 220 Returns: 221 Code for the test case. 222 """ 223 _, test_name = ukernel.split("_", 1) 224 _, datatype, ukernel_type, _ = ukernel.split("_", 3) 225 test_args = [ukernel] 226 if not isa: 227 test_args.append("VMulCAddCMicrokernelTester::Variant::Scalar") 228 return xngen.preprocess(VMULCADDC_TEST_TEMPLATE, { 229 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 230 "TEST_ARGS": test_args, 231 "DATATYPE": datatype, 232 "CHANNEL_TILE": channel_tile, 233 "ROW_TILE": row_tile, 234 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 235 "next_prime": next_prime, 236 }) 237 238 239def main(args): 240 options = parser.parse_args(args) 241 242 with codecs.open(options.spec, "r", encoding="utf-8") as spechannels_file: 243 spechannels_yaml = yaml.safe_load(spechannels_file) 244 if not isinstance(spechannels_yaml, list): 245 raise ValueError("expected a list of micro-kernels in the spec") 246 247 tests = """\ 248// Copyright 2019 Google LLC 249// 250// This source code is licensed under the BSD-style license found in the 251// LICENSE file in the root directory of this source tree. 252// 253// Auto-generated file. Do not edit! 254// Specification: {specification} 255// Generator: {generator} 256 257 258#include <gtest/gtest.h> 259 260#include <xnnpack/common.h> 261#include <xnnpack/isa-checks.h> 262 263#include <xnnpack/vmulcaddc.h> 264#include "vmulcaddc-microkernel-tester.h" 265""".format(specification=options.spec, generator=sys.argv[0]) 266 267 for ukernel_spec in spechannels_yaml: 268 name = ukernel_spec["name"] 269 channel_tile, row_tile, arch, isa = split_ukernel_name(name) 270 271 # specification can override architecture 272 arch = ukernel_spec.get("arch", arch) 273 274 test_case = generate_test_cases(name, channel_tile, row_tile, isa) 275 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 276 277 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 278 output_file.write(tests) 279 280 281if __name__ == "__main__": 282 main(sys.argv[1:]) 283