1# This script generates all variants of wmma builtins, verifies that clang calls 2# correct LLVM instrinsics, and checks that availability of specific builtins is 3# constrained by the correct PTX version and the target GPU variant. 4 5# Dummy test run to avoid lit warnings. 6# RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null 7 8from __future__ import print_function 9 10import argparse 11from collections import defaultdict 12from itertools import product 13from string import Template 14 15class MMAFrag: 16 def __init__(self, geom, frag, ptx_elt_type): 17 self.geom = geom 18 self.frag = frag 19 self.ptx_type = ptx_elt_type; 20 21 def __repr__(self): 22 return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type) 23 24class MMAOp: 25 def __init__(self, a, b, c, d): 26 self.a = a 27 self.b = b 28 self.c = c 29 self.d = d 30 31 def __repr__(self): 32 return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d )) 33 34def make_mma_ops(geoms, types_a, types_b, types_c, types_d): 35 ops = [] 36 for geom, type_a, type_c in product( geoms, types_a, types_c): 37 for type_b, type_d in product(types_b if types_b else [type_a], 38 types_d if types_d else [type_c]): 39 ops.append(MMAOp(MMAFrag(geom, "a", type_a), 40 MMAFrag(geom, "b", type_b), 41 MMAFrag(geom, "c", type_c), 42 MMAFrag(geom, "d", type_d))) 43 return ops 44 45def make_ldst_ops(geoms, frags, types): 46 return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) 47 in product(geoms, frags, types)] 48 49def get_mma_ops(): 50 return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], 51 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + 52 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], 53 ["s8", "u8"], [], ["s32"], []) + 54 make_mma_ops(["m8n8k32"], 55 ["s4", "u4"], [], ["s32"], []) + 56 make_mma_ops(["m8n8k128"], 57 ["b1"], [], ["s32"], [])) 58def get_ldst_ops(): 59 return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], 60 ["a", "b"], ["f16", "u8", "s8"]) + 61 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], 62 ["c", "d"], ["f16", "f32", "s32"]) + 63 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) + 64 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + 65 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])) 66 67def is_geom_supported(geom): 68 # geometries for FP and ints. 69 if geom in ["m8n32k16", "m32n8k16"]: 70 return ptx_version >= 61 71 # geometries for sub-ints. 72 if geom in ["m8n8k32", "m8n8k128"]: 73 return ptx_version >= 63 and gpu_arch >= 75 74 if geom == "m16n16k16": 75 return ptx_version >= 60 76 assert(False) # Unexpected geometry. 77 78def is_type_supported(ptx_type): 79 if ptx_type in ["s8", "u8", "s32"]: 80 return ptx_version >= 63 and gpu_arch >= 72 81 if ptx_type in ["s4", "u4", "b1"]: 82 return ptx_version >= 63 and gpu_arch >= 75 83 return ptx_version >= 60 and gpu_arch >= 70 84 85def is_mma_variant_supported(op, layout_a, layout_b, satf): 86 if not (is_type_supported(op.a.ptx_type) 87 and is_geom_supported(op.a.geom)): 88 return False 89 # sub-integer require row/col layout, and no satf. 90 if op.a.ptx_type in ["s4", "u4", "b1"]: 91 if op.a.ptx_type == "b1" and satf: 92 return False 93 return layout_a == "row" and layout_b == "col" 94 return True 95 96def is_ldst_variant_supported(frag, layout): 97 if not (is_type_supported(frag.ptx_type) 98 and is_geom_supported(frag.geom)): 99 return False 100 if frag.ptx_type in ["s4", "u4", "b1"]: 101 # sub-integer require sm_75 and ptx63, row/col layout for a/b. 102 return ((frag.frag == "a" and layout == "row") 103 or (frag.frag == "b" and layout == "col") 104 or frag.frag in ["c", "d"]) 105 return True 106 107def get_builtin_prefix(frag): 108 prefix = None 109 if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]: 110 if frag.ptx_type in ["f16", "f32"]: 111 prefix = "__hmma" 112 else: 113 prefix = "__imma" 114 elif frag.geom == "m8n8k32": 115 prefix = "__imma" # sub-integers 116 elif frag.geom == "m8n8k128": 117 prefix = "__bmma" 118 assert prefix 119 return prefix 120 121def get_ldst_builtin_name(frag): 122 prefix = get_builtin_prefix(frag) 123 124 if prefix == "__hmma": 125 suffix = "" if frag.frag in ["a","b"] else frag.ptx_type 126 elif prefix in ["__imma", "__bmma"]: 127 suffix = "" if frag.frag in ["c"] else frag.ptx_type 128 if suffix == "s32": 129 suffix = "i32" 130 if frag.frag == "d": 131 ifrag = "c" 132 op = "st" 133 else: 134 ifrag = frag.frag 135 op = "ld" 136 137 name = "%s_%s_%s_%s%s" % (prefix, frag.geom, op, ifrag, 138 "_" + suffix if suffix else "") 139 return name 140 141def get_mma_builtin_name(op): 142 prefix = get_builtin_prefix(op.a) 143 144 if prefix == "__hmma": 145 suffix = op.d.ptx_type + op.c.ptx_type 146 else: 147 suffix = op.a.ptx_type 148 149 name = "%s_%s_mma%s_%s" % (prefix, op.a.geom, 150 "_xor_popc" if op.a.ptx_type == "b1" else "", 151 suffix) 152 return name 153 154 155def get_required_sm(frag): 156 if frag.ptx_type in ["u4", "s4", "b1"]: 157 return 75 158 if frag.ptx_type in ["s8", "u8"]: 159 return 72 160 if frag.ptx_type == "s32": 161 if frag.geom in ["m8n8k32", "m8n8k128"]: # s4/u4/b1 162 return 75 163 else: # s8/u8 164 return 72 165 if frag.ptx_type in ["f16", "f32"]: 166 return 70 167 assert(False) 168 169def get_required_ptx(frag): 170 if frag.ptx_type in ["f16", "f32"]: 171 return 60 if frag.geom == "m16n16k16" else 61 172 return 63 173 174def gen_wmma_ldst_tests(results): 175 load_template = """ 176 // CHECK${check_suffix}: call {{.*}} @${intrinsic} 177 // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} 178 ${builtin}(${dst}, ${src}, ldm, ${blayout}); 179""".rstrip() 180 intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}" 181 182 for frag, layout in sorted(product(get_ldst_ops(), ["row","col"]), key=str): 183 184 if not is_ldst_variant_supported(frag, layout): 185 continue 186 187 is_fp = frag.ptx_type == "f32" 188 min_sm = get_required_sm(frag) 189 min_ptx = get_required_ptx(frag) 190 params = { 191 "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), 192 "builtin" : get_ldst_builtin_name(frag), 193 "min_ptx" : min_ptx, 194 "min_sm" : min_sm, 195 "dst": "fdst" if is_fp else "dst", 196 "src": "fsrc" if is_fp else "src", 197 "blayout" : 0 if layout == "row" else 1, 198 "intrinsic" : Template(intrinsic_template).substitute({ 199 "frag" : frag.frag, 200 "geom" : frag.geom, 201 "ilayout" : layout, 202 "itype" : frag.ptx_type, 203 "op" : "store" if frag.frag == "d" else "load", 204 }) 205 } 206 results[(min_ptx,min_sm)] += Template(load_template).substitute(params) 207 208 return results 209 210def mma_signature(op): 211 if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: 212 # int and sub-int ops are identified by input type. 213 return op.a.ptx_type 214 else: 215 # the rest are FP ops identified by accumulator & result type. 216 return "%s.%s" % (op.d.ptx_type, op.c.ptx_type) 217 218# Get numeric value for rowcol parameter of the builtin 219# AFAICT it uses the encoding accepted by NVVM intrinsics: 220# https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma 221def get_ilayout(a, b): 222 return { 223 "row.row" : 0, 224 "row.col" : 1, 225 "col.row" : 2, 226 "col.col" : 3 227 }[a + "." + b] 228 229def gen_wmma_mma_tests(results): 230 mma_template = """ 231 // CHECK${check_suffix}: call {{.*}} @${intrinsic} 232 // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} 233 ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf}); 234""".rstrip() 235 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" 236 237 for op, alayout, blayout, satf in sorted(product( get_mma_ops(), 238 ["row","col"], 239 ["row","col"], 240 [".satfinite", ""]), 241 key=str): 242 243 if not is_mma_variant_supported(op, alayout, blayout, satf): 244 continue 245 246 a_is_fp = op.a.ptx_type == "f32" 247 c_is_fp = op.c.ptx_type == "f32" 248 d_is_fp = op.d.ptx_type == "f32" 249 min_sm = get_required_sm(op.a) 250 min_ptx = get_required_ptx(op.a) 251 if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. 252 isatf_arg = "" 253 else: 254 isatf_arg = ", 1" if satf else ", 0" 255 params = { 256 "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), 257 "builtin" : get_mma_builtin_name(op), 258 "min_ptx" : min_ptx, 259 "min_sm" : min_sm, 260 "dst": "fdst" if d_is_fp else "dst", 261 "asrc": "fsrc" if a_is_fp else "src", 262 "csrc": "fsrc" if c_is_fp else "src", 263 "ilayout" : get_ilayout(alayout, blayout), 264 "maybe_isatf" : isatf_arg, 265 "intrinsic" : Template(intrinsic_template).substitute({ 266 "geom" : op.a.geom, 267 "alayout" : alayout, 268 "blayout" : blayout, 269 "intrinsic_signature" : mma_signature(op), 270 "satf" : satf, 271 }) 272 } 273 results[(min_ptx, min_sm)] += Template(mma_template).substitute(params) 274 275 return results 276 277def gen_tests(): 278 results = gen_wmma_ldst_tests(defaultdict(str)) 279 results = gen_wmma_mma_tests(results) 280 281 run_template = r""" 282// 283// *** DO NOT EDIT *** 284// 285// This test has been automatically generated by 286// builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm} 287// 288// Make sure we can handle all builtins available on sm_${sm} with PTX${ptx} 289// ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \ 290// ${run}: -fcuda-is-device -target-feature +ptx${ptx} \ 291// ${run}: -DPTX=${ptx} -DSM=${sm} \ 292// ${run}: -S -emit-llvm -o - -x cuda %s \ 293// ${run}: | FileCheck -check-prefixes=${check_labels} %s 294// Verify that all builtins have correct constraints. 295// ${run}: %clang_cc1 -triple nvptx-unknown-unknown \ 296// ${run}: -target-cpu sm_60 -target-feature +ptx42 \ 297// ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \ 298// ${run}: -verify %s 299""" 300 def supported_variants(ptx, sm, results): 301 return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm] 302 303 print(Template(run_template).substitute({ 304 "run" : "RUN", # To avoid lit misinterpreting the template 305 "ptx" : ptx_version, 306 "sm" : gpu_arch, 307 "check_labels" : ",".join(["CHECK_PTX%d_SM%d" % (ptx_, sm_) 308 for ptx_, sm_ 309 in supported_variants(ptx_version, gpu_arch, 310 results)]) 311 })) 312 313 print(""" 314#if !defined(CUDA_VERSION) 315#define __device__ __attribute__((device)) 316#define __global__ __attribute__((global)) 317#define __shared__ __attribute__((shared)) 318#define __constant__ __attribute__((constant)) 319 320typedef unsigned long long uint64_t; 321#endif 322 323// CHECK-LABEL: test_wmma_buitins 324__device__ void test_wmma_buitins(int *src, int *dst, 325 float *fsrc, float *fdst, int ldm) { 326"""); 327 328 for (ptx, sm), tests in sorted(results.items()): 329 print() 330 print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm)) 331 print(tests) 332 print("#endif // (PTX >= %d) && (SM >= %d) "% (ptx, sm)) 333 334 print("}") 335 336parser = argparse.ArgumentParser() 337parser.add_argument("--ptx", type=int, default=60) 338parser.add_argument("--gpu-arch", type=int, default=70) 339args = parser.parse_args() 340ptx_version = args.ptx 341gpu_arch = args.gpu_arch 342 343gen_tests() 344