• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/num_workgroups_from_uniform.h"
16 
17 #include <utility>
18 
19 #include "src/transform/canonicalize_entry_point_io.h"
20 #include "src/transform/test_helper.h"
21 #include "src/transform/unshadow.h"
22 
23 namespace tint {
24 namespace transform {
25 namespace {
26 
27 using NumWorkgroupsFromUniformTest = TransformTest;
28 
TEST_F(NumWorkgroupsFromUniformTest,Error_MissingTransformData)29 TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
30   auto* src = "";
31 
32   auto* expect =
33       "error: missing transform data for "
34       "tint::transform::NumWorkgroupsFromUniform";
35 
36   DataMap data;
37   data.Add<CanonicalizeEntryPointIO::Config>(
38       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
39   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
40       src, data);
41 
42   EXPECT_EQ(expect, str(got));
43 }
44 
TEST_F(NumWorkgroupsFromUniformTest,Error_MissingCanonicalizeEntryPointIO)45 TEST_F(NumWorkgroupsFromUniformTest, Error_MissingCanonicalizeEntryPointIO) {
46   auto* src = "";
47 
48   auto* expect =
49       "error: tint::transform::NumWorkgroupsFromUniform depends on "
50       "tint::transform::CanonicalizeEntryPointIO but the dependency was not "
51       "run";
52 
53   auto got = Run<NumWorkgroupsFromUniform>(src);
54 
55   EXPECT_EQ(expect, str(got));
56 }
57 
TEST_F(NumWorkgroupsFromUniformTest,Basic)58 TEST_F(NumWorkgroupsFromUniformTest, Basic) {
59   auto* src = R"(
60 [[stage(compute), workgroup_size(1)]]
61 fn main([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
62   let groups_x = num_wgs.x;
63   let groups_y = num_wgs.y;
64   let groups_z = num_wgs.z;
65 }
66 )";
67 
68   auto* expect = R"(
69 [[block]]
70 struct tint_symbol_2 {
71   num_workgroups : vec3<u32>;
72 };
73 
74 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
75 
76 fn main_inner(num_wgs : vec3<u32>) {
77   let groups_x = num_wgs.x;
78   let groups_y = num_wgs.y;
79   let groups_z = num_wgs.z;
80 }
81 
82 [[stage(compute), workgroup_size(1)]]
83 fn main() {
84   main_inner(tint_symbol_3.num_workgroups);
85 }
86 )";
87 
88   DataMap data;
89   data.Add<CanonicalizeEntryPointIO::Config>(
90       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
91   data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
92   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
93       src, data);
94   EXPECT_EQ(expect, str(got));
95 }
96 
TEST_F(NumWorkgroupsFromUniformTest,StructOnlyMember)97 TEST_F(NumWorkgroupsFromUniformTest, StructOnlyMember) {
98   auto* src = R"(
99 struct Builtins {
100   [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
101 };
102 
103 [[stage(compute), workgroup_size(1)]]
104 fn main(in : Builtins) {
105   let groups_x = in.num_wgs.x;
106   let groups_y = in.num_wgs.y;
107   let groups_z = in.num_wgs.z;
108 }
109 )";
110 
111   auto* expect = R"(
112 [[block]]
113 struct tint_symbol_2 {
114   num_workgroups : vec3<u32>;
115 };
116 
117 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
118 
119 struct Builtins {
120   num_wgs : vec3<u32>;
121 };
122 
123 fn main_inner(in : Builtins) {
124   let groups_x = in.num_wgs.x;
125   let groups_y = in.num_wgs.y;
126   let groups_z = in.num_wgs.z;
127 }
128 
129 [[stage(compute), workgroup_size(1)]]
130 fn main() {
131   main_inner(Builtins(tint_symbol_3.num_workgroups));
132 }
133 )";
134 
135   DataMap data;
136   data.Add<CanonicalizeEntryPointIO::Config>(
137       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
138   data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
139   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
140       src, data);
141   EXPECT_EQ(expect, str(got));
142 }
143 
TEST_F(NumWorkgroupsFromUniformTest,StructMultipleMembers)144 TEST_F(NumWorkgroupsFromUniformTest, StructMultipleMembers) {
145   auto* src = R"(
146 struct Builtins {
147   [[builtin(global_invocation_id)]] gid : vec3<u32>;
148   [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
149   [[builtin(workgroup_id)]] wgid : vec3<u32>;
150 };
151 
152 [[stage(compute), workgroup_size(1)]]
153 fn main(in : Builtins) {
154   let groups_x = in.num_wgs.x;
155   let groups_y = in.num_wgs.y;
156   let groups_z = in.num_wgs.z;
157 }
158 )";
159 
160   auto* expect = R"(
161 [[block]]
162 struct tint_symbol_2 {
163   num_workgroups : vec3<u32>;
164 };
165 
166 [[group(0), binding(30)]] var<uniform> tint_symbol_3 : tint_symbol_2;
167 
168 struct Builtins {
169   gid : vec3<u32>;
170   num_wgs : vec3<u32>;
171   wgid : vec3<u32>;
172 };
173 
174 struct tint_symbol_1 {
175   [[builtin(global_invocation_id)]]
176   gid : vec3<u32>;
177   [[builtin(workgroup_id)]]
178   wgid : vec3<u32>;
179 };
180 
181 fn main_inner(in : Builtins) {
182   let groups_x = in.num_wgs.x;
183   let groups_y = in.num_wgs.y;
184   let groups_z = in.num_wgs.z;
185 }
186 
187 [[stage(compute), workgroup_size(1)]]
188 fn main(tint_symbol : tint_symbol_1) {
189   main_inner(Builtins(tint_symbol.gid, tint_symbol_3.num_workgroups, tint_symbol.wgid));
190 }
191 )";
192 
193   DataMap data;
194   data.Add<CanonicalizeEntryPointIO::Config>(
195       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
196   data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
197   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
198       src, data);
199   EXPECT_EQ(expect, str(got));
200 }
201 
TEST_F(NumWorkgroupsFromUniformTest,MultipleEntryPoints)202 TEST_F(NumWorkgroupsFromUniformTest, MultipleEntryPoints) {
203   auto* src = R"(
204 struct Builtins1 {
205   [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
206 };
207 
208 struct Builtins2 {
209   [[builtin(global_invocation_id)]] gid : vec3<u32>;
210   [[builtin(num_workgroups)]] num_wgs : vec3<u32>;
211   [[builtin(workgroup_id)]] wgid : vec3<u32>;
212 };
213 
214 [[stage(compute), workgroup_size(1)]]
215 fn main1(in : Builtins1) {
216   let groups_x = in.num_wgs.x;
217   let groups_y = in.num_wgs.y;
218   let groups_z = in.num_wgs.z;
219 }
220 
221 [[stage(compute), workgroup_size(1)]]
222 fn main2(in : Builtins2) {
223   let groups_x = in.num_wgs.x;
224   let groups_y = in.num_wgs.y;
225   let groups_z = in.num_wgs.z;
226 }
227 
228 [[stage(compute), workgroup_size(1)]]
229 fn main3([[builtin(num_workgroups)]] num_wgs : vec3<u32>) {
230   let groups_x = num_wgs.x;
231   let groups_y = num_wgs.y;
232   let groups_z = num_wgs.z;
233 }
234 )";
235 
236   auto* expect = R"(
237 [[block]]
238 struct tint_symbol_6 {
239   num_workgroups : vec3<u32>;
240 };
241 
242 [[group(0), binding(30)]] var<uniform> tint_symbol_7 : tint_symbol_6;
243 
244 struct Builtins1 {
245   num_wgs : vec3<u32>;
246 };
247 
248 struct Builtins2 {
249   gid : vec3<u32>;
250   num_wgs : vec3<u32>;
251   wgid : vec3<u32>;
252 };
253 
254 fn main1_inner(in : Builtins1) {
255   let groups_x = in.num_wgs.x;
256   let groups_y = in.num_wgs.y;
257   let groups_z = in.num_wgs.z;
258 }
259 
260 [[stage(compute), workgroup_size(1)]]
261 fn main1() {
262   main1_inner(Builtins1(tint_symbol_7.num_workgroups));
263 }
264 
265 struct tint_symbol_3 {
266   [[builtin(global_invocation_id)]]
267   gid : vec3<u32>;
268   [[builtin(workgroup_id)]]
269   wgid : vec3<u32>;
270 };
271 
272 fn main2_inner(in : Builtins2) {
273   let groups_x = in.num_wgs.x;
274   let groups_y = in.num_wgs.y;
275   let groups_z = in.num_wgs.z;
276 }
277 
278 [[stage(compute), workgroup_size(1)]]
279 fn main2(tint_symbol_2 : tint_symbol_3) {
280   main2_inner(Builtins2(tint_symbol_2.gid, tint_symbol_7.num_workgroups, tint_symbol_2.wgid));
281 }
282 
283 fn main3_inner(num_wgs : vec3<u32>) {
284   let groups_x = num_wgs.x;
285   let groups_y = num_wgs.y;
286   let groups_z = num_wgs.z;
287 }
288 
289 [[stage(compute), workgroup_size(1)]]
290 fn main3() {
291   main3_inner(tint_symbol_7.num_workgroups);
292 }
293 )";
294 
295   DataMap data;
296   data.Add<CanonicalizeEntryPointIO::Config>(
297       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
298   data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
299   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
300       src, data);
301   EXPECT_EQ(expect, str(got));
302 }
303 
TEST_F(NumWorkgroupsFromUniformTest,NoUsages)304 TEST_F(NumWorkgroupsFromUniformTest, NoUsages) {
305   auto* src = R"(
306 struct Builtins {
307   [[builtin(global_invocation_id)]] gid : vec3<u32>;
308   [[builtin(workgroup_id)]] wgid : vec3<u32>;
309 };
310 
311 [[stage(compute), workgroup_size(1)]]
312 fn main(in : Builtins) {
313 }
314 )";
315 
316   auto* expect = R"(
317 struct Builtins {
318   gid : vec3<u32>;
319   wgid : vec3<u32>;
320 };
321 
322 struct tint_symbol_1 {
323   [[builtin(global_invocation_id)]]
324   gid : vec3<u32>;
325   [[builtin(workgroup_id)]]
326   wgid : vec3<u32>;
327 };
328 
329 fn main_inner(in : Builtins) {
330 }
331 
332 [[stage(compute), workgroup_size(1)]]
333 fn main(tint_symbol : tint_symbol_1) {
334   main_inner(Builtins(tint_symbol.gid, tint_symbol.wgid));
335 }
336 )";
337 
338   DataMap data;
339   data.Add<CanonicalizeEntryPointIO::Config>(
340       CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
341   data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
342   auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(
343       src, data);
344   EXPECT_EQ(expect, str(got));
345 }
346 
347 }  // namespace
348 }  // namespace transform
349 }  // namespace tint
350