1#!/usr/bin/env python 2# Copyright 2019 Google LLC 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import bisect 9import codecs 10import os 11import sys 12import yaml 13 14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 15from primes import next_prime 16import xngen 17import xnncommon 18 19 20parser = argparse.ArgumentParser(description='XNNPACK generator') 21parser.add_argument("-s", "--spec", metavar="FILE", required=True, 22 help="Spec (YAML) file") 23parser.add_argument("-o", "--output", metavar="FILE", required=True, 24 help='Output (C++ source) file') 25parser.set_defaults(defines=list()) 26 27 28def split_ukernel_name(name): 29 common_name, target_name = name.split("__", 1) 30 common_parts = common_name.split("_") 31 param_spec = common_parts[-1] 32 mr, nr = map(int, param_spec.split("x")) 33 arch, isa = xnncommon.parse_target_name(target_name) 34 return mr, nr, arch, isa 35 36 37TEST_TEMPLATE = """\ 38TEST(${TEST_NAME}, k_eq_${KBLOCK}) { 39 $if ISA_CHECK: 40 ${ISA_CHECK}; 41 SpMMMicrokernelTester() 42 .mr(${MR}) 43 .nr(${NR}) 44 .m(${MR}) 45 .n(${NR}) 46 .k(${KBLOCK}) 47 .sparsity(0.0f) 48 .Test(${", ".join(TEST_ARGS)}); 49} 50 51$if NR > 1: 52 TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) { 53 $if ISA_CHECK: 54 ${ISA_CHECK}; 55 for (uint32_t n = 1; n <= ${NR}; n++) { 56 SpMMMicrokernelTester() 57 .mr(${MR}) 58 .nr(${NR}) 59 .m(${MR}) 60 .n(n) 61 .k(${KBLOCK}) 62 .sparsity(0.0f) 63 .Test(${", ".join(TEST_ARGS)}); 64 } 65 } 66 67$if IS_PIPELINED: 68 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) { 69 $if ISA_CHECK: 70 ${ISA_CHECK}; 71 SpMMMicrokernelTester() 72 .mr(${MR}) 73 .nr(${NR}) 74 .m(${MR}) 75 .n(${NR}) 76 .k(${KBLOCK * 2}) 77 .sparsity(0.0f) 78 .Test(${", ".join(TEST_ARGS)}); 79 } 80 81 $if NR > 1: 82 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) { 83 $if ISA_CHECK: 84 ${ISA_CHECK}; 85 for (uint32_t n = 1; n <= ${NR}; n++) { 86 SpMMMicrokernelTester() 87 .mr(${MR}) 88 .nr(${NR}) 89 .m(${MR}) 90 .n(n) 91 .k(${KBLOCK * 2}) 92 .sparsity(0.0f) 93 .Test(${", ".join(TEST_ARGS)}); 94 } 95 } 96 97$if KBLOCK > 1: 98 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) { 99 $if ISA_CHECK: 100 ${ISA_CHECK}; 101 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 102 SpMMMicrokernelTester() 103 .mr(${MR}) 104 .nr(${NR}) 105 .m(${MR}) 106 .n(${NR}) 107 .k(k) 108 .sparsity(0.0f) 109 .Test(${", ".join(TEST_ARGS)}); 110 } 111 } 112 113 $if NR > 1: 114 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) { 115 $if ISA_CHECK: 116 ${ISA_CHECK}; 117 for (size_t k = 1; k < ${ADJKBLOCK}; k++) { 118 for (uint32_t n = 1; n <= ${NR}; n++) { 119 SpMMMicrokernelTester() 120 .mr(${MR}) 121 .nr(${NR}) 122 .m(${MR}) 123 .n(n) 124 .k(k) 125 .sparsity(0.0f) 126 .Test(${", ".join(TEST_ARGS)}); 127 } 128 } 129 } 130 131TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) { 132 $if ISA_CHECK: 133 ${ISA_CHECK}; 134 for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) { 135 SpMMMicrokernelTester() 136 .mr(${MR}) 137 .nr(${NR}) 138 .m(${MR}) 139 .n(${NR}) 140 .k(k) 141 .sparsity(0.0f) 142 .Test(${", ".join(TEST_ARGS)}); 143 } 144} 145 146$if NR > 1: 147 TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) { 148 $if ISA_CHECK: 149 ${ISA_CHECK}; 150 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) { 151 for (uint32_t n = 1; n <= ${NR}; n++) { 152 SpMMMicrokernelTester() 153 .mr(${MR}) 154 .nr(${NR}) 155 .m(${MR}) 156 .n(n) 157 .k(k) 158 .sparsity(0.0f) 159 .Test(${", ".join(TEST_ARGS)}); 160 } 161 } 162 } 163 164$if KBLOCK > 1: 165 TEST(${TEST_NAME}, k_div_${KBLOCK}) { 166 $if ISA_CHECK: 167 ${ISA_CHECK}; 168 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 169 SpMMMicrokernelTester() 170 .mr(${MR}) 171 .nr(${NR}) 172 .m(${MR}) 173 .n(${NR}) 174 .k(k) 175 .sparsity(0.0f) 176 .Test(${", ".join(TEST_ARGS)}); 177 } 178 } 179 180 $if NR > 1: 181 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) { 182 $if ISA_CHECK: 183 ${ISA_CHECK}; 184 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) { 185 for (uint32_t n = 1; n <= ${NR}; n++) { 186 SpMMMicrokernelTester() 187 .mr(${MR}) 188 .nr(${NR}) 189 .m(${MR}) 190 .n(n) 191 .k(k) 192 .sparsity(0.0f) 193 .Test(${", ".join(TEST_ARGS)}); 194 } 195 } 196 } 197 198TEST(${TEST_NAME}, n_gt_${NR}) { 199 $if ISA_CHECK: 200 ${ISA_CHECK}; 201 for (uint32_t n = ${NR + 1}; n < ${max(10, NR * 2)}; n++) { 202 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 203 SpMMMicrokernelTester() 204 .mr(${MR}) 205 .nr(${NR}) 206 .m(${MR}) 207 .n(n) 208 .k(k) 209 .sparsity(0.0f) 210 .Test(${", ".join(TEST_ARGS)}); 211 } 212 } 213} 214 215$if NR > 1: 216 TEST(${TEST_NAME}, n_div_${NR}) { 217 $if ISA_CHECK: 218 ${ISA_CHECK}; 219 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) { 220 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 221 SpMMMicrokernelTester() 222 .mr(${MR}) 223 .nr(${NR}) 224 .m(${MR}) 225 .n(n) 226 .k(k) 227 .Test(${", ".join(TEST_ARGS)}); 228 } 229 } 230 } 231 232TEST(${TEST_NAME}, m_lt_${MR}) { 233 $if ISA_CHECK: 234 ${ISA_CHECK}; 235 for (uint32_t m = ${1}; m < ${MR}; m++) { 236 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 237 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 238 SpMMMicrokernelTester() 239 .mr(${MR}) 240 .nr(${NR}) 241 .m(m) 242 .n(n) 243 .k(k) 244 .sparsity(0.0f) 245 .Test(${", ".join(TEST_ARGS)}); 246 } 247 } 248 } 249} 250 251TEST(${TEST_NAME}, m_div_${MR}) { 252 $if ISA_CHECK: 253 ${ISA_CHECK}; 254 for (uint32_t m = ${MR * 2}; m <= ${MR * 3}; m += ${MR}) { 255 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 256 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 257 SpMMMicrokernelTester() 258 .mr(${MR}) 259 .nr(${NR}) 260 .m(m) 261 .n(n) 262 .k(k) 263 .sparsity(0.0f) 264 .Test(${", ".join(TEST_ARGS)}); 265 } 266 } 267 } 268} 269 270TEST(${TEST_NAME}, m_gt_${MR}) { 271 $if ISA_CHECK: 272 ${ISA_CHECK}; 273 for (uint32_t m = ${MR + 1}; m < ${MR * 2}; m++) { 274 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 275 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 276 SpMMMicrokernelTester() 277 .mr(${MR}) 278 .nr(${NR}) 279 .m(m) 280 .n(n) 281 .k(k) 282 .sparsity(0.0f) 283 .Test(${", ".join(TEST_ARGS)}); 284 } 285 } 286 } 287} 288 289TEST(${TEST_NAME}, qmin) { 290 $if ISA_CHECK: 291 ${ISA_CHECK}; 292 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 293 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 294 SpMMMicrokernelTester() 295 .mr(${MR}) 296 .nr(${NR}) 297 .m(${MR * 2}) 298 .n(n) 299 .k(k) 300 .sparsity(0.0f) 301 .qmin(128) 302 .Test(${", ".join(TEST_ARGS)}); 303 } 304 } 305} 306 307TEST(${TEST_NAME}, qmax) { 308 $if ISA_CHECK: 309 ${ISA_CHECK}; 310 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 311 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 312 SpMMMicrokernelTester() 313 .mr(${MR}) 314 .nr(${NR}) 315 .m(${MR * 2}) 316 .n(n) 317 .k(k) 318 .sparsity(0.0f) 319 .qmax(128) 320 .Test(${", ".join(TEST_ARGS)}); 321 } 322 } 323} 324 325TEST(${TEST_NAME}, half_sparse) { 326 $if ISA_CHECK: 327 ${ISA_CHECK}; 328 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 329 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 330 SpMMMicrokernelTester() 331 .mr(${MR}) 332 .nr(${NR}) 333 .m(${MR * 2}) 334 .n(n) 335 .k(k) 336 .sparsity(0.5f) 337 .Test(${", ".join(TEST_ARGS)}); 338 } 339 } 340} 341 342TEST(${TEST_NAME}, zero_weights) { 343 $if ISA_CHECK: 344 ${ISA_CHECK}; 345 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) { 346 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) { 347 SpMMMicrokernelTester() 348 .mr(${MR}) 349 .nr(${NR}) 350 .m(${MR * 2}) 351 .n(n) 352 .k(k) 353 .sparsity(1.0f) 354 .Test(${", ".join(TEST_ARGS)}); 355 } 356 } 357} 358""" 359 360 361def generate_test_cases(ukernel, mr, nr, k_block, is_pipelined, isa): 362 """Generates all tests cases for a GEMM micro-kernel. 363 364 Args: 365 ukernel: C name of the micro-kernel function. 366 mr: MR parameter of the GEMM micro-kernel. 367 nr: NR parameter of the GEMM micro-kernel. 368 k_block: Number of K values processed per one iteration of the main loop of 369 the micro-kernel. 370 is_pipelined: Indicates if the micro-kernel is implemented with software 371 pipelining. Additional test cases are generated for software 372 pipelined micro-kernels to separately test prologue + epiloque 373 of the pipelined loop and iteration of the pipelined loop. 374 isa: instruction set required to run the micro-kernel. Generated unit test 375 will skip execution if the host processor doesn't support this ISA. 376 377 Returns: 378 Code for the test case. 379 """ 380 _, test_name = ukernel.split("_", 1) 381 _, datatype, ukernel_type, _ = ukernel.split("_", 3) 382 test_args = [ukernel] 383 if not isa or isa == "psimd": 384 test_args.append("SpMMMicrokernelTester::Variant::Scalar") 385 return xngen.preprocess(TEST_TEMPLATE, { 386 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), 387 "TEST_ARGS": test_args, 388 "UKERNEL_TYPE": ukernel_type.upper(), 389 "DATATYPE": datatype, 390 "MR": mr, 391 "NR": nr, 392 "KBLOCK": k_block, 393 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, 394 "IS_PIPELINED": is_pipelined, 395 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), 396 "next_prime": next_prime, 397 }) 398 399 400def main(args): 401 options = parser.parse_args(args) 402 403 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: 404 spec_yaml = yaml.safe_load(spec_file) 405 if not isinstance(spec_yaml, list): 406 raise ValueError("expected a list of micro-kernels in the spec") 407 408 tests = """\ 409// Copyright 2019 Google LLC 410// 411// This source code is licensed under the BSD-style license found in the 412// LICENSE file in the root directory of this source tree. 413// 414// Auto-generated file. Do not edit! 415// Specification: {specification} 416// Generator: {generator} 417 418 419#include <gtest/gtest.h> 420 421#include <xnnpack/common.h> 422#include <xnnpack/isa-checks.h> 423 424#include <xnnpack/spmm.h> 425#include "spmm-microkernel-tester.h" 426""".format(specification=options.spec, generator=sys.argv[0]) 427 428 for ukernel_spec in spec_yaml: 429 name = ukernel_spec["name"] 430 k_block = int(ukernel_spec["k-block"]) 431 pipelined = bool(ukernel_spec.get("pipelined", False)) 432 mr, nr, arch, isa = split_ukernel_name(name) 433 434 # specification can override architecture 435 arch = ukernel_spec.get("arch", arch) 436 437 test_case = generate_test_cases(name, mr, nr, k_block, pipelined, isa) 438 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) 439 440 with codecs.open(options.output, "w", encoding="utf-8") as output_file: 441 output_file.write(tests) 442 443 444if __name__ == "__main__": 445 main(sys.argv[1:]) 446