• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/zero_init_workgroup_memory.h"
16 
17 #include <utility>
18 
19 #include "src/transform/test_helper.h"
20 
21 namespace tint {
22 namespace transform {
23 namespace {
24 
25 using ZeroInitWorkgroupMemoryTest = TransformTest;
26 
TEST_F(ZeroInitWorkgroupMemoryTest,EmptyModule)27 TEST_F(ZeroInitWorkgroupMemoryTest, EmptyModule) {
28   auto* src = "";
29   auto* expect = src;
30 
31   auto got = Run<ZeroInitWorkgroupMemory>(src);
32 
33   EXPECT_EQ(expect, str(got));
34 }
35 
TEST_F(ZeroInitWorkgroupMemoryTest,NoWorkgroupVars)36 TEST_F(ZeroInitWorkgroupMemoryTest, NoWorkgroupVars) {
37   auto* src = R"(
38 var<private> v : i32;
39 
40 fn f() {
41   v = 1;
42 }
43 )";
44   auto* expect = src;
45 
46   auto got = Run<ZeroInitWorkgroupMemory>(src);
47 
48   EXPECT_EQ(expect, str(got));
49 }
50 
TEST_F(ZeroInitWorkgroupMemoryTest,UnreferencedWorkgroupVars)51 TEST_F(ZeroInitWorkgroupMemoryTest, UnreferencedWorkgroupVars) {
52   auto* src = R"(
53 var<workgroup> a : i32;
54 
55 var<workgroup> b : i32;
56 
57 var<workgroup> c : i32;
58 
59 fn unreferenced() {
60   b = c;
61 }
62 
63 [[stage(compute), workgroup_size(1)]]
64 fn f() {
65 }
66 )";
67   auto* expect = src;
68 
69   auto got = Run<ZeroInitWorkgroupMemory>(src);
70 
71   EXPECT_EQ(expect, str(got));
72 }
73 
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_ExistingLocalIndex)74 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_ExistingLocalIndex) {
75   auto* src = R"(
76 var<workgroup> v : i32;
77 
78 [[stage(compute), workgroup_size(1)]]
79 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
80   ignore(v); // Initialization should be inserted above this statement
81 }
82 )";
83   auto* expect = R"(
84 var<workgroup> v : i32;
85 
86 [[stage(compute), workgroup_size(1)]]
87 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
88   {
89     v = i32();
90   }
91   workgroupBarrier();
92   ignore(v);
93 }
94 )";
95 
96   auto got = Run<ZeroInitWorkgroupMemory>(src);
97 
98   EXPECT_EQ(expect, str(got));
99 }
100 
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_ExistingLocalIndexInStruct)101 TEST_F(ZeroInitWorkgroupMemoryTest,
102        SingleWorkgroupVar_ExistingLocalIndexInStruct) {
103   auto* src = R"(
104 var<workgroup> v : i32;
105 
106 struct Params {
107   [[builtin(local_invocation_index)]] local_idx : u32;
108 };
109 
110 [[stage(compute), workgroup_size(1)]]
111 fn f(params : Params) {
112   ignore(v); // Initialization should be inserted above this statement
113 }
114 )";
115   auto* expect = R"(
116 var<workgroup> v : i32;
117 
118 struct Params {
119   [[builtin(local_invocation_index)]]
120   local_idx : u32;
121 };
122 
123 [[stage(compute), workgroup_size(1)]]
124 fn f(params : Params) {
125   {
126     v = i32();
127   }
128   workgroupBarrier();
129   ignore(v);
130 }
131 )";
132 
133   auto got = Run<ZeroInitWorkgroupMemory>(src);
134 
135   EXPECT_EQ(expect, str(got));
136 }
137 
TEST_F(ZeroInitWorkgroupMemoryTest,SingleWorkgroupVar_InjectedLocalIndex)138 TEST_F(ZeroInitWorkgroupMemoryTest, SingleWorkgroupVar_InjectedLocalIndex) {
139   auto* src = R"(
140 var<workgroup> v : i32;
141 
142 [[stage(compute), workgroup_size(1)]]
143 fn f() {
144   ignore(v); // Initialization should be inserted above this statement
145 }
146 )";
147   auto* expect = R"(
148 var<workgroup> v : i32;
149 
150 [[stage(compute), workgroup_size(1)]]
151 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
152   {
153     v = i32();
154   }
155   workgroupBarrier();
156   ignore(v);
157 }
158 )";
159 
160   auto got = Run<ZeroInitWorkgroupMemory>(src);
161 
162   EXPECT_EQ(expect, str(got));
163 }
164 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size1)165 TEST_F(ZeroInitWorkgroupMemoryTest,
166        MultipleWorkgroupVar_ExistingLocalIndex_Size1) {
167   auto* src = R"(
168 struct S {
169   x : i32;
170   y : array<i32, 8>;
171 };
172 
173 var<workgroup> a : i32;
174 
175 var<workgroup> b : S;
176 
177 var<workgroup> c : array<S, 32>;
178 
179 [[stage(compute), workgroup_size(1)]]
180 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
181   ignore(a); // Initialization should be inserted above this statement
182   ignore(b);
183   ignore(c);
184 }
185 )";
186   auto* expect = R"(
187 struct S {
188   x : i32;
189   y : array<i32, 8>;
190 };
191 
192 var<workgroup> a : i32;
193 
194 var<workgroup> b : S;
195 
196 var<workgroup> c : array<S, 32>;
197 
198 [[stage(compute), workgroup_size(1)]]
199 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
200   {
201     a = i32();
202     b.x = i32();
203   }
204   for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 1u)) {
205     let i : u32 = idx;
206     b.y[i] = i32();
207   }
208   for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
209     let i_1 : u32 = idx_1;
210     c[i_1].x = i32();
211   }
212   for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
213     let i_2 : u32 = (idx_2 / 8u);
214     let i : u32 = (idx_2 % 8u);
215     c[i_2].y[i] = i32();
216   }
217   workgroupBarrier();
218   ignore(a);
219   ignore(b);
220   ignore(c);
221 }
222 )";
223 
224   auto got = Run<ZeroInitWorkgroupMemory>(src);
225 
226   EXPECT_EQ(expect, str(got));
227 }
228 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3)229 TEST_F(ZeroInitWorkgroupMemoryTest,
230        MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3) {
231   auto* src = R"(
232 struct S {
233   x : i32;
234   y : array<i32, 8>;
235 };
236 
237 var<workgroup> a : i32;
238 
239 var<workgroup> b : S;
240 
241 var<workgroup> c : array<S, 32>;
242 
243 [[stage(compute), workgroup_size(2, 3)]]
244 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
245   ignore(a); // Initialization should be inserted above this statement
246   ignore(b);
247   ignore(c);
248 }
249 )";
250   auto* expect = R"(
251 struct S {
252   x : i32;
253   y : array<i32, 8>;
254 };
255 
256 var<workgroup> a : i32;
257 
258 var<workgroup> b : S;
259 
260 var<workgroup> c : array<S, 32>;
261 
262 [[stage(compute), workgroup_size(2, 3)]]
263 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
264   if ((local_idx < 1u)) {
265     a = i32();
266     b.x = i32();
267   }
268   for(var idx : u32 = local_idx; (idx < 8u); idx = (idx + 6u)) {
269     let i : u32 = idx;
270     b.y[i] = i32();
271   }
272   for(var idx_1 : u32 = local_idx; (idx_1 < 32u); idx_1 = (idx_1 + 6u)) {
273     let i_1 : u32 = idx_1;
274     c[i_1].x = i32();
275   }
276   for(var idx_2 : u32 = local_idx; (idx_2 < 256u); idx_2 = (idx_2 + 6u)) {
277     let i_2 : u32 = (idx_2 / 8u);
278     let i : u32 = (idx_2 % 8u);
279     c[i_2].y[i] = i32();
280   }
281   workgroupBarrier();
282   ignore(a);
283   ignore(b);
284   ignore(c);
285 }
286 )";
287 
288   auto got = Run<ZeroInitWorkgroupMemory>(src);
289 
290   EXPECT_EQ(expect, str(got));
291 }
292 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X)293 TEST_F(ZeroInitWorkgroupMemoryTest,
294        MultipleWorkgroupVar_ExistingLocalIndex_Size_2_3_X) {
295   auto* src = R"(
296 struct S {
297   x : i32;
298   y : array<i32, 8>;
299 };
300 
301 var<workgroup> a : i32;
302 
303 var<workgroup> b : S;
304 
305 var<workgroup> c : array<S, 32>;
306 
307 [[override(1)]] let X : i32;
308 
309 [[stage(compute), workgroup_size(2, 3, X)]]
310 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
311   ignore(a); // Initialization should be inserted above this statement
312   ignore(b);
313   ignore(c);
314 }
315 )";
316   auto* expect =
317       R"(
318 struct S {
319   x : i32;
320   y : array<i32, 8>;
321 };
322 
323 var<workgroup> a : i32;
324 
325 var<workgroup> b : S;
326 
327 var<workgroup> c : array<S, 32>;
328 
329 [[override(1)]] let X : i32;
330 
331 [[stage(compute), workgroup_size(2, 3, X)]]
332 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
333   for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (u32(X) * 6u))) {
334     a = i32();
335     b.x = i32();
336   }
337   for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (u32(X) * 6u))) {
338     let i : u32 = idx_1;
339     b.y[i] = i32();
340   }
341   for(var idx_2 : u32 = local_idx; (idx_2 < 32u); idx_2 = (idx_2 + (u32(X) * 6u))) {
342     let i_1 : u32 = idx_2;
343     c[i_1].x = i32();
344   }
345   for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (u32(X) * 6u))) {
346     let i_2 : u32 = (idx_3 / 8u);
347     let i : u32 = (idx_3 % 8u);
348     c[i_2].y[i] = i32();
349   }
350   workgroupBarrier();
351   ignore(a);
352   ignore(b);
353   ignore(c);
354 }
355 )";
356 
357   auto got = Run<ZeroInitWorkgroupMemory>(src);
358 
359   EXPECT_EQ(expect, str(got));
360 }
361 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u)362 TEST_F(ZeroInitWorkgroupMemoryTest,
363        MultipleWorkgroupVar_ExistingLocalIndex_Size_5u_X_10u) {
364   auto* src = R"(
365 struct S {
366   x : array<array<i32, 8>, 10>;
367   y : array<i32, 8>;
368   z : array<array<array<i32, 8>, 10>, 20>;
369 };
370 
371 var<workgroup> a : i32;
372 
373 var<workgroup> b : S;
374 
375 var<workgroup> c : array<S, 32>;
376 
377 [[override(1)]] let X : u32;
378 
379 [[stage(compute), workgroup_size(5u, X, 10u)]]
380 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
381   ignore(a); // Initialization should be inserted above this statement
382   ignore(b);
383   ignore(c);
384 }
385 )";
386   auto* expect =
387       R"(
388 struct S {
389   x : array<array<i32, 8>, 10>;
390   y : array<i32, 8>;
391   z : array<array<array<i32, 8>, 10>, 20>;
392 };
393 
394 var<workgroup> a : i32;
395 
396 var<workgroup> b : S;
397 
398 var<workgroup> c : array<S, 32>;
399 
400 [[override(1)]] let X : u32;
401 
402 [[stage(compute), workgroup_size(5u, X, 10u)]]
403 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
404   for(var idx : u32 = local_idx; (idx < 1u); idx = (idx + (X * 50u))) {
405     a = i32();
406   }
407   for(var idx_1 : u32 = local_idx; (idx_1 < 8u); idx_1 = (idx_1 + (X * 50u))) {
408     let i_1 : u32 = idx_1;
409     b.y[i_1] = i32();
410   }
411   for(var idx_2 : u32 = local_idx; (idx_2 < 80u); idx_2 = (idx_2 + (X * 50u))) {
412     let i : u32 = (idx_2 / 8u);
413     let i_1 : u32 = (idx_2 % 8u);
414     b.x[i][i_1] = i32();
415   }
416   for(var idx_3 : u32 = local_idx; (idx_3 < 256u); idx_3 = (idx_3 + (X * 50u))) {
417     let i_4 : u32 = (idx_3 / 8u);
418     let i_1 : u32 = (idx_3 % 8u);
419     c[i_4].y[i_1] = i32();
420   }
421   for(var idx_4 : u32 = local_idx; (idx_4 < 1600u); idx_4 = (idx_4 + (X * 50u))) {
422     let i_2 : u32 = (idx_4 / 80u);
423     let i : u32 = ((idx_4 % 80u) / 8u);
424     let i_1 : u32 = (idx_4 % 8u);
425     b.z[i_2][i][i_1] = i32();
426   }
427   for(var idx_5 : u32 = local_idx; (idx_5 < 2560u); idx_5 = (idx_5 + (X * 50u))) {
428     let i_3 : u32 = (idx_5 / 80u);
429     let i : u32 = ((idx_5 % 80u) / 8u);
430     let i_1 : u32 = (idx_5 % 8u);
431     c[i_3].x[i][i_1] = i32();
432   }
433   for(var idx_6 : u32 = local_idx; (idx_6 < 51200u); idx_6 = (idx_6 + (X * 50u))) {
434     let i_5 : u32 = (idx_6 / 1600u);
435     let i_2 : u32 = ((idx_6 % 1600u) / 80u);
436     let i : u32 = ((idx_6 % 80u) / 8u);
437     let i_1 : u32 = (idx_6 % 8u);
438     c[i_5].z[i_2][i][i_1] = i32();
439   }
440   workgroupBarrier();
441   ignore(a);
442   ignore(b);
443   ignore(c);
444 }
445 )";
446 
447   auto got = Run<ZeroInitWorkgroupMemory>(src);
448 
449   EXPECT_EQ(expect, str(got));
450 }
451 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_InjectedLocalIndex)452 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_InjectedLocalIndex) {
453   auto* src = R"(
454 struct S {
455   x : i32;
456   y : array<i32, 8>;
457 };
458 
459 var<workgroup> a : i32;
460 
461 var<workgroup> b : S;
462 
463 var<workgroup> c : array<S, 32>;
464 
465 [[stage(compute), workgroup_size(1)]]
466 fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
467   ignore(a); // Initialization should be inserted above this statement
468   ignore(b);
469   ignore(c);
470 }
471 )";
472   auto* expect = R"(
473 struct S {
474   x : i32;
475   y : array<i32, 8>;
476 };
477 
478 var<workgroup> a : i32;
479 
480 var<workgroup> b : S;
481 
482 var<workgroup> c : array<S, 32>;
483 
484 [[stage(compute), workgroup_size(1)]]
485 fn f([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index : u32) {
486   {
487     a = i32();
488     b.x = i32();
489   }
490   for(var idx : u32 = local_invocation_index; (idx < 8u); idx = (idx + 1u)) {
491     let i : u32 = idx;
492     b.y[i] = i32();
493   }
494   for(var idx_1 : u32 = local_invocation_index; (idx_1 < 32u); idx_1 = (idx_1 + 1u)) {
495     let i_1 : u32 = idx_1;
496     c[i_1].x = i32();
497   }
498   for(var idx_2 : u32 = local_invocation_index; (idx_2 < 256u); idx_2 = (idx_2 + 1u)) {
499     let i_2 : u32 = (idx_2 / 8u);
500     let i : u32 = (idx_2 % 8u);
501     c[i_2].y[i] = i32();
502   }
503   workgroupBarrier();
504   ignore(a);
505   ignore(b);
506   ignore(c);
507 }
508 )";
509 
510   auto got = Run<ZeroInitWorkgroupMemory>(src);
511 
512   EXPECT_EQ(expect, str(got));
513 }
514 
TEST_F(ZeroInitWorkgroupMemoryTest,MultipleWorkgroupVar_MultipleEntryPoints)515 TEST_F(ZeroInitWorkgroupMemoryTest, MultipleWorkgroupVar_MultipleEntryPoints) {
516   auto* src = R"(
517 struct S {
518   x : i32;
519   y : array<i32, 8>;
520 };
521 
522 var<workgroup> a : i32;
523 
524 var<workgroup> b : S;
525 
526 var<workgroup> c : array<S, 32>;
527 
528 [[stage(compute), workgroup_size(1)]]
529 fn f1() {
530   ignore(a); // Initialization should be inserted above this statement
531   ignore(c);
532 }
533 
534 [[stage(compute), workgroup_size(1, 2, 3)]]
535 fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>) {
536   ignore(b); // Initialization should be inserted above this statement
537 }
538 
539 [[stage(compute), workgroup_size(4, 5, 6)]]
540 fn f3() {
541   ignore(c); // Initialization should be inserted above this statement
542   ignore(a);
543 }
544 )";
545   auto* expect = R"(
546 struct S {
547   x : i32;
548   y : array<i32, 8>;
549 };
550 
551 var<workgroup> a : i32;
552 
553 var<workgroup> b : S;
554 
555 var<workgroup> c : array<S, 32>;
556 
557 [[stage(compute), workgroup_size(1)]]
558 fn f1([[builtin(local_invocation_index)]] local_invocation_index : u32) {
559   {
560     a = i32();
561   }
562   for(var idx : u32 = local_invocation_index; (idx < 32u); idx = (idx + 1u)) {
563     let i : u32 = idx;
564     c[i].x = i32();
565   }
566   for(var idx_1 : u32 = local_invocation_index; (idx_1 < 256u); idx_1 = (idx_1 + 1u)) {
567     let i_1 : u32 = (idx_1 / 8u);
568     let i_2 : u32 = (idx_1 % 8u);
569     c[i_1].y[i_2] = i32();
570   }
571   workgroupBarrier();
572   ignore(a);
573   ignore(c);
574 }
575 
576 [[stage(compute), workgroup_size(1, 2, 3)]]
577 fn f2([[builtin(local_invocation_id)]] local_invocation_id : vec3<u32>, [[builtin(local_invocation_index)]] local_invocation_index_1 : u32) {
578   if ((local_invocation_index_1 < 1u)) {
579     b.x = i32();
580   }
581   for(var idx_2 : u32 = local_invocation_index_1; (idx_2 < 8u); idx_2 = (idx_2 + 6u)) {
582     let i_3 : u32 = idx_2;
583     b.y[i_3] = i32();
584   }
585   workgroupBarrier();
586   ignore(b);
587 }
588 
589 [[stage(compute), workgroup_size(4, 5, 6)]]
590 fn f3([[builtin(local_invocation_index)]] local_invocation_index_2 : u32) {
591   if ((local_invocation_index_2 < 1u)) {
592     a = i32();
593   }
594   if ((local_invocation_index_2 < 32u)) {
595     let i_4 : u32 = local_invocation_index_2;
596     c[i_4].x = i32();
597   }
598   for(var idx_3 : u32 = local_invocation_index_2; (idx_3 < 256u); idx_3 = (idx_3 + 120u)) {
599     let i_5 : u32 = (idx_3 / 8u);
600     let i_6 : u32 = (idx_3 % 8u);
601     c[i_5].y[i_6] = i32();
602   }
603   workgroupBarrier();
604   ignore(c);
605   ignore(a);
606 }
607 )";
608 
609   auto got = Run<ZeroInitWorkgroupMemory>(src);
610 
611   EXPECT_EQ(expect, str(got));
612 }
613 
TEST_F(ZeroInitWorkgroupMemoryTest,TransitiveUsage)614 TEST_F(ZeroInitWorkgroupMemoryTest, TransitiveUsage) {
615   auto* src = R"(
616 var<workgroup> v : i32;
617 
618 fn use_v() {
619   ignore(v);
620 }
621 
622 fn call_use_v() {
623   use_v();
624 }
625 
626 [[stage(compute), workgroup_size(1)]]
627 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
628   call_use_v(); // Initialization should be inserted above this statement
629 }
630 )";
631   auto* expect = R"(
632 var<workgroup> v : i32;
633 
634 fn use_v() {
635   ignore(v);
636 }
637 
638 fn call_use_v() {
639   use_v();
640 }
641 
642 [[stage(compute), workgroup_size(1)]]
643 fn f([[builtin(local_invocation_index)]] local_idx : u32) {
644   {
645     v = i32();
646   }
647   workgroupBarrier();
648   call_use_v();
649 }
650 )";
651 
652   auto got = Run<ZeroInitWorkgroupMemory>(src);
653 
654   EXPECT_EQ(expect, str(got));
655 }
656 
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupAtomics)657 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupAtomics) {
658   auto* src = R"(
659 var<workgroup> i : atomic<i32>;
660 var<workgroup> u : atomic<u32>;
661 
662 [[stage(compute), workgroup_size(1)]]
663 fn f() {
664   ignore(i); // Initialization should be inserted above this statement
665   ignore(u);
666 }
667 )";
668   auto* expect = R"(
669 var<workgroup> i : atomic<i32>;
670 
671 var<workgroup> u : atomic<u32>;
672 
673 [[stage(compute), workgroup_size(1)]]
674 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
675   {
676     atomicStore(&(i), i32());
677     atomicStore(&(u), u32());
678   }
679   workgroupBarrier();
680   ignore(i);
681   ignore(u);
682 }
683 )";
684 
685   auto got = Run<ZeroInitWorkgroupMemory>(src);
686 
687   EXPECT_EQ(expect, str(got));
688 }
689 
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupStructOfAtomics)690 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupStructOfAtomics) {
691   auto* src = R"(
692 struct S {
693   a : i32;
694   i : atomic<i32>;
695   b : f32;
696   u : atomic<u32>;
697   c : u32;
698 };
699 
700 var<workgroup> w : S;
701 
702 [[stage(compute), workgroup_size(1)]]
703 fn f() {
704   ignore(w); // Initialization should be inserted above this statement
705 }
706 )";
707   auto* expect = R"(
708 struct S {
709   a : i32;
710   i : atomic<i32>;
711   b : f32;
712   u : atomic<u32>;
713   c : u32;
714 };
715 
716 var<workgroup> w : S;
717 
718 [[stage(compute), workgroup_size(1)]]
719 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
720   {
721     w.a = i32();
722     atomicStore(&(w.i), i32());
723     w.b = f32();
724     atomicStore(&(w.u), u32());
725     w.c = u32();
726   }
727   workgroupBarrier();
728   ignore(w);
729 }
730 )";
731 
732   auto got = Run<ZeroInitWorkgroupMemory>(src);
733 
734   EXPECT_EQ(expect, str(got));
735 }
736 
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupArrayOfAtomics)737 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfAtomics) {
738   auto* src = R"(
739 var<workgroup> w : array<atomic<u32>, 4>;
740 
741 [[stage(compute), workgroup_size(1)]]
742 fn f() {
743   ignore(w); // Initialization should be inserted above this statement
744 }
745 )";
746   auto* expect = R"(
747 var<workgroup> w : array<atomic<u32>, 4>;
748 
749 [[stage(compute), workgroup_size(1)]]
750 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
751   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
752     let i : u32 = idx;
753     atomicStore(&(w[i]), u32());
754   }
755   workgroupBarrier();
756   ignore(w);
757 }
758 )";
759 
760   auto got = Run<ZeroInitWorkgroupMemory>(src);
761 
762   EXPECT_EQ(expect, str(got));
763 }
764 
TEST_F(ZeroInitWorkgroupMemoryTest,WorkgroupArrayOfStructOfAtomics)765 TEST_F(ZeroInitWorkgroupMemoryTest, WorkgroupArrayOfStructOfAtomics) {
766   auto* src = R"(
767 struct S {
768   a : i32;
769   i : atomic<i32>;
770   b : f32;
771   u : atomic<u32>;
772   c : u32;
773 };
774 
775 var<workgroup> w : array<S, 4>;
776 
777 [[stage(compute), workgroup_size(1)]]
778 fn f() {
779   ignore(w); // Initialization should be inserted above this statement
780 }
781 )";
782   auto* expect = R"(
783 struct S {
784   a : i32;
785   i : atomic<i32>;
786   b : f32;
787   u : atomic<u32>;
788   c : u32;
789 };
790 
791 var<workgroup> w : array<S, 4>;
792 
793 [[stage(compute), workgroup_size(1)]]
794 fn f([[builtin(local_invocation_index)]] local_invocation_index : u32) {
795   for(var idx : u32 = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
796     let i_1 : u32 = idx;
797     w[i_1].a = i32();
798     atomicStore(&(w[i_1].i), i32());
799     w[i_1].b = f32();
800     atomicStore(&(w[i_1].u), u32());
801     w[i_1].c = u32();
802   }
803   workgroupBarrier();
804   ignore(w);
805 }
806 )";
807 
808   auto got = Run<ZeroInitWorkgroupMemory>(src);
809 
810   EXPECT_EQ(expect, str(got));
811 }
812 
813 }  // namespace
814 }  // namespace transform
815 }  // namespace tint
816