1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/zero_init_workgroup_memory.h"
16
17 #include <utility>
18
19 #include "src/transform/test_helper.h"
20
21 namespace tint {
22 namespace transform {
23 namespace {
24
25 using ZeroInitWorkgroupMemoryTest = TransformTest;
26
TEST_F(ZeroInitWorkgroupMemoryTest,EmptyModule)27 TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
28 auto* src = "";
29 auto* expect = src;
30
31 auto got = Run<ZeroInitWorkgroupMemory>(src);
32
33 EXPECT_EQ(expect, str(got));
34 }
35
TEST_F(ZeroInitWorkgroupMemoryTest,NoWorkgroupVars)36 TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) {
37 auto* src = R"(
38 var<private> v : i32;
39
40 fn f() {
41 v = 1;
42 }
43 )";
44 auto* expect = src;
45
46 auto got = Run<ZeroInitWorkgroupMemory>(src);
47
48 EXPECT_EQ(expect, str(got));
49 }
50
TEST_F(ZeroInitWorkgroupMemoryTest,UnreferencedWorkgroupVars)51 TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) {
52 auto* src = R"(
53 var<workgroup> a : i32;
54
55 var<workgroup> b : i32;
56
57 var<workgroup> c : i32;
58
59 fn unreferenced() {
60 b = c;
61 }
62
63 [[stage(compute), workgroup_size(1)]]
64 fn f() {
65 }
66 )";
67 auto* expect = src;
68
69 auto got = Run<ZeroInitWorkgroupMemory>(src);
70
71 EXPECT_EQ(expect, str(got));
72 }
73
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_ExistingLocalIndex)74 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) {
75 auto* src = R"(
76 var<workgroup> v : i32;
77
78 [[stage(compute), workgroup_size(1)]]
79 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
80 ignore(v); // Initialization should be inserted above this statement
81 }
82 )";
83 auto* expect = R"(
84 var<workgroup> v : i32;
85
86 [[stage(compute), workgroup_size(1)]]
87 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
88 {
89 v = i32();
90 }
91 workgroupBarrier();
92 ignore(v);
93 }
94 )";
95
96 auto got = Run<ZeroInitWorkgroupMemory>(src);
97
98 EXPECT_EQ(expect, str(got));
99 }
100
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_ExistingLocalIndexInStruct)101 TEST_F(ZeroInitWorkgroupMemoryTest,
102 SingleWorkgroupVar_ExistingLocalIndexInStruct) {
103 auto* src = R"(
104 var<workgroup> v : i32;
105
106 struct Params {
107 [[builtin(local_invocation_index)]] local_idx : u32;
108 };
109
110 [[stage(compute), workgroup_size(1)]]
111 fn f(params : Params) {
112 ignore(v); // Initialization should be inserted above this statement
113 }
114 )";
115 auto* expect = R"(
116 var<workgroup> v : i32;
117
118 struct Params {
119 [[builtin(local_invocation_index)]]
120 local_idx : u32;
121 };
122
123 [[stage(compute), workgroup_size(1)]]
124 fn f(params : Params) {
125 {
126 v = i32();
127 }
128 workgroupBarrier();
129 ignore(v);
130 }
131 )";
132
133 auto got = Run<ZeroInitWorkgroupMemory>(src);
134
135 EXPECT_EQ(expect, str(got));
136 }
137
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_InjectedLocalIndex)138 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) {
139 auto* src = R"(
140 var<workgroup> v : i32;
141
142 [[stage(compute), workgroup_size(1)]]
143 fn f() {
144 ignore(v); // Initialization should be inserted above this statement
145 }
146 )";
147 auto* expect = R"(
148 var<workgroup> v : i32;
149
150 [[stage(compute), workgroup_size(1)]]
151 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
152 {
153 v = i32();
154 }
155 workgroupBarrier();
156 ignore(v);
157 }
158 )";
159
160 auto got = Run<ZeroInitWorkgroupMemory>(src);
161
162 EXPECT_EQ(expect, str(got));
163 }
164
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size1)165 TEST_F(ZeroInitWorkgroupMemoryTest,
166 MultipleWorkgroupVar_ExistingLocalIndex_Size1) {
167 auto* src = R"(
168 struct S {
169 x : i32;
170 y : array<i32, 8>;
171 };
172
173 var<workgroup> a : i32;
174
175 var<workgroup> b : S;
176
177 var<workgroup> c : array<S, 32>;
178
179 [[stage(compute), workgroup_size(1)]]
180 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
181 ignore(a); // Initialization should be inserted above this statement
182 ignore(b);
183 ignore(c);
184 }
185 )";
186 auto* expect = R"(
187 struct S {
188 x : i32;
189 y : array<i32, 8>;
190 };
191
192 var<workgroup> a : i32;
193
194 var<workgroup> b : S;
195
196 var<workgroup> c : array<S, 32>;
197
198 [[stage(compute), workgroup_size(1)]]
199 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
200 {
201 a = i32();
202 b.x = i32();
203 }
204 for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 1u)) {
205 let i : u32 = idx;
206 b.y[i] = i32();
207 }
208 for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
209 let i_1 : u32 = idx_1;
210 c[i_1].x = i32();
211 }
212 for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
213 let i_2 : u32 = (idx_2 / 8u);
214 let i : u32 = (idx_2 % 8u);
215 c[i_2].y[i] = i32();
216 }
217 workgroupBarrier();
218 ignore(a);
219 ignore(b);
220 ignore(c);
221 }
222 )";
223
224 auto got = Run<ZeroInitWorkgroupMemory>(src);
225
226 EXPECT_EQ(expect, str(got));
227 }
228
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3)229 TEST_F(ZeroInitWorkgroupMemoryTest,
230 MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3) {
231 auto* src = R"(
232 struct S {
233 x : i32;
234 y : array<i32, 8>;
235 };
236
237 var<workgroup> a : i32;
238
239 var<workgroup> b : S;
240
241 var<workgroup> c : array<S, 32>;
242
243 [[stage(compute), workgroup_size(2, 3)]]
244 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
245 ignore(a); // Initialization should be inserted above this statement
246 ignore(b);
247 ignore(c);
248 }
249 )";
250 auto* expect = R"(
251 struct S {
252 x : i32;
253 y : array<i32, 8>;
254 };
255
256 var<workgroup> a : i32;
257
258 var<workgroup> b : S;
259
260 var<workgroup> c : array<S, 32>;
261
262 [[stage(compute), workgroup_size(2, 3)]]
263 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
264 if ((local_idx < 1u)) {
265 a = i32();
266 b.x = i32();
267 }
268 for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 6u)) {
269 let i : u32 = idx;
270 b.y[i] = i32();
271 }
272 for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 6u)) {
273 let i_1 : u32 = idx_1;
274 c[i_1].x = i32();
275 }
276 for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 6u)) {
277 let i_2 : u32 = (idx_2 / 8u);
278 let i : u32 = (idx_2 % 8u);
279 c[i_2].y[i] = i32();
280 }
281 workgroupBarrier();
282 ignore(a);
283 ignore(b);
284 ignore(c);
285 }
286 )";
287
288 auto got = Run<ZeroInitWorkgroupMemory>(src);
289
290 EXPECT_EQ(expect, str(got));
291 }
292
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X)293 TEST_F(ZeroInitWorkgroupMemoryTest,
294 MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X) {
295 auto* src = R"(
296 struct S {
297 x : i32;
298 y : array<i32, 8>;
299 };
300
301 var<workgroup> a : i32;
302
303 var<workgroup> b : S;
304
305 var<workgroup> c : array<S, 32>;
306
307 [[override(1)]] let X : i32;
308
309 [[stage(compute), workgroup_size(2, 3, X)]]
310 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
311 ignore(a); // Initialization should be inserted above this statement
312 ignore(b);
313 ignore(c);
314 }
315 )";
316 auto* expect =
317 R"(
318 struct S {
319 x : i32;
320 y : array<i32, 8>;
321 };
322
323 var<workgroup> a : i32;
324
325 var<workgroup> b : S;
326
327 var<workgroup> c : array<S, 32>;
328
329 [[override(1)]] let X : i32;
330
331 [[stage(compute), workgroup_size(2, 3, X)]]
332 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
333 for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (u32(X) * 6u))) {
334 a = i32();
335 b.x = i32();
336 }
337 for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (u32(X) * 6u))) {
338 let i : u32 = idx_1;
339 b.y[i] = i32();
340 }
341 for(var idx_2 : u32 = local_idx; (idx_2 < 32u); idx_2 = (idx_2 + (u32(X) * 6u))) {
342 let i_1 : u32 = idx_2;
343 c[i_1].x = i32();
344 }
345 for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (u32(X) * 6u))) {
346 let i_2 : u32 = (idx_3 / 8u);
347 let i : u32 = (idx_3 % 8u);
348 c[i_2].y[i] = i32();
349 }
350 workgroupBarrier();
351 ignore(a);
352 ignore(b);
353 ignore(c);
354 }
355 )";
356
357 auto got = Run<ZeroInitWorkgroupMemory>(src);
358
359 EXPECT_EQ(expect, str(got));
360 }
361
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u)362 TEST_F(ZeroInitWorkgroupMemoryTest,
363 MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u) {
364 auto* src = R"(
365 struct S {
366 x : array<array<i32, 8>, 10>;
367 y : array<i32, 8>;
368 z : array<array<array<i32, 8>, 10>, 20>;
369 };
370
371 var<workgroup> a : i32;
372
373 var<workgroup> b : S;
374
375 var<workgroup> c : array<S, 32>;
376
377 [[override(1)]] let X : u32;
378
379 [[stage(compute), workgroup_size(5u, X, 10u)]]
380 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
381 ignore(a); // Initialization should be inserted above this statement
382 ignore(b);
383 ignore(c);
384 }
385 )";
386 auto* expect =
387 R"(
388 struct S {
389 x : array<array<i32, 8>, 10>;
390 y : array<i32, 8>;
391 z : array<array<array<i32, 8>, 10>, 20>;
392 };
393
394 var<workgroup> a : i32;
395
396 var<workgroup> b : S;
397
398 var<workgroup> c : array<S, 32>;
399
400 [[override(1)]] let X : u32;
401
402 [[stage(compute), workgroup_size(5u, X, 10u)]]
403 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
404 for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (X * 50u))) {
405 a = i32();
406 }
407 for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (X * 50u))) {
408 let i_1 : u32 = idx_1;
409 b.y[i_1] = i32();
410 }
411 for(var idx_2 : u32 = local_idx; (idx_2 < 80u); idx_2 = (idx_2 + (X * 50u))) {
412 let i : u32 = (idx_2 / 8u);
413 let i_1 : u32 = (idx_2 % 8u);
414 b.x[i][i_1] = i32();
415 }
416 for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (X * 50u))) {
417 let i_4 : u32 = (idx_3 / 8u);
418 let i_1 : u32 = (idx_3 % 8u);
419 c[i_4].y[i_1] = i32();
420 }
421 for(var idx_4 : u32 = local_idx; (idx_4 < 1600u); idx_4 = (idx_4 + (X * 50u))) {
422 let i_2 : u32 = (idx_4 / 80u);
423 let i : u32 = ((idx_4 % 80u) / 8u);
424 let i_1 : u32 = (idx_4 % 8u);
425 b.z[i_2][i][i_1] = i32();
426 }
427 for(var idx_5 : u32 = local_idx; (idx_5 < 2560u); idx_5 = (idx_5 + (X * 50u))) {
428 let i_3 : u32 = (idx_5 / 80u);
429 let i : u32 = ((idx_5 % 80u) / 8u);
430 let i_1 : u32 = (idx_5 % 8u);
431 c[i_3].x[i][i_1] = i32();
432 }
433 for(var idx_6 : u32 = local_idx; (idx_6 < 51200u); idx_6 = (idx_6 + (X * 50u))) {
434 let i_5 : u32 = (idx_6 / 1600u);
435 let i_2 : u32 = ((idx_6 % 1600u) / 80u);
436 let i : u32 = ((idx_6 % 80u) / 8u);
437 let i_1 : u32 = (idx_6 % 8u);
438 c[i_5].z[i_2][i][i_1] = i32();
439 }
440 workgroupBarrier();
441 ignore(a);
442 ignore(b);
443 ignore(c);
444 }
445 )";
446
447 auto got = Run<ZeroInitWorkgroupMemory>(src);
448
449 EXPECT_EQ(expect, str(got));
450 }
451
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_InjectedLocalIndex)452 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) {
453 auto* src = R"(
454 struct S {
455 x : i32;
456 y : array<i32, 8>;
457 };
458
459 var<workgroup> a : i32;
460
461 var<workgroup> b : S;
462
463 var<workgroup> c : array<S, 32>;
464
465 [[stage(compute), workgroup_size(1)]]
466 fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
467 ignore(a); // Initialization should be inserted above this statement
468 ignore(b);
469 ignore(c);
470 }
471 )";
472 auto* expect = R"(
473 struct S {
474 x : i32;
475 y : array<i32, 8>;
476 };
477
478 var<workgroup> a : i32;
479
480 var<workgroup> b : S;
481
482 var<workgroup> c : array<S, 32>;
483
484 [[stage(compute), workgroup_size(1)]]
485 fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32) {
486 {
487 a = i32();
488 b.x = i32();
489 }
490 for(var idx : u32 = local_invocation_index; (idx < 8u); idx = (idx + 1u)) {
491 let i : u32 = idx;
492 b.y[i] = i32();
493 }
494 for(var idx_1 : u32 = local_invocation_index; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
495 let i_1 : u32 = idx_1;
496 c[i_1].x = i32();
497 }
498 for(var idx_2 : u32 = local_invocation_index; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
499 let i_2 : u32 = (idx_2 / 8u);
500 let i : u32 = (idx_2 % 8u);
501 c[i_2].y[i] = i32();
502 }
503 workgroupBarrier();
504 ignore(a);
505 ignore(b);
506 ignore(c);
507 }
508 )";
509
510 auto got = Run<ZeroInitWorkgroupMemory>(src);
511
512 EXPECT_EQ(expect, str(got));
513 }
514
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_MultipleEntryPoints)515 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) {
516 auto* src = R"(
517 struct S {
518 x : i32;
519 y : array<i32, 8>;
520 };
521
522 var<workgroup> a : i32;
523
524 var<workgroup> b : S;
525
526 var<workgroup> c : array<S, 32>;
527
528 [[stage(compute), workgroup_size(1)]]
529 fn f1() {
530 ignore(a); // Initialization should be inserted above this statement
531 ignore(c);
532 }
533
534 [[stage(compute), workgroup_size(1, 2, 3)]]
535 fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
536 ignore(b); // Initialization should be inserted above this statement
537 }
538
539 [[stage(compute), workgroup_size(4, 5, 6)]]
540 fn f3() {
541 ignore(c); // Initialization should be inserted above this statement
542 ignore(a);
543 }
544 )";
545 auto* expect = R"(
546 struct S {
547 x : i32;
548 y : array<i32, 8>;
549 };
550
551 var<workgroup> a : i32;
552
553 var<workgroup> b : S;
554
555 var<workgroup> c : array<S, 32>;
556
557 [[stage(compute), workgroup_size(1)]]
558 fn f1([[builtin(local_invocation_index)]] local_invocation_index : u32) {
559 {
560 a = i32();
561 }
562 for(var idx : u32 = local_invocation_index; (idx < 32u); idx = (idx + 1u)) {
563 let i : u32 = idx;
564 c[i].x = i32();
565 }
566 for(var idx_1 : u32 = local_invocation_index; (idx_1 < 256u); idx_1 = (idx_1 + 1u)) {
567 let i_1 : u32 = (idx_1 / 8u);
568 let i_2 : u32 = (idx_1 % 8u);
569 c[i_1].y[i_2] = i32();
570 }
571 workgroupBarrier();
572 ignore(a);
573 ignore(c);
574 }
575
576 [[stage(compute), workgroup_size(1, 2, 3)]]
577 fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index_1 : u32) {
578 if ((local_invocation_index_1 < 1u)) {
579 b.x = i32();
580 }
581 for(var idx_2 : u32 = local_invocation_index_1; (idx_2 < 8u); idx_2 = (idx_2 + 6u)) {
582 let i_3 : u32 = idx_2;
583 b.y[i_3] = i32();
584 }
585 workgroupBarrier();
586 ignore(b);
587 }
588
589 [[stage(compute), workgroup_size(4, 5, 6)]]
590 fn f3([[builtin(local_invocation_index)]] local_invocation_index_2 : u32) {
591 if ((local_invocation_index_2 < 1u)) {
592 a = i32();
593 }
594 if ((local_invocation_index_2 < 32u)) {
595 let i_4 : u32 = local_invocation_index_2;
596 c[i_4].x = i32();
597 }
598 for(var idx_3 : u32 = local_invocation_index_2; (idx_3 < 256u); idx_3 = (idx_3 + 120u)) {
599 let i_5 : u32 = (idx_3 / 8u);
600 let i_6 : u32 = (idx_3 % 8u);
601 c[i_5].y[i_6] = i32();
602 }
603 workgroupBarrier();
604 ignore(c);
605 ignore(a);
606 }
607 )";
608
609 auto got = Run<ZeroInitWorkgroupMemory>(src);
610
611 EXPECT_EQ(expect, str(got));
612 }
613
TEST_F(ZeroInitWorkgroupMemoryTest,TransitiveUsage)614 TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) {
615 auto* src = R"(
616 var<workgroup> v : i32;
617
618 fn use_v() {
619 ignore(v);
620 }
621
622 fn call_use_v() {
623 use_v();
624 }
625
626 [[stage(compute), workgroup_size(1)]]
627 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
628 call_use_v(); // Initialization should be inserted above this statement
629 }
630 )";
631 auto* expect = R"(
632 var<workgroup> v : i32;
633
634 fn use_v() {
635 ignore(v);
636 }
637
638 fn call_use_v() {
639 use_v();
640 }
641
642 [[stage(compute), workgroup_size(1)]]
643 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
644 {
645 v = i32();
646 }
647 workgroupBarrier();
648 call_use_v();
649 }
650 )";
651
652 auto got = Run<ZeroInitWorkgroupMemory>(src);
653
654 EXPECT_EQ(expect, str(got));
655 }
656
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupAtomics)657 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) {
658 auto* src = R"(
659 var<workgroup> i : atomic<i32>;
660 var<workgroup> u : atomic<u32>;
661
662 [[stage(compute), workgroup_size(1)]]
663 fn f() {
664 ignore(i); // Initialization should be inserted above this statement
665 ignore(u);
666 }
667 )";
668 auto* expect = R"(
669 var<workgroup> i : atomic<i32>;
670
671 var<workgroup> u : atomic<u32>;
672
673 [[stage(compute), workgroup_size(1)]]
674 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
675 {
676 atomicStore(&(i), i32());
677 atomicStore(&(u), u32());
678 }
679 workgroupBarrier();
680 ignore(i);
681 ignore(u);
682 }
683 )";
684
685 auto got = Run<ZeroInitWorkgroupMemory>(src);
686
687 EXPECT_EQ(expect, str(got));
688 }
689
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupStructOfAtomics)690 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) {
691 auto* src = R"(
692 struct S {
693 a : i32;
694 i : atomic<i32>;
695 b : f32;
696 u : atomic<u32>;
697 c : u32;
698 };
699
700 var<workgroup> w : S;
701
702 [[stage(compute), workgroup_size(1)]]
703 fn f() {
704 ignore(w); // Initialization should be inserted above this statement
705 }
706 )";
707 auto* expect = R"(
708 struct S {
709 a : i32;
710 i : atomic<i32>;
711 b : f32;
712 u : atomic<u32>;
713 c : u32;
714 };
715
716 var<workgroup> w : S;
717
718 [[stage(compute), workgroup_size(1)]]
719 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
720 {
721 w.a = i32();
722 atomicStore(&(w.i), i32());
723 w.b = f32();
724 atomicStore(&(w.u), u32());
725 w.c = u32();
726 }
727 workgroupBarrier();
728 ignore(w);
729 }
730 )";
731
732 auto got = Run<ZeroInitWorkgroupMemory>(src);
733
734 EXPECT_EQ(expect, str(got));
735 }
736
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupArrayOfAtomics)737 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) {
738 auto* src = R"(
739 var<workgroup> w : array<atomic<u32>, 4>;
740
741 [[stage(compute), workgroup_size(1)]]
742 fn f() {
743 ignore(w); // Initialization should be inserted above this statement
744 }
745 )";
746 auto* expect = R"(
747 var<workgroup> w : array<atomic<u32>, 4>;
748
749 [[stage(compute), workgroup_size(1)]]
750 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
751 for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
752 let i : u32 = idx;
753 atomicStore(&(w[i]), u32());
754 }
755 workgroupBarrier();
756 ignore(w);
757 }
758 )";
759
760 auto got = Run<ZeroInitWorkgroupMemory>(src);
761
762 EXPECT_EQ(expect, str(got));
763 }
764
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupArrayOfStructOfAtomics)765 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) {
766 auto* src = R"(
767 struct S {
768 a : i32;
769 i : atomic<i32>;
770 b : f32;
771 u : atomic<u32>;
772 c : u32;
773 };
774
775 var<workgroup> w : array<S, 4>;
776
777 [[stage(compute), workgroup_size(1)]]
778 fn f() {
779 ignore(w); // Initialization should be inserted above this statement
780 }
781 )";
782 auto* expect = R"(
783 struct S {
784 a : i32;
785 i : atomic<i32>;
786 b : f32;
787 u : atomic<u32>;
788 c : u32;
789 };
790
791 var<workgroup> w : array<S, 4>;
792
793 [[stage(compute), workgroup_size(1)]]
794 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
795 for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
796 let i_1 : u32 = idx;
797 w[i_1].a = i32();
798 atomicStore(&(w[i_1].i), i32());
799 w[i_1].b = f32();
800 atomicStore(&(w[i_1].u), u32());
801 w[i_1].c = u32();
802 }
803 workgroupBarrier();
804 ignore(w);
805 }
806 )";
807
808 auto got = Run<ZeroInitWorkgroupMemory>(src);
809
810 EXPECT_EQ(expect, str(got));
811 }
812
813 } // namespace
814 } // namespace transform
815 } // namespace tint
816