1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/module_scope_var_to_entry_point_param.h"
16
17 #include <utility>
18
19 #include "src/transform/test_helper.h"
20
21 namespace tint {
22 namespace transform {
23 namespace {
24
25 using ModuleScopeVarToEntryPointParamTest = TransformTest;
26
TEST_F(ModuleScopeVarToEntryPointParamTest,Basic)27 TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
28 auto* src = R"(
29 var<private> p : f32;
30 var<workgroup> w : f32;
31
32 [[stage(compute), workgroup_size(1)]]
33 fn main() {
34 w = p;
35 }
36 )";
37
38 auto* expect = R"(
39 [[stage(compute), workgroup_size(1)]]
40 fn main() {
41 [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol : f32;
42 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32;
43 tint_symbol = tint_symbol_1;
44 }
45 )";
46
47 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
48
49 EXPECT_EQ(expect, str(got));
50 }
51
TEST_F(ModuleScopeVarToEntryPointParamTest,FunctionCalls)52 TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls) {
53 auto* src = R"(
54 var<private> p : f32;
55 var<workgroup> w : f32;
56
57 fn no_uses() {
58 }
59
60 fn bar(a : f32, b : f32) {
61 p = a;
62 w = b;
63 }
64
65 fn foo(a : f32) {
66 let b : f32 = 2.0;
67 bar(a, b);
68 no_uses();
69 }
70
71 [[stage(compute), workgroup_size(1)]]
72 fn main() {
73 foo(1.0);
74 }
75 )";
76
77 auto* expect = R"(
78 fn no_uses() {
79 }
80
81 fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
82 *(tint_symbol) = a;
83 *(tint_symbol_1) = b;
84 }
85
86 fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
87 let b : f32 = 2.0;
88 bar(a, b, tint_symbol_2, tint_symbol_3);
89 no_uses();
90 }
91
92 [[stage(compute), workgroup_size(1)]]
93 fn main() {
94 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_4 : f32;
95 [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_5 : f32;
96 foo(1.0, &(tint_symbol_4), &(tint_symbol_5));
97 }
98 )";
99
100 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
101
102 EXPECT_EQ(expect, str(got));
103 }
104
TEST_F(ModuleScopeVarToEntryPointParamTest,Constructors)105 TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors) {
106 auto* src = R"(
107 var<private> a : f32 = 1.0;
108 var<private> b : f32 = f32();
109
110 [[stage(compute), workgroup_size(1)]]
111 fn main() {
112 let x : f32 = a + b;
113 }
114 )";
115
116 auto* expect = R"(
117 [[stage(compute), workgroup_size(1)]]
118 fn main() {
119 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32 = 1.0;
120 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
121 let x : f32 = (tint_symbol + tint_symbol_1);
122 }
123 )";
124
125 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
126
127 EXPECT_EQ(expect, str(got));
128 }
129
TEST_F(ModuleScopeVarToEntryPointParamTest,Pointers)130 TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers) {
131 auto* src = R"(
132 var<private> p : f32;
133 var<workgroup> w : f32;
134
135 [[stage(compute), workgroup_size(1)]]
136 fn main() {
137 let p_ptr : ptr<private, f32> = &p;
138 let w_ptr : ptr<workgroup, f32> = &w;
139 let x : f32 = *p_ptr + *w_ptr;
140 *p_ptr = x;
141 }
142 )";
143
144 auto* expect = R"(
145 [[stage(compute), workgroup_size(1)]]
146 fn main() {
147 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32;
148 [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_1 : f32;
149 let p_ptr : ptr<private, f32> = &(tint_symbol);
150 let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
151 let x : f32 = (*(p_ptr) + *(w_ptr));
152 *(p_ptr) = x;
153 }
154 )";
155
156 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
157
158 EXPECT_EQ(expect, str(got));
159 }
160
TEST_F(ModuleScopeVarToEntryPointParamTest,FoldAddressOfDeref)161 TEST_F(ModuleScopeVarToEntryPointParamTest, FoldAddressOfDeref) {
162 auto* src = R"(
163 var<private> v : f32;
164
165 fn bar(p : ptr<private, f32>) {
166 (*p) = 0.0;
167 }
168
169 fn foo() {
170 bar(&v);
171 }
172
173 [[stage(compute), workgroup_size(1)]]
174 fn main() {
175 foo();
176 }
177 )";
178
179 auto* expect = R"(
180 fn bar(p : ptr<private, f32>) {
181 *(p) = 0.0;
182 }
183
184 fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
185 bar(tint_symbol);
186 }
187
188 [[stage(compute), workgroup_size(1)]]
189 fn main() {
190 [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32;
191 foo(&(tint_symbol_1));
192 }
193 )";
194
195 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
196
197 EXPECT_EQ(expect, str(got));
198 }
199
TEST_F(ModuleScopeVarToEntryPointParamTest,Buffers_Basic)200 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic) {
201 auto* src = R"(
202 [[block]]
203 struct S {
204 a : f32;
205 };
206
207 [[group(0), binding(0)]]
208 var<uniform> u : S;
209 [[group(0), binding(1)]]
210 var<storage> s : S;
211
212 [[stage(compute), workgroup_size(1)]]
213 fn main() {
214 _ = u;
215 _ = s;
216 }
217 )";
218
219 auto* expect = R"(
220 [[block]]
221 struct S {
222 a : f32;
223 };
224
225 [[stage(compute), workgroup_size(1)]]
226 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, S>) {
227 _ = *(tint_symbol);
228 _ = *(tint_symbol_1);
229 }
230 )";
231
232 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
233
234 EXPECT_EQ(expect, str(got));
235 }
236
TEST_F(ModuleScopeVarToEntryPointParamTest,Buffers_FunctionCalls)237 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
238 auto* src = R"(
239 [[block]]
240 struct S {
241 a : f32;
242 };
243
244 [[group(0), binding(0)]]
245 var<uniform> u : S;
246 [[group(0), binding(1)]]
247 var<storage> s : S;
248
249 fn no_uses() {
250 }
251
252 fn bar(a : f32, b : f32) {
253 _ = u;
254 _ = s;
255 }
256
257 fn foo(a : f32) {
258 let b : f32 = 2.0;
259 _ = u;
260 bar(a, b);
261 no_uses();
262 }
263
264 [[stage(compute), workgroup_size(1)]]
265 fn main() {
266 foo(1.0);
267 }
268 )";
269
270 auto* expect = R"(
271 [[block]]
272 struct S {
273 a : f32;
274 };
275
276 fn no_uses() {
277 }
278
279 fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<storage, S>) {
280 _ = *(tint_symbol);
281 _ = *(tint_symbol_1);
282 }
283
284 fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<storage, S>) {
285 let b : f32 = 2.0;
286 _ = *(tint_symbol_2);
287 bar(a, b, tint_symbol_2, tint_symbol_3);
288 no_uses();
289 }
290
291 [[stage(compute), workgroup_size(1)]]
292 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_4 : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_5 : ptr<storage, S>) {
293 foo(1.0, tint_symbol_4, tint_symbol_5);
294 }
295 )";
296
297 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
298
299 EXPECT_EQ(expect, str(got));
300 }
301
TEST_F(ModuleScopeVarToEntryPointParamTest,HandleTypes_Basic)302 TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
303 auto* src = R"(
304 [[group(0), binding(0)]] var t : texture_2d<f32>;
305 [[group(0), binding(1)]] var s : sampler;
306
307 [[stage(compute), workgroup_size(1)]]
308 fn main() {
309 _ = t;
310 _ = s;
311 }
312 )";
313
314 auto* expect = R"(
315 [[stage(compute), workgroup_size(1)]]
316 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
317 _ = tint_symbol;
318 _ = tint_symbol_1;
319 }
320 )";
321
322 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
323
324 EXPECT_EQ(expect, str(got));
325 }
326
TEST_F(ModuleScopeVarToEntryPointParamTest,HandleTypes_FunctionCalls)327 TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls) {
328 auto* src = R"(
329 [[group(0), binding(0)]] var t : texture_2d<f32>;
330 [[group(0), binding(1)]] var s : sampler;
331
332 fn no_uses() {
333 }
334
335 fn bar(a : f32, b : f32) {
336 _ = t;
337 _ = s;
338 }
339
340 fn foo(a : f32) {
341 let b : f32 = 2.0;
342 _ = t;
343 bar(a, b);
344 no_uses();
345 }
346
347 [[stage(compute), workgroup_size(1)]]
348 fn main() {
349 foo(1.0);
350 }
351 )";
352
353 auto* expect = R"(
354 fn no_uses() {
355 }
356
357 fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
358 _ = tint_symbol;
359 _ = tint_symbol_1;
360 }
361
362 fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
363 let b : f32 = 2.0;
364 _ = tint_symbol_2;
365 bar(a, b, tint_symbol_2, tint_symbol_3);
366 no_uses();
367 }
368
369 [[stage(compute), workgroup_size(1)]]
370 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
371 foo(1.0, tint_symbol_4, tint_symbol_5);
372 }
373 )";
374
375 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
376
377 EXPECT_EQ(expect, str(got));
378 }
379
TEST_F(ModuleScopeVarToEntryPointParamTest,Matrix)380 TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
381 auto* src = R"(
382 var<workgroup> m : mat2x2<f32>;
383
384 [[stage(compute), workgroup_size(1)]]
385 fn main() {
386 let x = m;
387 }
388 )";
389
390 auto* expect = R"(
391 struct tint_symbol_2 {
392 m : mat2x2<f32>;
393 };
394
395 [[stage(compute), workgroup_size(1)]]
396 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
397 let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
398 let x = *(tint_symbol);
399 }
400 )";
401
402 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
403
404 EXPECT_EQ(expect, str(got));
405 }
406
TEST_F(ModuleScopeVarToEntryPointParamTest,NestedMatrix)407 TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
408 auto* src = R"(
409 struct S1 {
410 m : mat2x2<f32>;
411 };
412 struct S2 {
413 s : S1;
414 };
415 var<workgroup> m : array<S2, 4>;
416
417 [[stage(compute), workgroup_size(1)]]
418 fn main() {
419 let x = m;
420 }
421 )";
422
423 auto* expect = R"(
424 struct S1 {
425 m : mat2x2<f32>;
426 };
427
428 struct S2 {
429 s : S1;
430 };
431
432 struct tint_symbol_2 {
433 m : array<S2, 4u>;
434 };
435
436 [[stage(compute), workgroup_size(1)]]
437 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
438 let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
439 let x = *(tint_symbol);
440 }
441 )";
442
443 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
444
445 EXPECT_EQ(expect, str(got));
446 }
447
448 // Test that we do not duplicate a struct type used by multiple workgroup
449 // variables that are promoted to threadgroup memory arguments.
TEST_F(ModuleScopeVarToEntryPointParamTest,DuplicateThreadgroupArgumentTypes)450 TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes) {
451 auto* src = R"(
452 struct S {
453 m : mat2x2<f32>;
454 };
455
456 var<workgroup> a : S;
457
458 var<workgroup> b : S;
459
460 [[stage(compute), workgroup_size(1)]]
461 fn main() {
462 let x = a;
463 let y = b;
464 }
465 )";
466
467 auto* expect = R"(
468 struct S {
469 m : mat2x2<f32>;
470 };
471
472 struct tint_symbol_3 {
473 a : S;
474 b : S;
475 };
476
477 [[stage(compute), workgroup_size(1)]]
478 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
479 let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
480 let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
481 let x = *(tint_symbol);
482 let y = *(tint_symbol_2);
483 }
484 )";
485
486 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
487
488 EXPECT_EQ(expect, str(got));
489 }
490
TEST_F(ModuleScopeVarToEntryPointParamTest,UnusedVariables)491 TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
492 auto* src = R"(
493 [[block]]
494 struct S {
495 a : f32;
496 };
497
498 var<private> p : f32;
499 var<workgroup> w : f32;
500
501 [[group(0), binding(0)]]
502 var<uniform> ub : S;
503 [[group(0), binding(1)]]
504 var<storage> sb : S;
505
506 [[group(0), binding(2)]] var t : texture_2d<f32>;
507 [[group(0), binding(3)]] var s : sampler;
508
509 [[stage(compute), workgroup_size(1)]]
510 fn main() {
511 }
512 )";
513
514 auto* expect = R"(
515 [[block]]
516 struct S {
517 a : f32;
518 };
519
520 [[stage(compute), workgroup_size(1)]]
521 fn main() {
522 }
523 )";
524
525 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
526
527 EXPECT_EQ(expect, str(got));
528 }
529
TEST_F(ModuleScopeVarToEntryPointParamTest,EmtpyModule)530 TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
531 auto* src = "";
532
533 auto got = Run<ModuleScopeVarToEntryPointParam>(src);
534
535 EXPECT_EQ(src, str(got));
536 }
537
538 } // namespace
539 } // namespace transform
540 } // namespace tint
541