1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/calculate_array_length.h"
16
17 #include "src/transform/simplify_pointers.h"
18 #include "src/transform/test_helper.h"
19 #include "src/transform/unshadow.h"
20
21 namespace tint {
22 namespace transform {
23 namespace {
24
25 using CalculateArrayLengthTest = TransformTest;
26
TEST_F(CalculateArrayLengthTest,Error_MissingCalculateArrayLength)27 TEST_F(CalculateArrayLengthTest, Error_MissingCalculateArrayLength) {
28 auto* src = "";
29
30 auto* expect =
31 "error: tint::transform::CalculateArrayLength depends on "
32 "tint::transform::SimplifyPointers but the dependency was not run";
33
34 auto got = Run<CalculateArrayLength>(src);
35
36 EXPECT_EQ(expect, str(got));
37 }
38
TEST_F(CalculateArrayLengthTest,Basic)39 TEST_F(CalculateArrayLengthTest, Basic) {
40 auto* src = R"(
41 [[block]]
42 struct SB {
43 x : i32;
44 arr : array<i32>;
45 };
46
47 [[group(0), binding(0)]] var<storage, read> sb : SB;
48
49 [[stage(compute), workgroup_size(1)]]
50 fn main() {
51 var len : u32 = arrayLength(&sb.arr);
52 }
53 )";
54
55 auto* expect = R"(
56 [[block]]
57 struct SB {
58 x : i32;
59 arr : array<i32>;
60 };
61
62 [[internal(intrinsic_buffer_size)]]
63 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, result : ptr<function, u32>)
64
65 [[group(0), binding(0)]] var<storage, read> sb : SB;
66
67 [[stage(compute), workgroup_size(1)]]
68 fn main() {
69 var tint_symbol_1 : u32 = 0u;
70 tint_symbol(sb, &(tint_symbol_1));
71 let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
72 var len : u32 = tint_symbol_2;
73 }
74 )";
75
76 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
77
78 EXPECT_EQ(expect, str(got));
79 }
80
TEST_F(CalculateArrayLengthTest,InSameBlock)81 TEST_F(CalculateArrayLengthTest, InSameBlock) {
82 auto* src = R"(
83 [[block]]
84 struct SB {
85 x : i32;
86 arr : array<i32>;
87 };
88
89 [[group(0), binding(0)]] var<storage, read> sb : SB;
90
91 [[stage(compute), workgroup_size(1)]]
92 fn main() {
93 var a : u32 = arrayLength(&sb.arr);
94 var b : u32 = arrayLength(&sb.arr);
95 var c : u32 = arrayLength(&sb.arr);
96 }
97 )";
98
99 auto* expect = R"(
100 [[block]]
101 struct SB {
102 x : i32;
103 arr : array<i32>;
104 };
105
106 [[internal(intrinsic_buffer_size)]]
107 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, result : ptr<function, u32>)
108
109 [[group(0), binding(0)]] var<storage, read> sb : SB;
110
111 [[stage(compute), workgroup_size(1)]]
112 fn main() {
113 var tint_symbol_1 : u32 = 0u;
114 tint_symbol(sb, &(tint_symbol_1));
115 let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
116 var a : u32 = tint_symbol_2;
117 var b : u32 = tint_symbol_2;
118 var c : u32 = tint_symbol_2;
119 }
120 )";
121
122 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
123
124 EXPECT_EQ(expect, str(got));
125 }
126
TEST_F(CalculateArrayLengthTest,WithStride)127 TEST_F(CalculateArrayLengthTest, WithStride) {
128 auto* src = R"(
129 [[block]]
130 struct SB {
131 x : i32;
132 y : f32;
133 arr : [[stride(64)]] array<i32>;
134 };
135
136 [[group(0), binding(0)]] var<storage, read> sb : SB;
137
138 [[stage(compute), workgroup_size(1)]]
139 fn main() {
140 var len : u32 = arrayLength(&sb.arr);
141 }
142 )";
143
144 auto* expect = R"(
145 [[block]]
146 struct SB {
147 x : i32;
148 y : f32;
149 arr : [[stride(64)]] array<i32>;
150 };
151
152 [[internal(intrinsic_buffer_size)]]
153 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, result : ptr<function, u32>)
154
155 [[group(0), binding(0)]] var<storage, read> sb : SB;
156
157 [[stage(compute), workgroup_size(1)]]
158 fn main() {
159 var tint_symbol_1 : u32 = 0u;
160 tint_symbol(sb, &(tint_symbol_1));
161 let tint_symbol_2 : u32 = ((tint_symbol_1 - 8u) / 64u);
162 var len : u32 = tint_symbol_2;
163 }
164 )";
165
166 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
167
168 EXPECT_EQ(expect, str(got));
169 }
170
TEST_F(CalculateArrayLengthTest,Nested)171 TEST_F(CalculateArrayLengthTest, Nested) {
172 auto* src = R"(
173 [[block]]
174 struct SB {
175 x : i32;
176 arr : array<i32>;
177 };
178
179 [[group(0), binding(0)]] var<storage, read> sb : SB;
180
181 [[stage(compute), workgroup_size(1)]]
182 fn main() {
183 if (true) {
184 var len : u32 = arrayLength(&sb.arr);
185 } else {
186 if (true) {
187 var len : u32 = arrayLength(&sb.arr);
188 }
189 }
190 }
191 )";
192
193 auto* expect = R"(
194 [[block]]
195 struct SB {
196 x : i32;
197 arr : array<i32>;
198 };
199
200 [[internal(intrinsic_buffer_size)]]
201 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, result : ptr<function, u32>)
202
203 [[group(0), binding(0)]] var<storage, read> sb : SB;
204
205 [[stage(compute), workgroup_size(1)]]
206 fn main() {
207 if (true) {
208 var tint_symbol_1 : u32 = 0u;
209 tint_symbol(sb, &(tint_symbol_1));
210 let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
211 var len : u32 = tint_symbol_2;
212 } else {
213 if (true) {
214 var tint_symbol_3 : u32 = 0u;
215 tint_symbol(sb, &(tint_symbol_3));
216 let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
217 var len : u32 = tint_symbol_4;
218 }
219 }
220 }
221 )";
222
223 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
224
225 EXPECT_EQ(expect, str(got));
226 }
227
TEST_F(CalculateArrayLengthTest,MultipleStorageBuffers)228 TEST_F(CalculateArrayLengthTest, MultipleStorageBuffers) {
229 auto* src = R"(
230 [[block]]
231 struct SB1 {
232 x : i32;
233 arr1 : array<i32>;
234 };
235
236 [[block]]
237 struct SB2 {
238 x : i32;
239 arr2 : array<vec4<f32>>;
240 };
241
242 [[group(0), binding(0)]] var<storage, read> sb1 : SB1;
243
244 [[group(0), binding(1)]] var<storage, read> sb2 : SB2;
245
246 [[stage(compute), workgroup_size(1)]]
247 fn main() {
248 var len1 : u32 = arrayLength(&(sb1.arr1));
249 var len2 : u32 = arrayLength(&(sb2.arr2));
250 var x : u32 = (len1 + len2);
251 }
252 )";
253
254 auto* expect = R"(
255 [[block]]
256 struct SB1 {
257 x : i32;
258 arr1 : array<i32>;
259 };
260
261 [[internal(intrinsic_buffer_size)]]
262 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB1, result : ptr<function, u32>)
263
264 [[block]]
265 struct SB2 {
266 x : i32;
267 arr2 : array<vec4<f32>>;
268 };
269
270 [[internal(intrinsic_buffer_size)]]
271 fn tint_symbol_3([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB2, result : ptr<function, u32>)
272
273 [[group(0), binding(0)]] var<storage, read> sb1 : SB1;
274
275 [[group(0), binding(1)]] var<storage, read> sb2 : SB2;
276
277 [[stage(compute), workgroup_size(1)]]
278 fn main() {
279 var tint_symbol_1 : u32 = 0u;
280 tint_symbol(sb1, &(tint_symbol_1));
281 let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
282 var tint_symbol_4 : u32 = 0u;
283 tint_symbol_3(sb2, &(tint_symbol_4));
284 let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
285 var len1 : u32 = tint_symbol_2;
286 var len2 : u32 = tint_symbol_5;
287 var x : u32 = (len1 + len2);
288 }
289 )";
290
291 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
292
293 EXPECT_EQ(expect, str(got));
294 }
295
TEST_F(CalculateArrayLengthTest,Shadowing)296 TEST_F(CalculateArrayLengthTest, Shadowing) {
297 auto* src = R"(
298 [[block]]
299 struct SB {
300 x : i32;
301 arr : array<i32>;
302 };
303
304 [[group(0), binding(0)]] var<storage, read> a : SB;
305 [[group(0), binding(1)]] var<storage, read> b : SB;
306
307 [[stage(compute), workgroup_size(1)]]
308 fn main() {
309 let x = &a;
310 var a : u32 = arrayLength(&a.arr);
311 {
312 var b : u32 = arrayLength(&((*x).arr));
313 }
314 }
315 )";
316
317 auto* expect =
318 R"(
319 [[block]]
320 struct SB {
321 x : i32;
322 arr : array<i32>;
323 };
324
325 [[internal(intrinsic_buffer_size)]]
326 fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, result : ptr<function, u32>)
327
328 [[group(0), binding(0)]] var<storage, read> a : SB;
329
330 [[group(0), binding(1)]] var<storage, read> b : SB;
331
332 [[stage(compute), workgroup_size(1)]]
333 fn main() {
334 var tint_symbol_1 : u32 = 0u;
335 tint_symbol(a, &(tint_symbol_1));
336 let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
337 var a_1 : u32 = tint_symbol_2;
338 {
339 var tint_symbol_3 : u32 = 0u;
340 tint_symbol(a, &(tint_symbol_3));
341 let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
342 var b_1 : u32 = tint_symbol_4;
343 }
344 }
345 )";
346
347 auto got = Run<Unshadow, SimplifyPointers, CalculateArrayLength>(src);
348
349 EXPECT_EQ(expect, str(got));
350 }
351
352 } // namespace
353 } // namespace transform
354 } // namespace tint
355