1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <memory>
16 #include <string>
17 #include <unordered_set>
18 #include <vector>
19
20 #include "gmock/gmock.h"
21 #include "source/opt/iterator.h"
22 #include "source/opt/loop_descriptor.h"
23 #include "source/opt/pass.h"
24 #include "source/opt/tree_iterator.h"
25 #include "test/opt/assembly_builder.h"
26 #include "test/opt/function_utils.h"
27 #include "test/opt/pass_fixture.h"
28 #include "test/opt/pass_utils.h"
29
30 namespace spvtools {
31 namespace opt {
32 namespace {
33
34 using ::testing::UnorderedElementsAre;
35
Validate(const std::vector<uint32_t> & bin)36 bool Validate(const std::vector<uint32_t>& bin) {
37 spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
38 spv_context spvContext = spvContextCreate(target_env);
39 spv_diagnostic diagnostic = nullptr;
40 spv_const_binary_t binary = {bin.data(), bin.size()};
41 spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
42 if (error != 0) spvDiagnosticPrint(diagnostic);
43 spvDiagnosticDestroy(diagnostic);
44 spvContextDestroy(spvContext);
45 return error == 0;
46 }
47
48 using PassClassTest = PassTest<::testing::Test>;
49
50 /*
51 Generated from the following GLSL
52 #version 330 core
53 layout(location = 0) out vec4 c;
54 void main() {
55 int i = 0;
56 for (; i < 10; ++i) {
57 int j = 0;
58 int k = 0;
59 for (; j < 11; ++j) {}
60 for (; k < 12; ++k) {}
61 }
62 }
63 */
TEST_F(PassClassTest,BasicVisitFromEntryPoint)64 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
65 const std::string text = R"(
66 OpCapability Shader
67 %1 = OpExtInstImport "GLSL.std.450"
68 OpMemoryModel Logical GLSL450
69 OpEntryPoint Fragment %2 "main" %3
70 OpExecutionMode %2 OriginUpperLeft
71 OpSource GLSL 330
72 OpName %2 "main"
73 OpName %4 "i"
74 OpName %5 "j"
75 OpName %6 "k"
76 OpName %3 "c"
77 OpDecorate %3 Location 0
78 %7 = OpTypeVoid
79 %8 = OpTypeFunction %7
80 %9 = OpTypeInt 32 1
81 %10 = OpTypePointer Function %9
82 %11 = OpConstant %9 0
83 %12 = OpConstant %9 10
84 %13 = OpTypeBool
85 %14 = OpConstant %9 11
86 %15 = OpConstant %9 1
87 %16 = OpConstant %9 12
88 %17 = OpTypeFloat 32
89 %18 = OpTypeVector %17 4
90 %19 = OpTypePointer Output %18
91 %3 = OpVariable %19 Output
92 %2 = OpFunction %7 None %8
93 %20 = OpLabel
94 %4 = OpVariable %10 Function
95 %5 = OpVariable %10 Function
96 %6 = OpVariable %10 Function
97 OpStore %4 %11
98 OpBranch %21
99 %21 = OpLabel
100 OpLoopMerge %22 %23 None
101 OpBranch %24
102 %24 = OpLabel
103 %25 = OpLoad %9 %4
104 %26 = OpSLessThan %13 %25 %12
105 OpBranchConditional %26 %27 %22
106 %27 = OpLabel
107 OpStore %5 %11
108 OpStore %6 %11
109 OpBranch %28
110 %28 = OpLabel
111 OpLoopMerge %29 %30 None
112 OpBranch %31
113 %31 = OpLabel
114 %32 = OpLoad %9 %5
115 %33 = OpSLessThan %13 %32 %14
116 OpBranchConditional %33 %34 %29
117 %34 = OpLabel
118 OpBranch %30
119 %30 = OpLabel
120 %35 = OpLoad %9 %5
121 %36 = OpIAdd %9 %35 %15
122 OpStore %5 %36
123 OpBranch %28
124 %29 = OpLabel
125 OpBranch %37
126 %37 = OpLabel
127 OpLoopMerge %38 %39 None
128 OpBranch %40
129 %40 = OpLabel
130 %41 = OpLoad %9 %6
131 %42 = OpSLessThan %13 %41 %16
132 OpBranchConditional %42 %43 %38
133 %43 = OpLabel
134 OpBranch %39
135 %39 = OpLabel
136 %44 = OpLoad %9 %6
137 %45 = OpIAdd %9 %44 %15
138 OpStore %6 %45
139 OpBranch %37
140 %38 = OpLabel
141 OpBranch %23
142 %23 = OpLabel
143 %46 = OpLoad %9 %4
144 %47 = OpIAdd %9 %46 %15
145 OpStore %4 %47
146 OpBranch %21
147 %22 = OpLabel
148 OpReturn
149 OpFunctionEnd
150 )";
151 // clang-format on
152 std::unique_ptr<IRContext> context =
153 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
154 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
155 Module* module = context->module();
156 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
157 << text << std::endl;
158 const Function* f = spvtest::GetFunction(module, 2);
159 LoopDescriptor& ld = *context->GetLoopDescriptor(f);
160
161 EXPECT_EQ(ld.NumLoops(), 3u);
162
163 // Invalid basic block id.
164 EXPECT_EQ(ld[0u], nullptr);
165 // Not a loop header.
166 EXPECT_EQ(ld[20], nullptr);
167
168 Loop& parent_loop = *ld[21];
169 EXPECT_TRUE(parent_loop.HasNestedLoops());
170 EXPECT_FALSE(parent_loop.IsNested());
171 EXPECT_EQ(parent_loop.GetDepth(), 1u);
172 EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
173 EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
174 EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
175 EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
176
177 Loop& child_loop_1 = *ld[28];
178 EXPECT_FALSE(child_loop_1.HasNestedLoops());
179 EXPECT_TRUE(child_loop_1.IsNested());
180 EXPECT_EQ(child_loop_1.GetDepth(), 2u);
181 EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
182 EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
183 EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
184 EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
185
186 Loop& child_loop_2 = *ld[37];
187 EXPECT_FALSE(child_loop_2.HasNestedLoops());
188 EXPECT_TRUE(child_loop_2.IsNested());
189 EXPECT_EQ(child_loop_2.GetDepth(), 2u);
190 EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
191 EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
192 EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
193 EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
194 }
195
CheckLoopBlocks(Loop * loop,std::unordered_set<uint32_t> * expected_ids)196 static void CheckLoopBlocks(Loop* loop,
197 std::unordered_set<uint32_t>* expected_ids) {
198 SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
199 for (uint32_t bb_id : loop->GetBlocks()) {
200 EXPECT_EQ(expected_ids->count(bb_id), 1u);
201 expected_ids->erase(bb_id);
202 }
203 EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
204 EXPECT_EQ(expected_ids->size(), 0u);
205 }
206
207 /*
208 Generated from the following GLSL
209 #version 330 core
210 layout(location = 0) out vec4 c;
211 void main() {
212 int i = 0;
213 for (; i < 10; ++i) {
214 for (int j = 0; j < 11; ++j) {
215 if (j < 5) {
216 for (int k = 0; k < 12; ++k) {}
217 }
218 else {}
219 for (int k = 0; k < 12; ++k) {}
220 }
221 }
222 }*/
TEST_F(PassClassTest,TripleNestedLoop)223 TEST_F(PassClassTest, TripleNestedLoop) {
224 const std::string text = R"(
225 OpCapability Shader
226 %1 = OpExtInstImport "GLSL.std.450"
227 OpMemoryModel Logical GLSL450
228 OpEntryPoint Fragment %2 "main" %3
229 OpExecutionMode %2 OriginUpperLeft
230 OpSource GLSL 330
231 OpName %2 "main"
232 OpName %4 "i"
233 OpName %5 "j"
234 OpName %6 "k"
235 OpName %7 "k"
236 OpName %3 "c"
237 OpDecorate %3 Location 0
238 %8 = OpTypeVoid
239 %9 = OpTypeFunction %8
240 %10 = OpTypeInt 32 1
241 %11 = OpTypePointer Function %10
242 %12 = OpConstant %10 0
243 %13 = OpConstant %10 10
244 %14 = OpTypeBool
245 %15 = OpConstant %10 11
246 %16 = OpConstant %10 5
247 %17 = OpConstant %10 12
248 %18 = OpConstant %10 1
249 %19 = OpTypeFloat 32
250 %20 = OpTypeVector %19 4
251 %21 = OpTypePointer Output %20
252 %3 = OpVariable %21 Output
253 %2 = OpFunction %8 None %9
254 %22 = OpLabel
255 %4 = OpVariable %11 Function
256 %5 = OpVariable %11 Function
257 %6 = OpVariable %11 Function
258 %7 = OpVariable %11 Function
259 OpStore %4 %12
260 OpBranch %23
261 %23 = OpLabel
262 OpLoopMerge %24 %25 None
263 OpBranch %26
264 %26 = OpLabel
265 %27 = OpLoad %10 %4
266 %28 = OpSLessThan %14 %27 %13
267 OpBranchConditional %28 %29 %24
268 %29 = OpLabel
269 OpStore %5 %12
270 OpBranch %30
271 %30 = OpLabel
272 OpLoopMerge %31 %32 None
273 OpBranch %33
274 %33 = OpLabel
275 %34 = OpLoad %10 %5
276 %35 = OpSLessThan %14 %34 %15
277 OpBranchConditional %35 %36 %31
278 %36 = OpLabel
279 %37 = OpLoad %10 %5
280 %38 = OpSLessThan %14 %37 %16
281 OpSelectionMerge %39 None
282 OpBranchConditional %38 %40 %39
283 %40 = OpLabel
284 OpStore %6 %12
285 OpBranch %41
286 %41 = OpLabel
287 OpLoopMerge %42 %43 None
288 OpBranch %44
289 %44 = OpLabel
290 %45 = OpLoad %10 %6
291 %46 = OpSLessThan %14 %45 %17
292 OpBranchConditional %46 %47 %42
293 %47 = OpLabel
294 OpBranch %43
295 %43 = OpLabel
296 %48 = OpLoad %10 %6
297 %49 = OpIAdd %10 %48 %18
298 OpStore %6 %49
299 OpBranch %41
300 %42 = OpLabel
301 OpBranch %39
302 %39 = OpLabel
303 OpStore %7 %12
304 OpBranch %50
305 %50 = OpLabel
306 OpLoopMerge %51 %52 None
307 OpBranch %53
308 %53 = OpLabel
309 %54 = OpLoad %10 %7
310 %55 = OpSLessThan %14 %54 %17
311 OpBranchConditional %55 %56 %51
312 %56 = OpLabel
313 OpBranch %52
314 %52 = OpLabel
315 %57 = OpLoad %10 %7
316 %58 = OpIAdd %10 %57 %18
317 OpStore %7 %58
318 OpBranch %50
319 %51 = OpLabel
320 OpBranch %32
321 %32 = OpLabel
322 %59 = OpLoad %10 %5
323 %60 = OpIAdd %10 %59 %18
324 OpStore %5 %60
325 OpBranch %30
326 %31 = OpLabel
327 OpBranch %25
328 %25 = OpLabel
329 %61 = OpLoad %10 %4
330 %62 = OpIAdd %10 %61 %18
331 OpStore %4 %62
332 OpBranch %23
333 %24 = OpLabel
334 OpReturn
335 OpFunctionEnd
336 )";
337 // clang-format on
338 std::unique_ptr<IRContext> context =
339 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
340 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
341 Module* module = context->module();
342 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
343 << text << std::endl;
344 const Function* f = spvtest::GetFunction(module, 2);
345 LoopDescriptor& ld = *context->GetLoopDescriptor(f);
346
347 EXPECT_EQ(ld.NumLoops(), 4u);
348
349 // Invalid basic block id.
350 EXPECT_EQ(ld[0u], nullptr);
351 // Not in a loop.
352 EXPECT_EQ(ld[22], nullptr);
353
354 // Check that we can map basic block to the correct loop.
355 // The following block ids do not belong to a loop.
356 for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
357
358 {
359 std::unordered_set<uint32_t> basic_block_in_loop = {
360 {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
361 42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
362 Loop* loop = ld[23];
363 CheckLoopBlocks(loop, &basic_block_in_loop);
364
365 EXPECT_TRUE(loop->HasNestedLoops());
366 EXPECT_FALSE(loop->IsNested());
367 EXPECT_EQ(loop->GetDepth(), 1u);
368 EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
369 EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
370 EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
371 EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
372 EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
373 EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
374 EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
375 }
376
377 {
378 std::unordered_set<uint32_t> basic_block_in_loop = {
379 {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
380 Loop* loop = ld[30];
381 CheckLoopBlocks(loop, &basic_block_in_loop);
382
383 EXPECT_TRUE(loop->HasNestedLoops());
384 EXPECT_TRUE(loop->IsNested());
385 EXPECT_EQ(loop->GetDepth(), 2u);
386 EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
387 EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
388 EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
389 EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
390 EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
391 EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
392 EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
393 }
394
395 {
396 std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
397 Loop* loop = ld[41];
398 CheckLoopBlocks(loop, &basic_block_in_loop);
399
400 EXPECT_FALSE(loop->HasNestedLoops());
401 EXPECT_TRUE(loop->IsNested());
402 EXPECT_EQ(loop->GetDepth(), 3u);
403 EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
404 EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
405 EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
406 EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
407 EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
408 EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
409 EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
410 }
411
412 {
413 std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
414 Loop* loop = ld[50];
415 CheckLoopBlocks(loop, &basic_block_in_loop);
416
417 EXPECT_FALSE(loop->HasNestedLoops());
418 EXPECT_TRUE(loop->IsNested());
419 EXPECT_EQ(loop->GetDepth(), 3u);
420 EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
421 EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
422 EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
423 EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
424 EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
425 EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
426 EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
427 }
428
429 // Make sure LoopDescriptor gives us the inner most loop when we query for
430 // loops.
431 for (const BasicBlock& bb : *f) {
432 if (Loop* loop = ld[&bb]) {
433 for (Loop& sub_loop :
434 make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
435 EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
436 }
437 }
438 }
439 }
440
441 /*
442 Generated from the following GLSL
443 #version 330 core
444 layout(location = 0) out vec4 c;
445 void main() {
446 for (int i = 0; i < 10; ++i) {
447 for (int j = 0; j < 11; ++j) {
448 for (int k = 0; k < 11; ++k) {}
449 }
450 for (int k = 0; k < 12; ++k) {}
451 }
452 }
453 */
TEST_F(PassClassTest,LoopParentTest)454 TEST_F(PassClassTest, LoopParentTest) {
455 const std::string text = R"(
456 OpCapability Shader
457 %1 = OpExtInstImport "GLSL.std.450"
458 OpMemoryModel Logical GLSL450
459 OpEntryPoint Fragment %2 "main" %3
460 OpExecutionMode %2 OriginUpperLeft
461 OpSource GLSL 330
462 OpName %2 "main"
463 OpName %4 "i"
464 OpName %5 "j"
465 OpName %6 "k"
466 OpName %7 "k"
467 OpName %3 "c"
468 OpDecorate %3 Location 0
469 %8 = OpTypeVoid
470 %9 = OpTypeFunction %8
471 %10 = OpTypeInt 32 1
472 %11 = OpTypePointer Function %10
473 %12 = OpConstant %10 0
474 %13 = OpConstant %10 10
475 %14 = OpTypeBool
476 %15 = OpConstant %10 11
477 %16 = OpConstant %10 1
478 %17 = OpConstant %10 12
479 %18 = OpTypeFloat 32
480 %19 = OpTypeVector %18 4
481 %20 = OpTypePointer Output %19
482 %3 = OpVariable %20 Output
483 %2 = OpFunction %8 None %9
484 %21 = OpLabel
485 %4 = OpVariable %11 Function
486 %5 = OpVariable %11 Function
487 %6 = OpVariable %11 Function
488 %7 = OpVariable %11 Function
489 OpStore %4 %12
490 OpBranch %22
491 %22 = OpLabel
492 OpLoopMerge %23 %24 None
493 OpBranch %25
494 %25 = OpLabel
495 %26 = OpLoad %10 %4
496 %27 = OpSLessThan %14 %26 %13
497 OpBranchConditional %27 %28 %23
498 %28 = OpLabel
499 OpStore %5 %12
500 OpBranch %29
501 %29 = OpLabel
502 OpLoopMerge %30 %31 None
503 OpBranch %32
504 %32 = OpLabel
505 %33 = OpLoad %10 %5
506 %34 = OpSLessThan %14 %33 %15
507 OpBranchConditional %34 %35 %30
508 %35 = OpLabel
509 OpStore %6 %12
510 OpBranch %36
511 %36 = OpLabel
512 OpLoopMerge %37 %38 None
513 OpBranch %39
514 %39 = OpLabel
515 %40 = OpLoad %10 %6
516 %41 = OpSLessThan %14 %40 %15
517 OpBranchConditional %41 %42 %37
518 %42 = OpLabel
519 OpBranch %38
520 %38 = OpLabel
521 %43 = OpLoad %10 %6
522 %44 = OpIAdd %10 %43 %16
523 OpStore %6 %44
524 OpBranch %36
525 %37 = OpLabel
526 OpBranch %31
527 %31 = OpLabel
528 %45 = OpLoad %10 %5
529 %46 = OpIAdd %10 %45 %16
530 OpStore %5 %46
531 OpBranch %29
532 %30 = OpLabel
533 OpStore %7 %12
534 OpBranch %47
535 %47 = OpLabel
536 OpLoopMerge %48 %49 None
537 OpBranch %50
538 %50 = OpLabel
539 %51 = OpLoad %10 %7
540 %52 = OpSLessThan %14 %51 %17
541 OpBranchConditional %52 %53 %48
542 %53 = OpLabel
543 OpBranch %49
544 %49 = OpLabel
545 %54 = OpLoad %10 %7
546 %55 = OpIAdd %10 %54 %16
547 OpStore %7 %55
548 OpBranch %47
549 %48 = OpLabel
550 OpBranch %24
551 %24 = OpLabel
552 %56 = OpLoad %10 %4
553 %57 = OpIAdd %10 %56 %16
554 OpStore %4 %57
555 OpBranch %22
556 %23 = OpLabel
557 OpReturn
558 OpFunctionEnd
559 )";
560 // clang-format on
561 std::unique_ptr<IRContext> context =
562 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
563 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
564 Module* module = context->module();
565 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
566 << text << std::endl;
567 const Function* f = spvtest::GetFunction(module, 2);
568 LoopDescriptor& ld = *context->GetLoopDescriptor(f);
569
570 EXPECT_EQ(ld.NumLoops(), 4u);
571
572 {
573 Loop& loop = *ld[22];
574 EXPECT_TRUE(loop.HasNestedLoops());
575 EXPECT_FALSE(loop.IsNested());
576 EXPECT_EQ(loop.GetDepth(), 1u);
577 EXPECT_EQ(loop.GetParent(), nullptr);
578 }
579
580 {
581 Loop& loop = *ld[29];
582 EXPECT_TRUE(loop.HasNestedLoops());
583 EXPECT_TRUE(loop.IsNested());
584 EXPECT_EQ(loop.GetDepth(), 2u);
585 EXPECT_EQ(loop.GetParent(), ld[22]);
586 }
587
588 {
589 Loop& loop = *ld[36];
590 EXPECT_FALSE(loop.HasNestedLoops());
591 EXPECT_TRUE(loop.IsNested());
592 EXPECT_EQ(loop.GetDepth(), 3u);
593 EXPECT_EQ(loop.GetParent(), ld[29]);
594 }
595
596 {
597 Loop& loop = *ld[47];
598 EXPECT_FALSE(loop.HasNestedLoops());
599 EXPECT_TRUE(loop.IsNested());
600 EXPECT_EQ(loop.GetDepth(), 2u);
601 EXPECT_EQ(loop.GetParent(), ld[22]);
602 }
603 }
604
605 /*
606 Generated from the following GLSL + --eliminate-local-multi-store
607 The preheader of loop %33 and %41 were removed as well.
608
609 #version 330 core
610 void main() {
611 int a = 0;
612 for (int i = 0; i < 10; ++i) {
613 if (i == 0) {
614 a = 1;
615 } else {
616 a = 2;
617 }
618 for (int j = 0; j < 11; ++j) {
619 a++;
620 }
621 }
622 for (int k = 0; k < 12; ++k) {}
623 }
624 */
TEST_F(PassClassTest,CreatePreheaderTest)625 TEST_F(PassClassTest, CreatePreheaderTest) {
626 const std::string text = R"(
627 OpCapability Shader
628 %1 = OpExtInstImport "GLSL.std.450"
629 OpMemoryModel Logical GLSL450
630 OpEntryPoint Fragment %2 "main"
631 OpExecutionMode %2 OriginUpperLeft
632 OpSource GLSL 330
633 OpName %2 "main"
634 %3 = OpTypeVoid
635 %4 = OpTypeFunction %3
636 %5 = OpTypeInt 32 1
637 %6 = OpTypePointer Function %5
638 %7 = OpConstant %5 0
639 %8 = OpConstant %5 10
640 %9 = OpTypeBool
641 %10 = OpConstant %5 1
642 %11 = OpConstant %5 2
643 %12 = OpConstant %5 11
644 %13 = OpConstant %5 12
645 %14 = OpUndef %5
646 %2 = OpFunction %3 None %4
647 %15 = OpLabel
648 OpBranch %16
649 %16 = OpLabel
650 %17 = OpPhi %5 %7 %15 %18 %19
651 %20 = OpPhi %5 %7 %15 %21 %19
652 %22 = OpPhi %5 %14 %15 %23 %19
653 OpLoopMerge %41 %19 None
654 OpBranch %25
655 %25 = OpLabel
656 %26 = OpSLessThan %9 %20 %8
657 OpBranchConditional %26 %27 %41
658 %27 = OpLabel
659 %28 = OpIEqual %9 %20 %7
660 OpSelectionMerge %33 None
661 OpBranchConditional %28 %30 %31
662 %30 = OpLabel
663 OpBranch %33
664 %31 = OpLabel
665 OpBranch %33
666 %33 = OpLabel
667 %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
668 %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
669 OpLoopMerge %37 %35 None
670 OpBranch %38
671 %38 = OpLabel
672 %39 = OpSLessThan %9 %23 %12
673 OpBranchConditional %39 %40 %37
674 %40 = OpLabel
675 %34 = OpIAdd %5 %18 %10
676 OpBranch %35
677 %35 = OpLabel
678 %36 = OpIAdd %5 %23 %10
679 OpBranch %33
680 %37 = OpLabel
681 OpBranch %19
682 %19 = OpLabel
683 %21 = OpIAdd %5 %20 %10
684 OpBranch %16
685 %41 = OpLabel
686 %42 = OpPhi %5 %7 %25 %43 %44
687 OpLoopMerge %45 %44 None
688 OpBranch %46
689 %46 = OpLabel
690 %47 = OpSLessThan %9 %42 %13
691 OpBranchConditional %47 %48 %45
692 %48 = OpLabel
693 OpBranch %44
694 %44 = OpLabel
695 %43 = OpIAdd %5 %42 %10
696 OpBranch %41
697 %45 = OpLabel
698 OpReturn
699 OpFunctionEnd
700 )";
701 // clang-format on
702 std::unique_ptr<IRContext> context =
703 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
704 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
705 Module* module = context->module();
706 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
707 << text << std::endl;
708 const Function* f = spvtest::GetFunction(module, 2);
709 LoopDescriptor& ld = *context->GetLoopDescriptor(f);
710 // No invalidation of the cfg should occur during this test.
711 CFG* cfg = context->cfg();
712
713 EXPECT_EQ(ld.NumLoops(), 3u);
714
715 {
716 Loop& loop = *ld[16];
717 EXPECT_TRUE(loop.HasNestedLoops());
718 EXPECT_FALSE(loop.IsNested());
719 EXPECT_EQ(loop.GetDepth(), 1u);
720 EXPECT_EQ(loop.GetParent(), nullptr);
721 }
722
723 {
724 Loop& loop = *ld[33];
725 EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
726 EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
727 // Make sure the loop descriptor was properly updated.
728 EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
729 {
730 const std::vector<uint32_t>& preds =
731 cfg->preds(loop.GetPreHeaderBlock()->id());
732 std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
733 EXPECT_EQ(pred_set.size(), 2u);
734 EXPECT_TRUE(pred_set.count(30));
735 EXPECT_TRUE(pred_set.count(31));
736 // Check the phi instructions.
737 loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
738 for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
739 EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
740 }
741 });
742 }
743 {
744 const std::vector<uint32_t>& preds =
745 cfg->preds(loop.GetHeaderBlock()->id());
746 std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
747 EXPECT_EQ(pred_set.size(), 2u);
748 EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
749 EXPECT_TRUE(pred_set.count(35));
750 // Check the phi instructions.
751 loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
752 for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
753 EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
754 }
755 });
756 }
757 }
758
759 {
760 Loop& loop = *ld[41];
761 EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
762 EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
763 EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
764 EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
765 EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
766 // Check the phi instructions.
767 loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) {
768 EXPECT_EQ(phi->NumInOperands(), 2u);
769 EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
770 });
771 {
772 const std::vector<uint32_t>& preds =
773 cfg->preds(loop.GetHeaderBlock()->id());
774 std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
775 EXPECT_EQ(pred_set.size(), 2u);
776 EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
777 EXPECT_TRUE(pred_set.count(44));
778 // Check the phi instructions.
779 loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
780 for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
781 EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
782 }
783 });
784 }
785 }
786
787 // Make sure pre-header insertion leaves the module valid.
788 std::vector<uint32_t> bin;
789 context->module()->ToBinary(&bin, true);
790 EXPECT_TRUE(Validate(bin));
791 }
792
793 } // namespace
794 } // namespace opt
795 } // namespace spvtools
796