1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/ast/fallthrough_statement.h"
16 #include "src/writer/spirv/spv_dump.h"
17 #include "src/writer/spirv/test_helper.h"
18
19 namespace tint {
20 namespace writer {
21 namespace spirv {
22 namespace {
23
24 using BuilderTest = TestHelper;
25
TEST_F(BuilderTest,Switch_Empty)26 TEST_F(BuilderTest, Switch_Empty) {
27 // switch (1) {
28 // default: {}
29 // }
30
31 auto* expr = Switch(1, DefaultCase());
32 WrapInFunction(expr);
33
34 spirv::Builder& b = Build();
35
36 b.push_function(Function{});
37
38 EXPECT_TRUE(b.GenerateSwitchStatement(expr)) << b.error();
39 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
40 %3 = OpConstant %2 1
41 )");
42 EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
43 R"(OpSelectionMerge %1 None
44 OpSwitch %3 %4
45 %4 = OpLabel
46 OpBranch %1
47 %1 = OpLabel
48 )");
49 }
50
TEST_F(BuilderTest,Switch_WithCase)51 TEST_F(BuilderTest, Switch_WithCase) {
52 // switch(a) {
53 // case 1:
54 // v = 1;
55 // case 2:
56 // v = 2;
57 // default: {}
58 // }
59
60 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
61 auto* a = Global("a", ty.i32(), ast::StorageClass::kPrivate);
62
63 auto* func = Func("a_func", {}, ty.void_(),
64 {
65 Switch("a", //
66 Case(Expr(1), Block(Assign("v", 1))), //
67 Case(Expr(2), Block(Assign("v", 2))), //
68 DefaultCase()),
69 });
70
71 spirv::Builder& b = Build();
72
73 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
74 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
75 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
76
77 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
78 OpName %5 "a"
79 OpName %8 "a_func"
80 %3 = OpTypeInt 32 1
81 %2 = OpTypePointer Private %3
82 %4 = OpConstantNull %3
83 %1 = OpVariable %2 Private %4
84 %5 = OpVariable %2 Private %4
85 %7 = OpTypeVoid
86 %6 = OpTypeFunction %7
87 %15 = OpConstant %3 1
88 %16 = OpConstant %3 2
89 %8 = OpFunction %7 None %6
90 %9 = OpLabel
91 %11 = OpLoad %3 %5
92 OpSelectionMerge %10 None
93 OpSwitch %11 %12 1 %13 2 %14
94 %13 = OpLabel
95 OpStore %1 %15
96 OpBranch %10
97 %14 = OpLabel
98 OpStore %1 %16
99 OpBranch %10
100 %12 = OpLabel
101 OpBranch %10
102 %10 = OpLabel
103 OpReturn
104 OpFunctionEnd
105 )");
106 }
107
TEST_F(BuilderTest,Switch_WithCase_Unsigned)108 TEST_F(BuilderTest, Switch_WithCase_Unsigned) {
109 // switch(a) {
110 // case 1u:
111 // v = 1;
112 // case 2u:
113 // v = 2;
114 // default: {}
115 // }
116
117 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
118 auto* a = Global("a", ty.u32(), ast::StorageClass::kPrivate);
119
120 auto* func = Func("a_func", {}, ty.void_(),
121 {
122 Switch("a", //
123 Case(Expr(1u), Block(Assign("v", 1))), //
124 Case(Expr(2u), Block(Assign("v", 2))), //
125 DefaultCase()),
126 });
127
128 spirv::Builder& b = Build();
129
130 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
131 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
132 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
133
134 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
135 OpName %5 "a"
136 OpName %11 "a_func"
137 %3 = OpTypeInt 32 1
138 %2 = OpTypePointer Private %3
139 %4 = OpConstantNull %3
140 %1 = OpVariable %2 Private %4
141 %7 = OpTypeInt 32 0
142 %6 = OpTypePointer Private %7
143 %8 = OpConstantNull %7
144 %5 = OpVariable %6 Private %8
145 %10 = OpTypeVoid
146 %9 = OpTypeFunction %10
147 %18 = OpConstant %3 1
148 %19 = OpConstant %3 2
149 %11 = OpFunction %10 None %9
150 %12 = OpLabel
151 %14 = OpLoad %7 %5
152 OpSelectionMerge %13 None
153 OpSwitch %14 %15 1 %16 2 %17
154 %16 = OpLabel
155 OpStore %1 %18
156 OpBranch %13
157 %17 = OpLabel
158 OpStore %1 %19
159 OpBranch %13
160 %15 = OpLabel
161 OpBranch %13
162 %13 = OpLabel
163 OpReturn
164 OpFunctionEnd
165 )");
166 }
167
TEST_F(BuilderTest,Switch_WithDefault)168 TEST_F(BuilderTest, Switch_WithDefault) {
169 // switch(true) {
170 // default: {}
171 // v = 1;
172 // }
173
174 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
175 auto* a = Global("a", ty.i32(), ast::StorageClass::kPrivate);
176
177 auto* func = Func("a_func", {}, ty.void_(),
178 {
179 Switch("a", //
180 DefaultCase(Block(Assign("v", 1)))), //
181 });
182
183 spirv::Builder& b = Build();
184
185 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
186 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
187 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
188
189 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
190 OpName %5 "a"
191 OpName %8 "a_func"
192 %3 = OpTypeInt 32 1
193 %2 = OpTypePointer Private %3
194 %4 = OpConstantNull %3
195 %1 = OpVariable %2 Private %4
196 %5 = OpVariable %2 Private %4
197 %7 = OpTypeVoid
198 %6 = OpTypeFunction %7
199 %13 = OpConstant %3 1
200 %8 = OpFunction %7 None %6
201 %9 = OpLabel
202 %11 = OpLoad %3 %5
203 OpSelectionMerge %10 None
204 OpSwitch %11 %12
205 %12 = OpLabel
206 OpStore %1 %13
207 OpBranch %10
208 %10 = OpLabel
209 OpReturn
210 OpFunctionEnd
211 )");
212 }
213
TEST_F(BuilderTest,Switch_WithCaseAndDefault)214 TEST_F(BuilderTest, Switch_WithCaseAndDefault) {
215 // switch(a) {
216 // case 1:
217 // v = 1;
218 // case 2, 3:
219 // v = 2;
220 // default: {}
221 // v = 3;
222 // }
223
224 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
225 auto* a = Global("a", ty.i32(), ast::StorageClass::kPrivate);
226
227 auto* func = Func("a_func", {}, ty.void_(),
228 {
229 Switch(Expr("a"), //
230 Case(Expr(1), //
231 Block(Assign("v", 1))), //
232 Case({Expr(2), Expr(3)}, //
233 Block(Assign("v", 2))), //
234 DefaultCase(Block(Assign("v", 3)))),
235 });
236
237 spirv::Builder& b = Build();
238
239 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
240 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
241 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
242
243 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
244 OpName %5 "a"
245 OpName %8 "a_func"
246 %3 = OpTypeInt 32 1
247 %2 = OpTypePointer Private %3
248 %4 = OpConstantNull %3
249 %1 = OpVariable %2 Private %4
250 %5 = OpVariable %2 Private %4
251 %7 = OpTypeVoid
252 %6 = OpTypeFunction %7
253 %15 = OpConstant %3 1
254 %16 = OpConstant %3 2
255 %17 = OpConstant %3 3
256 %8 = OpFunction %7 None %6
257 %9 = OpLabel
258 %11 = OpLoad %3 %5
259 OpSelectionMerge %10 None
260 OpSwitch %11 %12 1 %13 2 %14 3 %14
261 %13 = OpLabel
262 OpStore %1 %15
263 OpBranch %10
264 %14 = OpLabel
265 OpStore %1 %16
266 OpBranch %10
267 %12 = OpLabel
268 OpStore %1 %17
269 OpBranch %10
270 %10 = OpLabel
271 OpReturn
272 OpFunctionEnd
273 )");
274 }
275
TEST_F(BuilderTest,Switch_CaseWithFallthrough)276 TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
277 // switch(a) {
278 // case 1:
279 // v = 1;
280 // fallthrough;
281 // case 2:
282 // v = 2;
283 // default: {}
284 // v = 3;
285 // }
286
287 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
288 auto* a = Global("a", ty.i32(), ast::StorageClass::kPrivate);
289
290 auto* func = Func("a_func", {}, ty.void_(),
291 {
292 Switch(Expr("a"), //
293 Case(Expr(1), //
294 Block(Assign("v", 1), Fallthrough())), //
295 Case(Expr(2), //
296 Block(Assign("v", 2))), //
297 DefaultCase(Block(Assign("v", 3)))),
298 });
299
300 spirv::Builder& b = Build();
301
302 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
303 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
304 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
305
306 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
307 OpName %5 "a"
308 OpName %8 "a_func"
309 %3 = OpTypeInt 32 1
310 %2 = OpTypePointer Private %3
311 %4 = OpConstantNull %3
312 %1 = OpVariable %2 Private %4
313 %5 = OpVariable %2 Private %4
314 %7 = OpTypeVoid
315 %6 = OpTypeFunction %7
316 %15 = OpConstant %3 1
317 %16 = OpConstant %3 2
318 %17 = OpConstant %3 3
319 %8 = OpFunction %7 None %6
320 %9 = OpLabel
321 %11 = OpLoad %3 %5
322 OpSelectionMerge %10 None
323 OpSwitch %11 %12 1 %13 2 %14
324 %13 = OpLabel
325 OpStore %1 %15
326 OpBranch %14
327 %14 = OpLabel
328 OpStore %1 %16
329 OpBranch %10
330 %12 = OpLabel
331 OpStore %1 %17
332 OpBranch %10
333 %10 = OpLabel
334 OpReturn
335 OpFunctionEnd
336 )");
337 }
338
TEST_F(BuilderTest,Switch_WithNestedBreak)339 TEST_F(BuilderTest, Switch_WithNestedBreak) {
340 // switch (a) {
341 // case 1:
342 // if (true) {
343 // break;
344 // }
345 // v = 1;
346 // default: {}
347 // }
348
349 auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
350 auto* a = Global("a", ty.i32(), ast::StorageClass::kPrivate);
351
352 auto* func = Func(
353 "a_func", {}, ty.void_(),
354 {
355 Switch("a", //
356 Case(Expr(1), //
357 Block( //
358 If(Expr(true), Block(create<ast::BreakStatement>())),
359 Assign("v", 1))),
360 DefaultCase()),
361 });
362
363 spirv::Builder& b = Build();
364
365 ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
366 ASSERT_TRUE(b.GenerateGlobalVariable(a)) << b.error();
367 ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
368
369 EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "v"
370 OpName %5 "a"
371 OpName %8 "a_func"
372 %3 = OpTypeInt 32 1
373 %2 = OpTypePointer Private %3
374 %4 = OpConstantNull %3
375 %1 = OpVariable %2 Private %4
376 %5 = OpVariable %2 Private %4
377 %7 = OpTypeVoid
378 %6 = OpTypeFunction %7
379 %14 = OpTypeBool
380 %15 = OpConstantTrue %14
381 %18 = OpConstant %3 1
382 %8 = OpFunction %7 None %6
383 %9 = OpLabel
384 %11 = OpLoad %3 %5
385 OpSelectionMerge %10 None
386 OpSwitch %11 %12 1 %13
387 %13 = OpLabel
388 OpSelectionMerge %16 None
389 OpBranchConditional %15 %17 %16
390 %17 = OpLabel
391 OpBranch %10
392 %16 = OpLabel
393 OpStore %1 %18
394 OpBranch %10
395 %12 = OpLabel
396 OpBranch %10
397 %10 = OpLabel
398 OpReturn
399 OpFunctionEnd
400 )");
401 }
402
TEST_F(BuilderTest,Switch_AllReturn)403 TEST_F(BuilderTest, Switch_AllReturn) {
404 // switch (1) {
405 // case 1: {
406 // return 1;
407 // }
408 // case 2: {
409 // fallthrough;
410 // }
411 // default: {
412 // return 3;
413 // }
414 // }
415
416 auto* fn = Func("f", {}, ty.i32(),
417 {
418 Switch(1, //
419 Case(Expr(1), Block(Return(1))), //
420 Case(Expr(2), Block(Fallthrough())), //
421 DefaultCase(Block(Return(3)))),
422 });
423
424 spirv::Builder& b = Build();
425
426 EXPECT_TRUE(b.GenerateFunction(fn)) << b.error();
427 EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "f"
428 %2 = OpTypeInt 32 1
429 %1 = OpTypeFunction %2
430 %6 = OpConstant %2 1
431 %10 = OpConstant %2 3
432 %11 = OpConstantNull %2
433 %3 = OpFunction %2 None %1
434 %4 = OpLabel
435 OpSelectionMerge %5 None
436 OpSwitch %6 %7 1 %8 2 %9
437 %8 = OpLabel
438 OpReturnValue %6
439 %9 = OpLabel
440 OpBranch %7
441 %7 = OpLabel
442 OpReturnValue %10
443 %5 = OpLabel
444 OpReturnValue %11
445 OpFunctionEnd
446 )");
447 }
448
449 } // namespace
450 } // namespace spirv
451 } // namespace writer
452 } // namespace tint
453