1# This test generates all variants of wmma intrinsics and verifies that LLVM 2# generates correct instructions for them. 3 4# RUN: python %s > %t.ll 5# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll 6 7from itertools import product 8from string import Template 9 10def make_wmma_slice_ty(abcd, itype): 11 elt_ty = "<2 x half>" if itype == "f16" else "float" 12 num_elts = 4 if abcd in "cd" and itype == "f16" else 8; 13 return [elt_ty] * num_elts 14 15def make_wmma_ld_ret_ty(abc, itype): 16 return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype)) 17 18# returns address space 19def get_aspace(space): 20 space_map = { 21 ".global" : 1, 22 ".shared" : 3, 23 ".const" : 4, 24 ".local" : 5, 25 ".param" : 101, 26 "" : 0, 27 ".generic": 0 28 } 29 return space_map[space]; 30 31def get_pspace(space): 32 return "p%di8" % get_aspace(space); 33 34# Convenient test patterns. 35check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8) 36check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4) 37check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8) 38 39known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"] 40 41def gen_wmma_load_tests(): 42 load_template = """ 43declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); 44 45; CHECK-LABEL: .func {{.*}}test_${function}( 46define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) { 47; CHECK: ${instruction} 48; CHECK: {${check_result}} 49; CHECK: [%rd{{[0-9]+}}]${stride_pattern} 50 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args}); 51 ret ${ret_ty} %v0; 52} 53 54; CHECK-LABEL: .func{{.*}}test_${function}_o( 55define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) { 56; CHECK: ${instruction} 57; CHECK: {${check_result}} 58; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} 59 %src1 = getelementptr i8, i8 ${as}* %src, i32 128; 60 %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args}); 61 ret ${ret_ty} %v0; 62} 63""" 64 intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}" 65 instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}" 66 67 for geom, abc, layout, space, stride, itype in product( 68 known_geoms, 69 "abc", 70 ["row","col"], 71 ["",".shared",".global"], 72 ["", ".stride"], 73 ["f16", "f32"]): 74 75 params = { 76 "abc" : abc, 77 "layout" : layout, 78 "space" : space, 79 "stride" : stride, 80 "itype" : itype, 81 "pspace" : get_pspace(space), 82 "as" : "addrspace(%d)" % get_aspace(space), 83 "geom" : geom, 84 } 85 86 if itype == "f32" and abc != "c": 87 continue 88 89 test_params = params 90 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 91 test_params["function"] = test_params["intrinsic"].replace(".","_") 92 test_params["instruction"] = Template(instruction_template).substitute(params) 93 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) 94 if abc == "c" : 95 test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8 96 else: 97 test_params["check_result"] = check_f16_8 98 99 if stride: 100 test_params["extra_args"] = ", i32 %stride"; 101 test_params["stride_pattern"] = ", %r{{[0-9]+}}" 102 else: 103 test_params["extra_args"] = "" 104 test_params["stride_pattern"] = "" 105 106 print(Template(load_template).substitute(test_params)) 107 108def make_wmma_slice_args(itype, abcd, prefix="v"): 109 return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t 110 in enumerate(make_wmma_slice_ty(abcd, itype))]) 111 112def gen_wmma_store_tests(): 113 store_template = """ 114declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args}); 115 116; CHECK-LABEL: .func {{.*}}test_${function}( 117define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) { 118; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}} 119; CHECK: {${check_args}} 120; CHECK: ${stride_pattern} 121 call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args}); 122 ret void 123} 124 125; CHECK-LABEL: .func{{.*}}test_${function}_o( 126define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) { 127; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128] 128; CHECK: ${check_args} 129; CHECK: ${stride_pattern} 130 %src1 = getelementptr i8, i8 ${as}* %src, i32 128; 131 call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args}); 132 ret void 133} 134""" 135 intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}" 136 instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}" 137 138 for geom, abc, layout, space, stride, itype in product( 139 known_geoms, 140 "d", 141 ["row","col"], 142 ["",".shared",".global"], 143 ["", ".stride"], 144 ["f16", "f32"]): 145 146 params = { 147 "abc" : abc, 148 "layout" : layout, 149 "space" : space, 150 "stride" : stride, 151 "itype" : itype, 152 "pspace" : get_pspace(space), 153 "as" : "addrspace(%d)" % get_aspace(space), 154 "geom" : geom, 155 } 156 157 test_params = params 158 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 159 test_params["function"] = test_params["intrinsic"].replace(".","_") 160 test_params["instruction"] = Template(instruction_template).substitute(params) 161 test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype) 162 test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8 163 if stride: 164 test_params["extra_args"] = ", i32 %stride"; 165 test_params["stride_pattern"] = ", %r{{[0-9]+}};" 166 else: 167 test_params["extra_args"] = "" 168 test_params["stride_pattern"] = ";" 169 test_params["args"] = make_wmma_slice_args(itype, "d"); 170 171 print(Template(store_template).substitute(test_params)) 172 173def gen_wmma_mma_tests(): 174 mma_template = """ 175declare ${ret_ty} @${intrinsic}( 176 ${args}); 177 178; CHECK-LABEL: .func {{.*}}test_${function}( 179define ${ret_ty} @test_${function}( 180 ${args}) { 181; CHECK: ${instruction} 182; CHECK-NEXT: ${check_d} 183; CHECK-NEXT: ${check_ab} 184; CHECK-NEXT: ${check_ab} 185; CHECK-NEXT: ${check_c} 186 %r = call ${ret_ty} @${intrinsic}( 187 ${args}); 188 ret ${ret_ty} %r; 189} 190""" 191 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}" 192 instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}" 193 194 for geom, alayout, blayout, ctype, dtype, satf in product( 195 known_geoms, 196 ["row","col"], 197 ["row","col"], 198 ["f16", "f32"], 199 ["f16", "f32"], 200 [".satfinite", ""]): 201 202 params = { 203 "alayout" : alayout, 204 "blayout" : blayout, 205 "ctype" : ctype, 206 "dtype" : dtype, 207 "satf" : satf, 208 "geom" : geom, 209 } 210 211 test_params = params 212 test_params["intrinsic"] = Template(intrinsic_template).substitute(params) 213 test_params["function"] = test_params["intrinsic"].replace(".", "_") 214 test_params["instruction"] = Template(instruction_template).substitute(params) 215 test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype) 216 test_params["check_ab"] = check_f16_8 217 test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8 218 test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8 219 args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd) 220 for abcd, t in (("a", "f16"), 221 ("b", "f16"), 222 ("c", ctype))) 223 test_params["args"] = args 224 print(Template(mma_template).substitute(test_params)) 225 226def main(): 227 gen_wmma_load_tests() 228 gen_wmma_store_tests() 229 gen_wmma_mma_tests() 230 231main() 232