• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/module_scope_var_to_entry_point_param.h"
16 
17 #include <utility>
18 
19 #include "src/transform/test_helper.h"
20 
21 namespace tint {
22 namespace transform {
23 namespace {
24 
25 using ModuleScopeVarToEntryPointParamTest = TransformTest;
26 
TEST_F(ModuleScopeVarToEntryPointParamTest,Basic)27 TEST_F(ModuleScopeVarToEntryPointParamTest, Basic) {
28   auto* src = R"(
29 var<private> p : f32;
30 var<workgroup> w : f32;
31 
32 [[stage(compute), workgroup_size(1)]]
33 fn main() {
34   w = p;
35 }
36 )";
37 
38   auto* expect = R"(
39 [[stage(compute), workgroup_size(1)]]
40 fn main() {
41   [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol : f32;
42   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32;
43   tint_symbol = tint_symbol_1;
44 }
45 )";
46 
47   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
48 
49   EXPECT_EQ(expect, str(got));
50 }
51 
TEST_F(ModuleScopeVarToEntryPointParamTest,FunctionCalls)52 TEST_F(ModuleScopeVarToEntryPointParamTest, FunctionCalls) {
53   auto* src = R"(
54 var<private> p : f32;
55 var<workgroup> w : f32;
56 
57 fn no_uses() {
58 }
59 
60 fn bar(a : f32, b : f32) {
61   p = a;
62   w = b;
63 }
64 
65 fn foo(a : f32) {
66   let b : f32 = 2.0;
67   bar(a, b);
68   no_uses();
69 }
70 
71 [[stage(compute), workgroup_size(1)]]
72 fn main() {
73   foo(1.0);
74 }
75 )";
76 
77   auto* expect = R"(
78 fn no_uses() {
79 }
80 
81 fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
82   *(tint_symbol) = a;
83   *(tint_symbol_1) = b;
84 }
85 
86 fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
87   let b : f32 = 2.0;
88   bar(a, b, tint_symbol_2, tint_symbol_3);
89   no_uses();
90 }
91 
92 [[stage(compute), workgroup_size(1)]]
93 fn main() {
94   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_4 : f32;
95   [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_5 : f32;
96   foo(1.0, &(tint_symbol_4), &(tint_symbol_5));
97 }
98 )";
99 
100   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
101 
102   EXPECT_EQ(expect, str(got));
103 }
104 
TEST_F(ModuleScopeVarToEntryPointParamTest,Constructors)105 TEST_F(ModuleScopeVarToEntryPointParamTest, Constructors) {
106   auto* src = R"(
107 var<private> a : f32 = 1.0;
108 var<private> b : f32 = f32();
109 
110 [[stage(compute), workgroup_size(1)]]
111 fn main() {
112   let x : f32 = a + b;
113 }
114 )";
115 
116   auto* expect = R"(
117 [[stage(compute), workgroup_size(1)]]
118 fn main() {
119   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32 = 1.0;
120   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32 = f32();
121   let x : f32 = (tint_symbol + tint_symbol_1);
122 }
123 )";
124 
125   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
126 
127   EXPECT_EQ(expect, str(got));
128 }
129 
TEST_F(ModuleScopeVarToEntryPointParamTest,Pointers)130 TEST_F(ModuleScopeVarToEntryPointParamTest, Pointers) {
131   auto* src = R"(
132 var<private> p : f32;
133 var<workgroup> w : f32;
134 
135 [[stage(compute), workgroup_size(1)]]
136 fn main() {
137   let p_ptr : ptr<private, f32> = &p;
138   let w_ptr : ptr<workgroup, f32> = &w;
139   let x : f32 = *p_ptr + *w_ptr;
140   *p_ptr = x;
141 }
142 )";
143 
144   auto* expect = R"(
145 [[stage(compute), workgroup_size(1)]]
146 fn main() {
147   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol : f32;
148   [[internal(disable_validation__ignore_storage_class)]] var<workgroup> tint_symbol_1 : f32;
149   let p_ptr : ptr<private, f32> = &(tint_symbol);
150   let w_ptr : ptr<workgroup, f32> = &(tint_symbol_1);
151   let x : f32 = (*(p_ptr) + *(w_ptr));
152   *(p_ptr) = x;
153 }
154 )";
155 
156   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
157 
158   EXPECT_EQ(expect, str(got));
159 }
160 
TEST_F(ModuleScopeVarToEntryPointParamTest,FoldAddressOfDeref)161 TEST_F(ModuleScopeVarToEntryPointParamTest, FoldAddressOfDeref) {
162   auto* src = R"(
163 var<private> v : f32;
164 
165 fn bar(p : ptr<private, f32>) {
166   (*p) = 0.0;
167 }
168 
169 fn foo() {
170   bar(&v);
171 }
172 
173 [[stage(compute), workgroup_size(1)]]
174 fn main() {
175   foo();
176 }
177 )";
178 
179   auto* expect = R"(
180 fn bar(p : ptr<private, f32>) {
181   *(p) = 0.0;
182 }
183 
184 fn foo([[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
185   bar(tint_symbol);
186 }
187 
188 [[stage(compute), workgroup_size(1)]]
189 fn main() {
190   [[internal(disable_validation__ignore_storage_class)]] var<private> tint_symbol_1 : f32;
191   foo(&(tint_symbol_1));
192 }
193 )";
194 
195   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
196 
197   EXPECT_EQ(expect, str(got));
198 }
199 
TEST_F(ModuleScopeVarToEntryPointParamTest,Buffers_Basic)200 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_Basic) {
201   auto* src = R"(
202 [[block]]
203 struct S {
204   a : f32;
205 };
206 
207 [[group(0), binding(0)]]
208 var<uniform> u : S;
209 [[group(0), binding(1)]]
210 var<storage> s : S;
211 
212 [[stage(compute), workgroup_size(1)]]
213 fn main() {
214   _ = u;
215   _ = s;
216 }
217 )";
218 
219   auto* expect = R"(
220 [[block]]
221 struct S {
222   a : f32;
223 };
224 
225 [[stage(compute), workgroup_size(1)]]
226 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_1 : ptr<storage, S>) {
227   _ = *(tint_symbol);
228   _ = *(tint_symbol_1);
229 }
230 )";
231 
232   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
233 
234   EXPECT_EQ(expect, str(got));
235 }
236 
TEST_F(ModuleScopeVarToEntryPointParamTest,Buffers_FunctionCalls)237 TEST_F(ModuleScopeVarToEntryPointParamTest, Buffers_FunctionCalls) {
238   auto* src = R"(
239 [[block]]
240 struct S {
241   a : f32;
242 };
243 
244 [[group(0), binding(0)]]
245 var<uniform> u : S;
246 [[group(0), binding(1)]]
247 var<storage> s : S;
248 
249 fn no_uses() {
250 }
251 
252 fn bar(a : f32, b : f32) {
253   _ = u;
254   _ = s;
255 }
256 
257 fn foo(a : f32) {
258   let b : f32 = 2.0;
259   _ = u;
260   bar(a, b);
261   no_uses();
262 }
263 
264 [[stage(compute), workgroup_size(1)]]
265 fn main() {
266   foo(1.0);
267 }
268 )";
269 
270   auto* expect = R"(
271 [[block]]
272 struct S {
273   a : f32;
274 };
275 
276 fn no_uses() {
277 }
278 
279 fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<storage, S>) {
280   _ = *(tint_symbol);
281   _ = *(tint_symbol_1);
282 }
283 
284 fn foo(a : f32, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<uniform, S>, [[internal(disable_validation__ignore_storage_class), internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<storage, S>) {
285   let b : f32 = 2.0;
286   _ = *(tint_symbol_2);
287   bar(a, b, tint_symbol_2, tint_symbol_3);
288   no_uses();
289 }
290 
291 [[stage(compute), workgroup_size(1)]]
292 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_4 : ptr<uniform, S>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter), internal(disable_validation__ignore_storage_class)]] tint_symbol_5 : ptr<storage, S>) {
293   foo(1.0, tint_symbol_4, tint_symbol_5);
294 }
295 )";
296 
297   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
298 
299   EXPECT_EQ(expect, str(got));
300 }
301 
TEST_F(ModuleScopeVarToEntryPointParamTest,HandleTypes_Basic)302 TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_Basic) {
303   auto* src = R"(
304 [[group(0), binding(0)]] var t : texture_2d<f32>;
305 [[group(0), binding(1)]] var s : sampler;
306 
307 [[stage(compute), workgroup_size(1)]]
308 fn main() {
309   _ = t;
310   _ = s;
311 }
312 )";
313 
314   auto* expect = R"(
315 [[stage(compute), workgroup_size(1)]]
316 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : sampler) {
317   _ = tint_symbol;
318   _ = tint_symbol_1;
319 }
320 )";
321 
322   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
323 
324   EXPECT_EQ(expect, str(got));
325 }
326 
TEST_F(ModuleScopeVarToEntryPointParamTest,HandleTypes_FunctionCalls)327 TEST_F(ModuleScopeVarToEntryPointParamTest, HandleTypes_FunctionCalls) {
328   auto* src = R"(
329 [[group(0), binding(0)]] var t : texture_2d<f32>;
330 [[group(0), binding(1)]] var s : sampler;
331 
332 fn no_uses() {
333 }
334 
335 fn bar(a : f32, b : f32) {
336   _ = t;
337   _ = s;
338 }
339 
340 fn foo(a : f32) {
341   let b : f32 = 2.0;
342   _ = t;
343   bar(a, b);
344   no_uses();
345 }
346 
347 [[stage(compute), workgroup_size(1)]]
348 fn main() {
349   foo(1.0);
350 }
351 )";
352 
353   auto* expect = R"(
354 fn no_uses() {
355 }
356 
357 fn bar(a : f32, b : f32, tint_symbol : texture_2d<f32>, tint_symbol_1 : sampler) {
358   _ = tint_symbol;
359   _ = tint_symbol_1;
360 }
361 
362 fn foo(a : f32, tint_symbol_2 : texture_2d<f32>, tint_symbol_3 : sampler) {
363   let b : f32 = 2.0;
364   _ = tint_symbol_2;
365   bar(a, b, tint_symbol_2, tint_symbol_3);
366   no_uses();
367 }
368 
369 [[stage(compute), workgroup_size(1)]]
370 fn main([[group(0), binding(0), internal(disable_validation__entry_point_parameter)]] tint_symbol_4 : texture_2d<f32>, [[group(0), binding(1), internal(disable_validation__entry_point_parameter)]] tint_symbol_5 : sampler) {
371   foo(1.0, tint_symbol_4, tint_symbol_5);
372 }
373 )";
374 
375   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
376 
377   EXPECT_EQ(expect, str(got));
378 }
379 
TEST_F(ModuleScopeVarToEntryPointParamTest,Matrix)380 TEST_F(ModuleScopeVarToEntryPointParamTest, Matrix) {
381   auto* src = R"(
382 var<workgroup> m : mat2x2<f32>;
383 
384 [[stage(compute), workgroup_size(1)]]
385 fn main() {
386   let x = m;
387 }
388 )";
389 
390   auto* expect = R"(
391 struct tint_symbol_2 {
392   m : mat2x2<f32>;
393 };
394 
395 [[stage(compute), workgroup_size(1)]]
396 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
397   let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
398   let x = *(tint_symbol);
399 }
400 )";
401 
402   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
403 
404   EXPECT_EQ(expect, str(got));
405 }
406 
TEST_F(ModuleScopeVarToEntryPointParamTest,NestedMatrix)407 TEST_F(ModuleScopeVarToEntryPointParamTest, NestedMatrix) {
408   auto* src = R"(
409 struct S1 {
410   m : mat2x2<f32>;
411 };
412 struct S2 {
413   s : S1;
414 };
415 var<workgroup> m : array<S2, 4>;
416 
417 [[stage(compute), workgroup_size(1)]]
418 fn main() {
419   let x = m;
420 }
421 )";
422 
423   auto* expect = R"(
424 struct S1 {
425   m : mat2x2<f32>;
426 };
427 
428 struct S2 {
429   s : S1;
430 };
431 
432 struct tint_symbol_2 {
433   m : array<S2, 4u>;
434 };
435 
436 [[stage(compute), workgroup_size(1)]]
437 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
438   let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
439   let x = *(tint_symbol);
440 }
441 )";
442 
443   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
444 
445   EXPECT_EQ(expect, str(got));
446 }
447 
448 // Test that we do not duplicate a struct type used by multiple workgroup
449 // variables that are promoted to threadgroup memory arguments.
TEST_F(ModuleScopeVarToEntryPointParamTest,DuplicateThreadgroupArgumentTypes)450 TEST_F(ModuleScopeVarToEntryPointParamTest, DuplicateThreadgroupArgumentTypes) {
451   auto* src = R"(
452 struct S {
453   m : mat2x2<f32>;
454 };
455 
456 var<workgroup> a : S;
457 
458 var<workgroup> b : S;
459 
460 [[stage(compute), workgroup_size(1)]]
461 fn main() {
462   let x = a;
463   let y = b;
464 }
465 )";
466 
467   auto* expect = R"(
468 struct S {
469   m : mat2x2<f32>;
470 };
471 
472 struct tint_symbol_3 {
473   a : S;
474   b : S;
475 };
476 
477 [[stage(compute), workgroup_size(1)]]
478 fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_3>) {
479   let tint_symbol : ptr<workgroup, S> = &((*(tint_symbol_1)).a);
480   let tint_symbol_2 : ptr<workgroup, S> = &((*(tint_symbol_1)).b);
481   let x = *(tint_symbol);
482   let y = *(tint_symbol_2);
483 }
484 )";
485 
486   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
487 
488   EXPECT_EQ(expect, str(got));
489 }
490 
TEST_F(ModuleScopeVarToEntryPointParamTest,UnusedVariables)491 TEST_F(ModuleScopeVarToEntryPointParamTest, UnusedVariables) {
492   auto* src = R"(
493 [[block]]
494 struct S {
495   a : f32;
496 };
497 
498 var<private> p : f32;
499 var<workgroup> w : f32;
500 
501 [[group(0), binding(0)]]
502 var<uniform> ub : S;
503 [[group(0), binding(1)]]
504 var<storage> sb : S;
505 
506 [[group(0), binding(2)]] var t : texture_2d<f32>;
507 [[group(0), binding(3)]] var s : sampler;
508 
509 [[stage(compute), workgroup_size(1)]]
510 fn main() {
511 }
512 )";
513 
514   auto* expect = R"(
515 [[block]]
516 struct S {
517   a : f32;
518 };
519 
520 [[stage(compute), workgroup_size(1)]]
521 fn main() {
522 }
523 )";
524 
525   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
526 
527   EXPECT_EQ(expect, str(got));
528 }
529 
TEST_F(ModuleScopeVarToEntryPointParamTest,EmtpyModule)530 TEST_F(ModuleScopeVarToEntryPointParamTest, EmtpyModule) {
531   auto* src = "";
532 
533   auto got = Run<ModuleScopeVarToEntryPointParam>(src);
534 
535   EXPECT_EQ(src, str(got));
536 }
537 
538 }  // namespace
539 }  // namespace transform
540 }  // namespace tint
541