• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This test generates all variants of wmma intrinsics and verifies that LLVM
2# generates correct instructions for them.
3
4# RUN: python %s > %t.ll
5# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
6
7from itertools import product
8from string import Template
9
10def make_wmma_slice_ty(abcd, itype):
11  elt_ty = "<2 x half>" if itype == "f16" else "float"
12  num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
13  return [elt_ty] * num_elts
14
15def make_wmma_ld_ret_ty(abc, itype):
16  return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
17
18# returns address space
19def get_aspace(space):
20  space_map = {
21      ".global" : 1,
22      ".shared" : 3,
23      ".const"  : 4,
24      ".local"  : 5,
25      ".param"  : 101,
26      ""        : 0,
27      ".generic": 0
28  }
29  return space_map[space];
30
31def get_pspace(space):
32  return "p%di8" % get_aspace(space);
33
34# Convenient test patterns.
35check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
36check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
37check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
38
39known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
40
41def gen_wmma_load_tests():
42  load_template = """
43declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
44
45; CHECK-LABEL: .func {{.*}}test_${function}(
46define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
47; CHECK: ${instruction}
48; CHECK: {${check_result}}
49; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
50  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
51  ret ${ret_ty} %v0;
52}
53
54; CHECK-LABEL: .func{{.*}}test_${function}_o(
55define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
56; CHECK: ${instruction}
57; CHECK: {${check_result}}
58; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
59  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
60  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
61  ret ${ret_ty} %v0;
62}
63"""
64  intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
65  instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
66
67  for geom, abc, layout, space, stride, itype in product(
68      known_geoms,
69      "abc",
70      ["row","col"],
71      ["",".shared",".global"],
72      ["", ".stride"],
73      ["f16", "f32"]):
74
75    params = {
76        "abc" : abc,
77        "layout" : layout,
78        "space" : space,
79        "stride" : stride,
80        "itype" : itype,
81        "pspace" : get_pspace(space),
82        "as"     : "addrspace(%d)" % get_aspace(space),
83        "geom"   : geom,
84    }
85
86    if itype == "f32" and abc != "c":
87      continue
88
89    test_params = params
90    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
91    test_params["function"] = test_params["intrinsic"].replace(".","_")
92    test_params["instruction"] = Template(instruction_template).substitute(params)
93    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
94    if abc == "c" :
95      test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
96    else:
97      test_params["check_result"] = check_f16_8
98
99    if stride:
100      test_params["extra_args"] = ", i32 %stride";
101      test_params["stride_pattern"] = ", %r{{[0-9]+}}"
102    else:
103      test_params["extra_args"] = ""
104      test_params["stride_pattern"] = ""
105
106    print(Template(load_template).substitute(test_params))
107
108def make_wmma_slice_args(itype, abcd, prefix="v"):
109  return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
110                  in enumerate(make_wmma_slice_ty(abcd, itype))])
111
112def gen_wmma_store_tests():
113  store_template = """
114declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
115
116; CHECK-LABEL: .func {{.*}}test_${function}(
117define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
118; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
119; CHECK: {${check_args}}
120; CHECK: ${stride_pattern}
121  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
122  ret void
123}
124
125; CHECK-LABEL: .func{{.*}}test_${function}_o(
126define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
127; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
128; CHECK: ${check_args}
129; CHECK: ${stride_pattern}
130  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
131  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
132  ret void
133}
134"""
135  intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
136  instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
137
138  for geom, abc, layout, space, stride, itype in product(
139      known_geoms,
140      "d",
141      ["row","col"],
142      ["",".shared",".global"],
143      ["", ".stride"],
144      ["f16", "f32"]):
145
146    params = {
147        "abc" : abc,
148        "layout" : layout,
149        "space" : space,
150        "stride" : stride,
151        "itype" : itype,
152        "pspace" : get_pspace(space),
153        "as"     : "addrspace(%d)" % get_aspace(space),
154        "geom"   : geom,
155    }
156
157    test_params = params
158    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
159    test_params["function"] = test_params["intrinsic"].replace(".","_")
160    test_params["instruction"] = Template(instruction_template).substitute(params)
161    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
162    test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
163    if stride:
164      test_params["extra_args"] = ", i32 %stride";
165      test_params["stride_pattern"] = ", %r{{[0-9]+}};"
166    else:
167      test_params["extra_args"] = ""
168      test_params["stride_pattern"] = ";"
169    test_params["args"] = make_wmma_slice_args(itype, "d");
170
171    print(Template(store_template).substitute(test_params))
172
173def gen_wmma_mma_tests():
174  mma_template = """
175declare ${ret_ty} @${intrinsic}(
176        ${args});
177
178; CHECK-LABEL: .func {{.*}}test_${function}(
179define ${ret_ty} @test_${function}(
180        ${args}) {
181; CHECK: ${instruction}
182; CHECK-NEXT: ${check_d}
183; CHECK-NEXT: ${check_ab}
184; CHECK-NEXT: ${check_ab}
185; CHECK-NEXT: ${check_c}
186  %r = call ${ret_ty} @${intrinsic}(
187        ${args});
188  ret ${ret_ty} %r;
189}
190"""
191  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
192  instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
193
194  for geom, alayout, blayout, ctype, dtype, satf in product(
195      known_geoms,
196      ["row","col"],
197      ["row","col"],
198      ["f16", "f32"],
199      ["f16", "f32"],
200      [".satfinite", ""]):
201
202    params = {
203        "alayout" : alayout,
204        "blayout" : blayout,
205        "ctype" : ctype,
206        "dtype" : dtype,
207        "satf"  : satf,
208        "geom"  : geom,
209    }
210
211    test_params = params
212    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
213    test_params["function"] = test_params["intrinsic"].replace(".", "_")
214    test_params["instruction"] = Template(instruction_template).substitute(params)
215    test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
216    test_params["check_ab"] = check_f16_8
217    test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
218    test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
219    args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
220                              for abcd, t in (("a", "f16"),
221                                              ("b", "f16"),
222                                              ("c", ctype)))
223    test_params["args"] = args
224    print(Template(mma_template).substitute(test_params))
225
226def main():
227  gen_wmma_load_tests()
228  gen_wmma_store_tests()
229  gen_wmma_mma_tests()
230
231main()
232