1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/num_workgroups_from_uniform.h"
16
17 #include <utility>
18
19 #include "src/transform/canonicalize_entry_point_io.h"
20 #include "src/transform/test_helper.h"
21 #include "src/transform/unshadow.h"
22
23 namespace tint {
24 namespace transform {
25 namespace {
26
27 using NumWorkgroupsFromUniformTest = TransformTest;
28
TEST_F(NumWorkgroupsFromUniformTest,Error_MissingTransformData)29 TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
30 auto* src = "";
31
32 auto* expect =
33 "error: missing transform data for "
34 "tint::transform::NumWorkgroupsFromUniform";
35
36 DataMap data;
37 data.Add<CanonicalizeEntryPointIO::Config>(
38 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
39 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
40 src, data);
41
42 EXPECT_EQ(expect, str(got));
43 }
44
TEST_F(NumWorkgroupsFromUniformTest,Error_MissingCanonicalizeEntryPointIO)45 TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) {
46 auto* src = "";
47
48 auto* expect =
49 "error: tint::transform::NumWorkgroupsFromUniform depends on "
50 "tint::transform::CanonicalizeEntryPointIO but the dependency was not "
51 "run";
52
53 auto got = Run<NumWorkgroupsFromUniform>(src);
54
55 EXPECT_EQ(expect, str(got));
56 }
57
TEST_F(NumWorkgroupsFromUniformTest,Basic)58 TEST_F(NumWorkgroupsFromUniformTest, Basic) {
59 auto* src = R"(
60 [[stage(compute), workgroup_size(1)]]
61 fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
62 let groups_x = num_wgs.x;
63 let groups_y = num_wgs.y;
64 let groups_z = num_wgs.z;
65 }
66 )";
67
68 auto* expect = R"(
69 [[block]]
70 struct tint_symbol_2 {
71 num_workgroups : vec3<u32>;
72 };
73
74 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
75
76 fn main_inner(num_wgs : vec3<u32>) {
77 let groups_x = num_wgs.x;
78 let groups_y = num_wgs.y;
79 let groups_z = num_wgs.z;
80 }
81
82 [[stage(compute), workgroup_size(1)]]
83 fn main() {
84 main_inner(tint_symbol_3.num_workgroups);
85 }
86 )";
87
88 DataMap data;
89 data.Add<CanonicalizeEntryPointIO::Config>(
90 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
91 data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
92 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
93 src, data);
94 EXPECT_EQ(expect, str(got));
95 }
96
TEST_F(NumWorkgroupsFromUniformTest,StructOnlyMember)97 TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
98 auto* src = R"(
99 struct Builtins {
100 [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
101 };
102
103 [[stage(compute), workgroup_size(1)]]
104 fn main(in : Builtins) {
105 let groups_x = in.num_wgs.x;
106 let groups_y = in.num_wgs.y;
107 let groups_z = in.num_wgs.z;
108 }
109 )";
110
111 auto* expect = R"(
112 [[block]]
113 struct tint_symbol_2 {
114 num_workgroups : vec3<u32>;
115 };
116
117 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
118
119 struct Builtins {
120 num_wgs : vec3<u32>;
121 };
122
123 fn main_inner(in : Builtins) {
124 let groups_x = in.num_wgs.x;
125 let groups_y = in.num_wgs.y;
126 let groups_z = in.num_wgs.z;
127 }
128
129 [[stage(compute), workgroup_size(1)]]
130 fn main() {
131 main_inner(Builtins(tint_symbol_3.num_workgroups));
132 }
133 )";
134
135 DataMap data;
136 data.Add<CanonicalizeEntryPointIO::Config>(
137 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
138 data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
139 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
140 src, data);
141 EXPECT_EQ(expect, str(got));
142 }
143
TEST_F(NumWorkgroupsFromUniformTest,StructMultipleMembers)144 TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
145 auto* src = R"(
146 struct Builtins {
147 [[builtin(global_invocation_id)]] gid : vec3<u32>;
148 [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
149 [[builtin(workgroup_id)]] wgid : vec3<u32>;
150 };
151
152 [[stage(compute), workgroup_size(1)]]
153 fn main(in : Builtins) {
154 let groups_x = in.num_wgs.x;
155 let groups_y = in.num_wgs.y;
156 let groups_z = in.num_wgs.z;
157 }
158 )";
159
160 auto* expect = R"(
161 [[block]]
162 struct tint_symbol_2 {
163 num_workgroups : vec3<u32>;
164 };
165
166 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
167
168 struct Builtins {
169 gid : vec3<u32>;
170 num_wgs : vec3<u32>;
171 wgid : vec3<u32>;
172 };
173
174 struct tint_symbol_1 {
175 [[builtin(global_invocation_id)]]
176 gid : vec3<u32>;
177 [[builtin(workgroup_id)]]
178 wgid : vec3<u32>;
179 };
180
181 fn main_inner(in : Builtins) {
182 let groups_x = in.num_wgs.x;
183 let groups_y = in.num_wgs.y;
184 let groups_z = in.num_wgs.z;
185 }
186
187 [[stage(compute), workgroup_size(1)]]
188 fn main(tint_symbol : tint_symbol_1) {
189 main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
190 }
191 )";
192
193 DataMap data;
194 data.Add<CanonicalizeEntryPointIO::Config>(
195 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
196 data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
197 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
198 src, data);
199 EXPECT_EQ(expect, str(got));
200 }
201
TEST_F(NumWorkgroupsFromUniformTest,MultipleEntryPoints)202 TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
203 auto* src = R"(
204 struct Builtins1 {
205 [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
206 };
207
208 struct Builtins2 {
209 [[builtin(global_invocation_id)]] gid : vec3<u32>;
210 [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
211 [[builtin(workgroup_id)]] wgid : vec3<u32>;
212 };
213
214 [[stage(compute), workgroup_size(1)]]
215 fn main1(in : Builtins1) {
216 let groups_x = in.num_wgs.x;
217 let groups_y = in.num_wgs.y;
218 let groups_z = in.num_wgs.z;
219 }
220
221 [[stage(compute), workgroup_size(1)]]
222 fn main2(in : Builtins2) {
223 let groups_x = in.num_wgs.x;
224 let groups_y = in.num_wgs.y;
225 let groups_z = in.num_wgs.z;
226 }
227
228 [[stage(compute), workgroup_size(1)]]
229 fn main3([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
230 let groups_x = num_wgs.x;
231 let groups_y = num_wgs.y;
232 let groups_z = num_wgs.z;
233 }
234 )";
235
236 auto* expect = R"(
237 [[block]]
238 struct tint_symbol_6 {
239 num_workgroups : vec3<u32>;
240 };
241
242 [[group(0), binding(30)]] var<uniform> tint_symbol_7 : tint_symbol_6;
243
244 struct Builtins1 {
245 num_wgs : vec3<u32>;
246 };
247
248 struct Builtins2 {
249 gid : vec3<u32>;
250 num_wgs : vec3<u32>;
251 wgid : vec3<u32>;
252 };
253
254 fn main1_inner(in : Builtins1) {
255 let groups_x = in.num_wgs.x;
256 let groups_y = in.num_wgs.y;
257 let groups_z = in.num_wgs.z;
258 }
259
260 [[stage(compute), workgroup_size(1)]]
261 fn main1() {
262 main1_inner(Builtins1(tint_symbol_7.num_workgroups));
263 }
264
265 struct tint_symbol_3 {
266 [[builtin(global_invocation_id)]]
267 gid : vec3<u32>;
268 [[builtin(workgroup_id)]]
269 wgid : vec3<u32>;
270 };
271
272 fn main2_inner(in : Builtins2) {
273 let groups_x = in.num_wgs.x;
274 let groups_y = in.num_wgs.y;
275 let groups_z = in.num_wgs.z;
276 }
277
278 [[stage(compute), workgroup_size(1)]]
279 fn main2(tint_symbol_2 : tint_symbol_3) {
280 main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
281 }
282
283 fn main3_inner(num_wgs : vec3<u32>) {
284 let groups_x = num_wgs.x;
285 let groups_y = num_wgs.y;
286 let groups_z = num_wgs.z;
287 }
288
289 [[stage(compute), workgroup_size(1)]]
290 fn main3() {
291 main3_inner(tint_symbol_7.num_workgroups);
292 }
293 )";
294
295 DataMap data;
296 data.Add<CanonicalizeEntryPointIO::Config>(
297 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
298 data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
299 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
300 src, data);
301 EXPECT_EQ(expect, str(got));
302 }
303
TEST_F(NumWorkgroupsFromUniformTest,NoUsages)304 TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
305 auto* src = R"(
306 struct Builtins {
307 [[builtin(global_invocation_id)]] gid : vec3<u32>;
308 [[builtin(workgroup_id)]] wgid : vec3<u32>;
309 };
310
311 [[stage(compute), workgroup_size(1)]]
312 fn main(in : Builtins) {
313 }
314 )";
315
316 auto* expect = R"(
317 struct Builtins {
318 gid : vec3<u32>;
319 wgid : vec3<u32>;
320 };
321
322 struct tint_symbol_1 {
323 [[builtin(global_invocation_id)]]
324 gid : vec3<u32>;
325 [[builtin(workgroup_id)]]
326 wgid : vec3<u32>;
327 };
328
329 fn main_inner(in : Builtins) {
330 }
331
332 [[stage(compute), workgroup_size(1)]]
333 fn main(tint_symbol : tint_symbol_1) {
334 main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
335 }
336 )";
337
338 DataMap data;
339 data.Add<CanonicalizeEntryPointIO::Config>(
340 CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
341 data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
342 auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
343 src, data);
344 EXPECT_EQ(expect, str(got));
345 }
346
347 } // namespace
348 } // namespace transform
349 } // namespace tint
350