• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <string>
17 #include <unordered_set>
18 #include <vector>
19 
20 #include "gmock/gmock.h"
21 #include "source/opt/iterator.h"
22 #include "source/opt/loop_descriptor.h"
23 #include "source/opt/pass.h"
24 #include "source/opt/tree_iterator.h"
25 #include "test/opt/assembly_builder.h"
26 #include "test/opt/function_utils.h"
27 #include "test/opt/pass_fixture.h"
28 #include "test/opt/pass_utils.h"
29 
30 namespace spvtools {
31 namespace opt {
32 namespace {
33 
34 using ::testing::UnorderedElementsAre;
35 
Validate(const std::vector<uint32_t> & bin)36 bool Validate(const std::vector<uint32_t>& bin) {
37   spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
38   spv_context spvContext = spvContextCreate(target_env);
39   spv_diagnostic diagnostic = nullptr;
40   spv_const_binary_t binary = {bin.data(), bin.size()};
41   spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
42   if (error != 0) spvDiagnosticPrint(diagnostic);
43   spvDiagnosticDestroy(diagnostic);
44   spvContextDestroy(spvContext);
45   return error == 0;
46 }
47 
48 using PassClassTest = PassTest<::testing::Test>;
49 
50 /*
51 Generated from the following GLSL
52 #version 330 core
53 layout(location = 0) out vec4 c;
54 void main() {
55   int i = 0;
56   for (; i < 10; ++i) {
57     int j = 0;
58     int k = 0;
59     for (; j < 11; ++j) {}
60     for (; k < 12; ++k) {}
61   }
62 }
63 */
TEST_F(PassClassTest,BasicVisitFromEntryPoint)64 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
65   const std::string text = R"(
66                OpCapability Shader
67           %1 = OpExtInstImport "GLSL.std.450"
68                OpMemoryModel Logical GLSL450
69                OpEntryPoint Fragment %2 "main" %3
70                OpExecutionMode %2 OriginUpperLeft
71                OpSource GLSL 330
72                OpName %2 "main"
73                OpName %4 "i"
74                OpName %5 "j"
75                OpName %6 "k"
76                OpName %3 "c"
77                OpDecorate %3 Location 0
78           %7 = OpTypeVoid
79           %8 = OpTypeFunction %7
80           %9 = OpTypeInt 32 1
81          %10 = OpTypePointer Function %9
82          %11 = OpConstant %9 0
83          %12 = OpConstant %9 10
84          %13 = OpTypeBool
85          %14 = OpConstant %9 11
86          %15 = OpConstant %9 1
87          %16 = OpConstant %9 12
88          %17 = OpTypeFloat 32
89          %18 = OpTypeVector %17 4
90          %19 = OpTypePointer Output %18
91           %3 = OpVariable %19 Output
92           %2 = OpFunction %7 None %8
93          %20 = OpLabel
94           %4 = OpVariable %10 Function
95           %5 = OpVariable %10 Function
96           %6 = OpVariable %10 Function
97                OpStore %4 %11
98                OpBranch %21
99          %21 = OpLabel
100                OpLoopMerge %22 %23 None
101                OpBranch %24
102          %24 = OpLabel
103          %25 = OpLoad %9 %4
104          %26 = OpSLessThan %13 %25 %12
105                OpBranchConditional %26 %27 %22
106          %27 = OpLabel
107                OpStore %5 %11
108                OpStore %6 %11
109                OpBranch %28
110          %28 = OpLabel
111                OpLoopMerge %29 %30 None
112                OpBranch %31
113          %31 = OpLabel
114          %32 = OpLoad %9 %5
115          %33 = OpSLessThan %13 %32 %14
116                OpBranchConditional %33 %34 %29
117          %34 = OpLabel
118                OpBranch %30
119          %30 = OpLabel
120          %35 = OpLoad %9 %5
121          %36 = OpIAdd %9 %35 %15
122                OpStore %5 %36
123                OpBranch %28
124          %29 = OpLabel
125                OpBranch %37
126          %37 = OpLabel
127                OpLoopMerge %38 %39 None
128                OpBranch %40
129          %40 = OpLabel
130          %41 = OpLoad %9 %6
131          %42 = OpSLessThan %13 %41 %16
132                OpBranchConditional %42 %43 %38
133          %43 = OpLabel
134                OpBranch %39
135          %39 = OpLabel
136          %44 = OpLoad %9 %6
137          %45 = OpIAdd %9 %44 %15
138                OpStore %6 %45
139                OpBranch %37
140          %38 = OpLabel
141                OpBranch %23
142          %23 = OpLabel
143          %46 = OpLoad %9 %4
144          %47 = OpIAdd %9 %46 %15
145                OpStore %4 %47
146                OpBranch %21
147          %22 = OpLabel
148                OpReturn
149                OpFunctionEnd
150   )";
151   // clang-format on
152   std::unique_ptr<IRContext> context =
153       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
154                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
155   Module* module = context->module();
156   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
157                              << text << std::endl;
158   const Function* f = spvtest::GetFunction(module, 2);
159   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
160 
161   EXPECT_EQ(ld.NumLoops(), 3u);
162 
163   // Invalid basic block id.
164   EXPECT_EQ(ld[0u], nullptr);
165   // Not a loop header.
166   EXPECT_EQ(ld[20], nullptr);
167 
168   Loop& parent_loop = *ld[21];
169   EXPECT_TRUE(parent_loop.HasNestedLoops());
170   EXPECT_FALSE(parent_loop.IsNested());
171   EXPECT_EQ(parent_loop.GetDepth(), 1u);
172   EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
173   EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
174   EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
175   EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
176 
177   Loop& child_loop_1 = *ld[28];
178   EXPECT_FALSE(child_loop_1.HasNestedLoops());
179   EXPECT_TRUE(child_loop_1.IsNested());
180   EXPECT_EQ(child_loop_1.GetDepth(), 2u);
181   EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
182   EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
183   EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
184   EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
185 
186   Loop& child_loop_2 = *ld[37];
187   EXPECT_FALSE(child_loop_2.HasNestedLoops());
188   EXPECT_TRUE(child_loop_2.IsNested());
189   EXPECT_EQ(child_loop_2.GetDepth(), 2u);
190   EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
191   EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
192   EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
193   EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
194 }
195 
CheckLoopBlocks(Loop * loop,std::unordered_set<uint32_t> * expected_ids)196 static void CheckLoopBlocks(Loop* loop,
197                             std::unordered_set<uint32_t>* expected_ids) {
198   SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
199   for (uint32_t bb_id : loop->GetBlocks()) {
200     EXPECT_EQ(expected_ids->count(bb_id), 1u);
201     expected_ids->erase(bb_id);
202   }
203   EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
204   EXPECT_EQ(expected_ids->size(), 0u);
205 }
206 
207 /*
208 Generated from the following GLSL
209 #version 330 core
210 layout(location = 0) out vec4 c;
211 void main() {
212   int i = 0;
213   for (; i < 10; ++i) {
214     for (int j = 0; j < 11; ++j) {
215       if (j < 5) {
216         for (int k = 0; k < 12; ++k) {}
217       }
218       else {}
219       for (int k = 0; k < 12; ++k) {}
220     }
221   }
222 }*/
TEST_F(PassClassTest,TripleNestedLoop)223 TEST_F(PassClassTest, TripleNestedLoop) {
224   const std::string text = R"(
225                OpCapability Shader
226           %1 = OpExtInstImport "GLSL.std.450"
227                OpMemoryModel Logical GLSL450
228                OpEntryPoint Fragment %2 "main" %3
229                OpExecutionMode %2 OriginUpperLeft
230                OpSource GLSL 330
231                OpName %2 "main"
232                OpName %4 "i"
233                OpName %5 "j"
234                OpName %6 "k"
235                OpName %7 "k"
236                OpName %3 "c"
237                OpDecorate %3 Location 0
238           %8 = OpTypeVoid
239           %9 = OpTypeFunction %8
240          %10 = OpTypeInt 32 1
241          %11 = OpTypePointer Function %10
242          %12 = OpConstant %10 0
243          %13 = OpConstant %10 10
244          %14 = OpTypeBool
245          %15 = OpConstant %10 11
246          %16 = OpConstant %10 5
247          %17 = OpConstant %10 12
248          %18 = OpConstant %10 1
249          %19 = OpTypeFloat 32
250          %20 = OpTypeVector %19 4
251          %21 = OpTypePointer Output %20
252           %3 = OpVariable %21 Output
253           %2 = OpFunction %8 None %9
254          %22 = OpLabel
255           %4 = OpVariable %11 Function
256           %5 = OpVariable %11 Function
257           %6 = OpVariable %11 Function
258           %7 = OpVariable %11 Function
259                OpStore %4 %12
260                OpBranch %23
261          %23 = OpLabel
262                OpLoopMerge %24 %25 None
263                OpBranch %26
264          %26 = OpLabel
265          %27 = OpLoad %10 %4
266          %28 = OpSLessThan %14 %27 %13
267                OpBranchConditional %28 %29 %24
268          %29 = OpLabel
269                OpStore %5 %12
270                OpBranch %30
271          %30 = OpLabel
272                OpLoopMerge %31 %32 None
273                OpBranch %33
274          %33 = OpLabel
275          %34 = OpLoad %10 %5
276          %35 = OpSLessThan %14 %34 %15
277                OpBranchConditional %35 %36 %31
278          %36 = OpLabel
279          %37 = OpLoad %10 %5
280          %38 = OpSLessThan %14 %37 %16
281                OpSelectionMerge %39 None
282                OpBranchConditional %38 %40 %39
283          %40 = OpLabel
284                OpStore %6 %12
285                OpBranch %41
286          %41 = OpLabel
287                OpLoopMerge %42 %43 None
288                OpBranch %44
289          %44 = OpLabel
290          %45 = OpLoad %10 %6
291          %46 = OpSLessThan %14 %45 %17
292                OpBranchConditional %46 %47 %42
293          %47 = OpLabel
294                OpBranch %43
295          %43 = OpLabel
296          %48 = OpLoad %10 %6
297          %49 = OpIAdd %10 %48 %18
298                OpStore %6 %49
299                OpBranch %41
300          %42 = OpLabel
301                OpBranch %39
302          %39 = OpLabel
303                OpStore %7 %12
304                OpBranch %50
305          %50 = OpLabel
306                OpLoopMerge %51 %52 None
307                OpBranch %53
308          %53 = OpLabel
309          %54 = OpLoad %10 %7
310          %55 = OpSLessThan %14 %54 %17
311                OpBranchConditional %55 %56 %51
312          %56 = OpLabel
313                OpBranch %52
314          %52 = OpLabel
315          %57 = OpLoad %10 %7
316          %58 = OpIAdd %10 %57 %18
317                OpStore %7 %58
318                OpBranch %50
319          %51 = OpLabel
320                OpBranch %32
321          %32 = OpLabel
322          %59 = OpLoad %10 %5
323          %60 = OpIAdd %10 %59 %18
324                OpStore %5 %60
325                OpBranch %30
326          %31 = OpLabel
327                OpBranch %25
328          %25 = OpLabel
329          %61 = OpLoad %10 %4
330          %62 = OpIAdd %10 %61 %18
331                OpStore %4 %62
332                OpBranch %23
333          %24 = OpLabel
334                OpReturn
335                OpFunctionEnd
336   )";
337   // clang-format on
338   std::unique_ptr<IRContext> context =
339       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
340                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
341   Module* module = context->module();
342   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
343                              << text << std::endl;
344   const Function* f = spvtest::GetFunction(module, 2);
345   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
346 
347   EXPECT_EQ(ld.NumLoops(), 4u);
348 
349   // Invalid basic block id.
350   EXPECT_EQ(ld[0u], nullptr);
351   // Not in a loop.
352   EXPECT_EQ(ld[22], nullptr);
353 
354   // Check that we can map basic block to the correct loop.
355   // The following block ids do not belong to a loop.
356   for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
357 
358   {
359     std::unordered_set<uint32_t> basic_block_in_loop = {
360         {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
361          42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
362     Loop* loop = ld[23];
363     CheckLoopBlocks(loop, &basic_block_in_loop);
364 
365     EXPECT_TRUE(loop->HasNestedLoops());
366     EXPECT_FALSE(loop->IsNested());
367     EXPECT_EQ(loop->GetDepth(), 1u);
368     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
369     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
370     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
371     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
372     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
373     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
374     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
375   }
376 
377   {
378     std::unordered_set<uint32_t> basic_block_in_loop = {
379         {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
380     Loop* loop = ld[30];
381     CheckLoopBlocks(loop, &basic_block_in_loop);
382 
383     EXPECT_TRUE(loop->HasNestedLoops());
384     EXPECT_TRUE(loop->IsNested());
385     EXPECT_EQ(loop->GetDepth(), 2u);
386     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
387     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
388     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
389     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
390     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
391     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
392     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
393   }
394 
395   {
396     std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
397     Loop* loop = ld[41];
398     CheckLoopBlocks(loop, &basic_block_in_loop);
399 
400     EXPECT_FALSE(loop->HasNestedLoops());
401     EXPECT_TRUE(loop->IsNested());
402     EXPECT_EQ(loop->GetDepth(), 3u);
403     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
404     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
405     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
406     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
407     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
408     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
409     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
410   }
411 
412   {
413     std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
414     Loop* loop = ld[50];
415     CheckLoopBlocks(loop, &basic_block_in_loop);
416 
417     EXPECT_FALSE(loop->HasNestedLoops());
418     EXPECT_TRUE(loop->IsNested());
419     EXPECT_EQ(loop->GetDepth(), 3u);
420     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
421     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
422     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
423     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
424     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
425     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
426     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
427   }
428 
429   // Make sure LoopDescriptor gives us the inner most loop when we query for
430   // loops.
431   for (const BasicBlock& bb : *f) {
432     if (Loop* loop = ld[&bb]) {
433       for (Loop& sub_loop :
434            make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
435         EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
436       }
437     }
438   }
439 }
440 
441 /*
442 Generated from the following GLSL
443 #version 330 core
444 layout(location = 0) out vec4 c;
445 void main() {
446   for (int i = 0; i < 10; ++i) {
447     for (int j = 0; j < 11; ++j) {
448       for (int k = 0; k < 11; ++k) {}
449     }
450     for (int k = 0; k < 12; ++k) {}
451   }
452 }
453 */
TEST_F(PassClassTest,LoopParentTest)454 TEST_F(PassClassTest, LoopParentTest) {
455   const std::string text = R"(
456                OpCapability Shader
457           %1 = OpExtInstImport "GLSL.std.450"
458                OpMemoryModel Logical GLSL450
459                OpEntryPoint Fragment %2 "main" %3
460                OpExecutionMode %2 OriginUpperLeft
461                OpSource GLSL 330
462                OpName %2 "main"
463                OpName %4 "i"
464                OpName %5 "j"
465                OpName %6 "k"
466                OpName %7 "k"
467                OpName %3 "c"
468                OpDecorate %3 Location 0
469           %8 = OpTypeVoid
470           %9 = OpTypeFunction %8
471          %10 = OpTypeInt 32 1
472          %11 = OpTypePointer Function %10
473          %12 = OpConstant %10 0
474          %13 = OpConstant %10 10
475          %14 = OpTypeBool
476          %15 = OpConstant %10 11
477          %16 = OpConstant %10 1
478          %17 = OpConstant %10 12
479          %18 = OpTypeFloat 32
480          %19 = OpTypeVector %18 4
481          %20 = OpTypePointer Output %19
482           %3 = OpVariable %20 Output
483           %2 = OpFunction %8 None %9
484          %21 = OpLabel
485           %4 = OpVariable %11 Function
486           %5 = OpVariable %11 Function
487           %6 = OpVariable %11 Function
488           %7 = OpVariable %11 Function
489                OpStore %4 %12
490                OpBranch %22
491          %22 = OpLabel
492                OpLoopMerge %23 %24 None
493                OpBranch %25
494          %25 = OpLabel
495          %26 = OpLoad %10 %4
496          %27 = OpSLessThan %14 %26 %13
497                OpBranchConditional %27 %28 %23
498          %28 = OpLabel
499                OpStore %5 %12
500                OpBranch %29
501          %29 = OpLabel
502                OpLoopMerge %30 %31 None
503                OpBranch %32
504          %32 = OpLabel
505          %33 = OpLoad %10 %5
506          %34 = OpSLessThan %14 %33 %15
507                OpBranchConditional %34 %35 %30
508          %35 = OpLabel
509                OpStore %6 %12
510                OpBranch %36
511          %36 = OpLabel
512                OpLoopMerge %37 %38 None
513                OpBranch %39
514          %39 = OpLabel
515          %40 = OpLoad %10 %6
516          %41 = OpSLessThan %14 %40 %15
517                OpBranchConditional %41 %42 %37
518          %42 = OpLabel
519                OpBranch %38
520          %38 = OpLabel
521          %43 = OpLoad %10 %6
522          %44 = OpIAdd %10 %43 %16
523                OpStore %6 %44
524                OpBranch %36
525          %37 = OpLabel
526                OpBranch %31
527          %31 = OpLabel
528          %45 = OpLoad %10 %5
529          %46 = OpIAdd %10 %45 %16
530                OpStore %5 %46
531                OpBranch %29
532          %30 = OpLabel
533                OpStore %7 %12
534                OpBranch %47
535          %47 = OpLabel
536                OpLoopMerge %48 %49 None
537                OpBranch %50
538          %50 = OpLabel
539          %51 = OpLoad %10 %7
540          %52 = OpSLessThan %14 %51 %17
541                OpBranchConditional %52 %53 %48
542          %53 = OpLabel
543                OpBranch %49
544          %49 = OpLabel
545          %54 = OpLoad %10 %7
546          %55 = OpIAdd %10 %54 %16
547                OpStore %7 %55
548                OpBranch %47
549          %48 = OpLabel
550                OpBranch %24
551          %24 = OpLabel
552          %56 = OpLoad %10 %4
553          %57 = OpIAdd %10 %56 %16
554                OpStore %4 %57
555                OpBranch %22
556          %23 = OpLabel
557                OpReturn
558                OpFunctionEnd
559   )";
560   // clang-format on
561   std::unique_ptr<IRContext> context =
562       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
563                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
564   Module* module = context->module();
565   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
566                              << text << std::endl;
567   const Function* f = spvtest::GetFunction(module, 2);
568   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
569 
570   EXPECT_EQ(ld.NumLoops(), 4u);
571 
572   {
573     Loop& loop = *ld[22];
574     EXPECT_TRUE(loop.HasNestedLoops());
575     EXPECT_FALSE(loop.IsNested());
576     EXPECT_EQ(loop.GetDepth(), 1u);
577     EXPECT_EQ(loop.GetParent(), nullptr);
578   }
579 
580   {
581     Loop& loop = *ld[29];
582     EXPECT_TRUE(loop.HasNestedLoops());
583     EXPECT_TRUE(loop.IsNested());
584     EXPECT_EQ(loop.GetDepth(), 2u);
585     EXPECT_EQ(loop.GetParent(), ld[22]);
586   }
587 
588   {
589     Loop& loop = *ld[36];
590     EXPECT_FALSE(loop.HasNestedLoops());
591     EXPECT_TRUE(loop.IsNested());
592     EXPECT_EQ(loop.GetDepth(), 3u);
593     EXPECT_EQ(loop.GetParent(), ld[29]);
594   }
595 
596   {
597     Loop& loop = *ld[47];
598     EXPECT_FALSE(loop.HasNestedLoops());
599     EXPECT_TRUE(loop.IsNested());
600     EXPECT_EQ(loop.GetDepth(), 2u);
601     EXPECT_EQ(loop.GetParent(), ld[22]);
602   }
603 }
604 
605 /*
606 Generated from the following GLSL + --eliminate-local-multi-store
607 The preheader of loop %33 and %41 were removed as well.
608 
609 #version 330 core
610 void main() {
611   int a = 0;
612   for (int i = 0; i < 10; ++i) {
613     if (i == 0) {
614       a = 1;
615     } else {
616       a = 2;
617     }
618     for (int j = 0; j < 11; ++j) {
619       a++;
620     }
621   }
622   for (int k = 0; k < 12; ++k) {}
623 }
624 */
TEST_F(PassClassTest,CreatePreheaderTest)625 TEST_F(PassClassTest, CreatePreheaderTest) {
626   const std::string text = R"(
627                OpCapability Shader
628           %1 = OpExtInstImport "GLSL.std.450"
629                OpMemoryModel Logical GLSL450
630                OpEntryPoint Fragment %2 "main"
631                OpExecutionMode %2 OriginUpperLeft
632                OpSource GLSL 330
633                OpName %2 "main"
634           %3 = OpTypeVoid
635           %4 = OpTypeFunction %3
636           %5 = OpTypeInt 32 1
637           %6 = OpTypePointer Function %5
638           %7 = OpConstant %5 0
639           %8 = OpConstant %5 10
640           %9 = OpTypeBool
641          %10 = OpConstant %5 1
642          %11 = OpConstant %5 2
643          %12 = OpConstant %5 11
644          %13 = OpConstant %5 12
645          %14 = OpUndef %5
646           %2 = OpFunction %3 None %4
647          %15 = OpLabel
648                OpBranch %16
649          %16 = OpLabel
650          %17 = OpPhi %5 %7 %15 %18 %19
651          %20 = OpPhi %5 %7 %15 %21 %19
652          %22 = OpPhi %5 %14 %15 %23 %19
653                OpLoopMerge %41 %19 None
654                OpBranch %25
655          %25 = OpLabel
656          %26 = OpSLessThan %9 %20 %8
657                OpBranchConditional %26 %27 %41
658          %27 = OpLabel
659          %28 = OpIEqual %9 %20 %7
660                OpSelectionMerge %33 None
661                OpBranchConditional %28 %30 %31
662          %30 = OpLabel
663                OpBranch %33
664          %31 = OpLabel
665                OpBranch %33
666          %33 = OpLabel
667          %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
668          %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
669                OpLoopMerge %37 %35 None
670                OpBranch %38
671          %38 = OpLabel
672          %39 = OpSLessThan %9 %23 %12
673                OpBranchConditional %39 %40 %37
674          %40 = OpLabel
675          %34 = OpIAdd %5 %18 %10
676                OpBranch %35
677          %35 = OpLabel
678          %36 = OpIAdd %5 %23 %10
679                OpBranch %33
680          %37 = OpLabel
681                OpBranch %19
682          %19 = OpLabel
683          %21 = OpIAdd %5 %20 %10
684                OpBranch %16
685          %41 = OpLabel
686          %42 = OpPhi %5 %7 %25 %43 %44
687                OpLoopMerge %45 %44 None
688                OpBranch %46
689          %46 = OpLabel
690          %47 = OpSLessThan %9 %42 %13
691                OpBranchConditional %47 %48 %45
692          %48 = OpLabel
693                OpBranch %44
694          %44 = OpLabel
695          %43 = OpIAdd %5 %42 %10
696                OpBranch %41
697          %45 = OpLabel
698                OpReturn
699                OpFunctionEnd
700   )";
701   // clang-format on
702   std::unique_ptr<IRContext> context =
703       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
704                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
705   Module* module = context->module();
706   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
707                              << text << std::endl;
708   const Function* f = spvtest::GetFunction(module, 2);
709   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
710   // No invalidation of the cfg should occur during this test.
711   CFG* cfg = context->cfg();
712 
713   EXPECT_EQ(ld.NumLoops(), 3u);
714 
715   {
716     Loop& loop = *ld[16];
717     EXPECT_TRUE(loop.HasNestedLoops());
718     EXPECT_FALSE(loop.IsNested());
719     EXPECT_EQ(loop.GetDepth(), 1u);
720     EXPECT_EQ(loop.GetParent(), nullptr);
721   }
722 
723   {
724     Loop& loop = *ld[33];
725     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
726     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
727     // Make sure the loop descriptor was properly updated.
728     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
729     {
730       const std::vector<uint32_t>& preds =
731           cfg->preds(loop.GetPreHeaderBlock()->id());
732       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
733       EXPECT_EQ(pred_set.size(), 2u);
734       EXPECT_TRUE(pred_set.count(30));
735       EXPECT_TRUE(pred_set.count(31));
736       // Check the phi instructions.
737       loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
738         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
739           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
740         }
741       });
742     }
743     {
744       const std::vector<uint32_t>& preds =
745           cfg->preds(loop.GetHeaderBlock()->id());
746       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
747       EXPECT_EQ(pred_set.size(), 2u);
748       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
749       EXPECT_TRUE(pred_set.count(35));
750       // Check the phi instructions.
751       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
752         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
753           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
754         }
755       });
756     }
757   }
758 
759   {
760     Loop& loop = *ld[41];
761     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
762     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
763     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
764     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
765     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
766     // Check the phi instructions.
767     loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) {
768       EXPECT_EQ(phi->NumInOperands(), 2u);
769       EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
770     });
771     {
772       const std::vector<uint32_t>& preds =
773           cfg->preds(loop.GetHeaderBlock()->id());
774       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
775       EXPECT_EQ(pred_set.size(), 2u);
776       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
777       EXPECT_TRUE(pred_set.count(44));
778       // Check the phi instructions.
779       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
780         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
781           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
782         }
783       });
784     }
785   }
786 
787   // Make sure pre-header insertion leaves the module valid.
788   std::vector<uint32_t> bin;
789   context->module()->ToBinary(&bin, true);
790   EXPECT_TRUE(Validate(bin));
791 }
792 
793 }  // namespace
794 }  // namespace opt
795 }  // namespace spvtools
796