• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "unit_test.h"
17 #include "optimizer/optimizations/regalloc/reg_alloc.h"
18 #include "optimizer/optimizations/loop_unroll.h"
19 #include "optimizer/optimizations/cleanup.h"
20 #include "optimizer/code_generator/codegen.h"
21 #include "optimizer/ir/graph_cloner.h"
22 
23 #if defined(PANDA_TARGET_ARM64) || defined(PANDA_TARGET_AMD64)
24 #include "vixl_exec_module.h"
25 #endif
26 
27 namespace panda::compiler {
28 class LoopUnrollTest : public GraphTest {
29 public:
30 #if defined(PANDA_TARGET_ARM64) || defined(PANDA_TARGET_AMD64)
LoopUnrollTest()31     LoopUnrollTest() : opcodes_count_(GetAllocator()->Adapter()), exec_module_(GetAllocator(), GetGraph()->GetRuntime())
32 #else
33     LoopUnrollTest() : opcodes_count_(GetAllocator()->Adapter())
34 #endif
35     {
36     }
37 
38     template <typename T>
CheckRetOnVixlSimulator(Graph * graph,T return_value)39     bool CheckRetOnVixlSimulator([[maybe_unused]] Graph *graph, [[maybe_unused]] T return_value)
40     {
41 #if defined(PANDA_TARGET_ARM64) || defined(PANDA_TARGET_AMD64)
42 #ifndef NDEBUG
43         // GraphChecker hack: LowLevel instructions may appear only after Lowering pass:
44         graph->SetLowLevelInstructionsEnabled();
45 #endif
46         EXPECT_TRUE(RegAlloc(graph));
47         EXPECT_TRUE(graph->RunPass<Codegen>());
48         auto entry = reinterpret_cast<char *>(graph->GetData().Data());
49         auto exit = entry + graph->GetData().Size();
50         ASSERT(entry != nullptr && exit != nullptr);
51         exec_module_.SetInstructions(entry, exit);
52         exec_module_.Execute();
53         return exec_module_.GetRetValue<T>() == return_value;
54 #else
55         return true;
56 #endif
57     }
58 
CountOpcodes(const ArenaVector<BasicBlock * > & blocks)59     size_t CountOpcodes(const ArenaVector<BasicBlock *> &blocks)
60     {
61         size_t count_inst = 0;
62         opcodes_count_.clear();
63         for (auto block : blocks) {
64             for (auto inst : block->AllInsts()) {
65                 opcodes_count_[inst->GetOpcode()]++;
66                 count_inst++;
67             }
68         }
69         return count_inst;
70     }
71 
GetOpcodeCount(Opcode opcode)72     size_t GetOpcodeCount(Opcode opcode)
73     {
74         return opcodes_count_.at(opcode);
75     }
76 
CheckSimpleLoop(uint32_t inst_limit,uint32_t unroll_factor,uint32_t expected_factor)77     void CheckSimpleLoop(uint32_t inst_limit, uint32_t unroll_factor, uint32_t expected_factor)
78     {
79         auto graph = CreateEmptyGraph();
80         GRAPH(graph)
81         {
82             PARAMETER(0, 0).u64();    // a = 0
83             PARAMETER(1, 1).u64();    // b = 1
84             PARAMETER(2, 100).u64();  // c = 100
85             PARAMETER(3, 101).u64();
86 
87             BASIC_BLOCK(2, 2, 3)
88             {
89                 INST(4, Opcode::Phi).u64().Inputs(0, 6);
90                 INST(5, Opcode::Phi).u64().Inputs(1, 7);
91                 INST(6, Opcode::Mul).u64().Inputs(4, 4);  // a = a * a
92                 INST(7, Opcode::Add).u64().Inputs(5, 3);  // b = b + 1
93 
94                 INST(8, Opcode::Compare).CC(CC_LT).b().Inputs(7, 2);  // while b < c
95                 INST(9, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(8);
96             }
97             BASIC_BLOCK(3, -1)
98             {
99                 INST(10, Opcode::Sub).u64().Inputs(6, 7);
100                 INST(11, Opcode::Return).u64().Inputs(10);  // return (a - b)
101             }
102         }
103         graph->RunPass<LoopUnroll>(inst_limit, unroll_factor);
104         graph->RunPass<Cleanup>();
105 
106         // Check number of instructions
107         CountOpcodes(graph->GetBlocksRPO());
108         EXPECT_EQ(GetOpcodeCount(Opcode::Add), expected_factor);
109         EXPECT_EQ(GetOpcodeCount(Opcode::Mul), expected_factor);
110         EXPECT_EQ(GetOpcodeCount(Opcode::Compare), expected_factor);
111         EXPECT_EQ(GetOpcodeCount(Opcode::IfImm), expected_factor);
112         EXPECT_EQ(GetOpcodeCount(Opcode::Sub), 1U);
113         EXPECT_EQ(GetOpcodeCount(Opcode::Parameter), 4U);
114 
115         if (expected_factor > 1) {
116             // Check control-flow
117             EXPECT_EQ(BB(3).GetSuccsBlocks().size(), 1U);
118             EXPECT_EQ(BB(3).GetSuccessor(0), graph->GetEndBlock());
119             EXPECT_EQ(BB(3).GetPredsBlocks().size(), expected_factor);
120 
121             // phi1 [INST(6, Mul), INST(6', Mul), INST(6'', Mul)]
122             auto phi1 = INS(10).GetInput(0).GetInst();
123             EXPECT_TRUE(phi1->IsPhi() && phi1->GetInputsCount() == expected_factor);
124             EXPECT_TRUE(phi1->GetInput(0).GetInst() == &INS(6));
125             for (auto input : phi1->GetInputs()) {
126                 EXPECT_TRUE(input.GetInst()->GetOpcode() == Opcode::Mul);
127             }
128 
129             // phi2 [INST(7, Add), INST(7', Add), INST(7'', Add)]
130             auto phi2 = INS(10).GetInput(1).GetInst();
131             EXPECT_TRUE(phi2->IsPhi() && phi2->GetInputsCount() == expected_factor);
132             EXPECT_TRUE(phi2->GetInput(0).GetInst() == &INS(7));
133             for (auto input : phi2->GetInputs()) {
134                 EXPECT_TRUE(input.GetInst()->GetOpcode() == Opcode::Add);
135             }
136 
137             // Check cloned `Mul` instruction inputs
138             for (size_t i = 1; i < phi1->GetInputsCount(); i++) {
139                 auto cloned_mul = phi1->GetInput(i).GetInst();
140                 auto prev_mul = phi1->GetInput(i - 1).GetInst();
141                 EXPECT_TRUE(cloned_mul->GetInput(0).GetInst() == prev_mul);
142                 EXPECT_TRUE(cloned_mul->GetInput(1).GetInst() == prev_mul);
143             }
144 
145             // Check cloned `Add` instruction inputs
146             for (size_t i = 1; i < phi2->GetInputsCount(); i++) {
147                 auto cloned_add = phi2->GetInput(i).GetInst();
148                 auto prev_add = phi2->GetInput(i - 1).GetInst();
149                 EXPECT_TRUE(cloned_add->GetInput(0).GetInst() == prev_add);
150                 EXPECT_TRUE(cloned_add->GetInput(1).GetInst() == &INS(3));
151             }
152         } else {
153             EXPECT_EQ(INS(10).GetInput(0).GetInst(), &INS(6));
154             EXPECT_EQ(INS(10).GetInput(1).GetInst(), &INS(7));
155             EXPECT_EQ(BB(3).GetPredsBlocks().size(), 1U);
156             EXPECT_EQ(BB(3).GetPredsBlocks()[0], &BB(2));
157         }
158     }
159 
CheckLoopWithPhiAndSafePoint(uint32_t inst_limit,uint32_t unroll_factor,uint32_t expected_factor)160     void CheckLoopWithPhiAndSafePoint(uint32_t inst_limit, uint32_t unroll_factor, uint32_t expected_factor)
161     {
162         auto graph = CreateEmptyGraph();
163         GRAPH(graph)
164         {
165             PARAMETER(0, 0).u64();  // a = 26
166             PARAMETER(1, 1).u64();  // b = 0
167             CONSTANT(2, 0);         // const 0
168             CONSTANT(3, 1UL);       // const 1
169             CONSTANT(4, 2UL);       // const 2
170             CONSTANT(5, 10UL);      // const 10
171 
172             BASIC_BLOCK(2, 3, 4)
173             {
174                 INST(6, Opcode::Phi).u64().Inputs(0, 15);
175                 INST(7, Opcode::Phi).u64().Inputs(1, 14);
176                 INST(20, Opcode::SafePoint).Inputs(0, 1).SrcVregs({0, 1});
177                 INST(8, Opcode::Mod).u64().Inputs(6, 4);              // mod = a % 2
178                 INST(9, Opcode::Compare).CC(CC_EQ).b().Inputs(8, 3);  // if mod == 1
179                 INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);
180             }
181             BASIC_BLOCK(3, 5)
182             {
183                 INST(11, Opcode::Add).u64().Inputs(7, 2);  // b = b + 0
184             }
185             BASIC_BLOCK(4, 5)
186             {
187                 INST(12, Opcode::Sub).u64().Inputs(7, 3);  // b = b + 1
188             }
189             BASIC_BLOCK(5, 6, 2)
190             {
191                 INST(13, Opcode::Phi).u64().Inputs(11, 12);
192                 INST(14, Opcode::Mul).u64().Inputs(13, 5);             // b = b * 10
193                 INST(15, Opcode::Div).u64().Inputs(6, 4);              // a = a / 2
194                 INST(16, Opcode::Compare).CC(CC_EQ).b().Inputs(6, 2);  // if a = 0
195                 INST(17, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(16);
196             }
197             BASIC_BLOCK(6, -1)
198             {
199                 INST(18, Opcode::Div).u64().Inputs(14, 5);  // b = b / 10
200                 INST(19, Opcode::Return).u64().Inputs(18);  // return b
201             }
202         }
203         graph->RunPass<LoopUnroll>(inst_limit, unroll_factor);
204         GraphChecker(graph).Check();
205 
206         // Check number of instructions
207         CountOpcodes(graph->GetBlocksRPO());
208         EXPECT_EQ(GetOpcodeCount(Opcode::Add), expected_factor);
209         EXPECT_EQ(GetOpcodeCount(Opcode::Sub), expected_factor);
210         EXPECT_EQ(GetOpcodeCount(Opcode::Mul), expected_factor);
211         EXPECT_EQ(GetOpcodeCount(Opcode::Mod), expected_factor);
212         EXPECT_EQ(GetOpcodeCount(Opcode::Div), expected_factor + 1);
213         EXPECT_EQ(GetOpcodeCount(Opcode::IfImm), 2 * expected_factor);
214         EXPECT_EQ(GetOpcodeCount(Opcode::Compare), 2 * expected_factor);
215         size_t extra_phi = (expected_factor > 1) ? 1 : 0;
216         EXPECT_EQ(GetOpcodeCount(Opcode::Phi),
217                   2 + expected_factor + extra_phi);        // 2 in the front-block + N unrolled + 1 in the outer-block
218         EXPECT_EQ(GetOpcodeCount(Opcode::SafePoint), 1U);  // SafePoint isn't unrolled
219 
220         if (expected_factor > 1) {
221             // Check control-flow
222             auto outer_block = BB(5).GetTrueSuccessor();
223             EXPECT_EQ(outer_block->GetSuccsBlocks().size(), 1U);
224             EXPECT_EQ(outer_block->GetSuccessor(0), &BB(6));
225             EXPECT_EQ(outer_block->GetPredsBlocks().size(), expected_factor);
226 
227             // phi [INST(14, Mul), INST(14', Mul)]
228             auto phi = INS(18).GetInput(0).GetInst();
229             EXPECT_TRUE(phi->IsPhi() && phi->GetInputsCount() == expected_factor);
230             EXPECT_TRUE(phi->GetInput(0).GetInst() == &INS(14));
231             for (auto input : phi->GetInputs()) {
232                 EXPECT_TRUE(input.GetInst()->GetOpcode() == Opcode::Mul);
233             }
234 
235             // Check cloned `Mul` instruction inputs
236             for (size_t i = 1; i < phi->GetInputsCount(); i++) {
237                 auto cloned_mul = phi->GetInput(i).GetInst();
238                 auto prev_mul = phi->GetInput(i - 1).GetInst();
239                 EXPECT_TRUE(cloned_mul->GetInput(0).GetInst()->IsPhi());
240                 EXPECT_TRUE(prev_mul->GetInput(0).GetInst()->IsPhi());
241                 EXPECT_NE(cloned_mul->GetInput(0).GetInst(), prev_mul->GetInput(0).GetInst());
242                 EXPECT_TRUE(cloned_mul->GetInput(1).GetInst() == &INS(5));
243             }
244         } else {
245             EXPECT_EQ(INS(18).GetInput(0).GetInst(), &INS(14));
246             EXPECT_EQ(INS(18).GetInput(1).GetInst(), &INS(5));
247             EXPECT_EQ(BB(6).GetPredsBlocks().size(), 1U);
248             EXPECT_EQ(BB(6).GetPredsBlocks()[0], &BB(5));
249         }
250     }
251 
252     Graph *BuildGraphPhiInputOfAnotherPhi();
253     template <ConditionCode cc, size_t stop>
254     Graph *BuildLoopWithIncrement(size_t step);
255     template <ConditionCode cc, size_t stop>
256     Graph *BuildLoopWithDecrement(size_t step);
257 
258 protected:
259     static constexpr uint32_t INST_LIMIT = 1000;
260 
261 private:
262     ArenaUnorderedMap<Opcode, size_t> opcodes_count_;
263 #if defined(PANDA_TARGET_ARM64) || defined(PANDA_TARGET_AMD64)
264     VixlExecModule exec_module_;
265 #endif
266 };
267 
268 /*
269  * Test Graph:
270  *              [0]
271  *               |
272  *               v
273  *              [2]<----\
274  *               |      |
275  *               v      |
276  *              [3]-----/
277  *               |
278  *               v
279  *             [exit]
280  *
281  *
282  * After unroll with FACTOR = 3
283  *
284  *              [0]
285  *               |
286  *               v
287  *              [2]<----\
288  *               |      |
289  *               v      |
290  *         /----[3]     |
291  *         |     |      |
292  *         |     v      |
293  *         |    [2']    |
294  *         |     |      |
295  *         |     v      |
296  *         |<---[3']    |
297  *         |     |      |
298  *         |     v      |
299  *         |    [2'']   |
300  *         |     |      |
301  *         |     v      |
302  *         |<---[3'']---/
303  *         |
304  *         |
305  *         \-->[outer]
306  *                |
307  *                v
308  *              [exit]
309  *
310  */
311 
312 /**
313  * There are 6 instructions in the loop [bb2, bb3], 4 of them are cloneable
314  * So we have the following mapping form unroll factor to number on unrolled instructions:
315  *
316  * factor | unrolled inst count
317  * 1        6
318  * 2        10
319  * 3        14
320  * 4        18
321  * ...
322  * 100      402
323  *
324  * unrolled_inst_count = (factor * cloneable_inst) + (not_cloneable_inst)
325  */
TEST_F(LoopUnrollTest,SimpleLoop)326 TEST_F(LoopUnrollTest, SimpleLoop)
327 {
328     CheckSimpleLoop(0, 4, 1);
329     CheckSimpleLoop(6, 4, 1);
330     CheckSimpleLoop(9, 4, 1);
331     CheckSimpleLoop(10, 4, 2);
332     CheckSimpleLoop(14, 4, 3);
333     CheckSimpleLoop(100, 4, 4);
334     CheckSimpleLoop(100, 10, 10);
335     CheckSimpleLoop(400, 100, 99);
336     CheckSimpleLoop(1000, 100, 100);
337 }
338 
339 /*
340  * Test Graph:
341  *              [0]
342  *               |
343  *               v
344  *              [2]<--------\
345  *             /   \        |
346  *            v     v       |
347  *           [3]    [4]     |
348  *            \      /      |
349  *             v    v       |
350  *              [5]---------/
351  *               |
352  *               v
353  *             [exit]
354  *
355  * After unroll with FACTOR = 2
356  *
357  *              [0]
358  *               |
359  *               v
360  *              [2]<--------\
361  *             /   \        |
362  *            v     v       |
363  *           [3]    [4]     |
364  *            \      /      |
365  *             v    v       |
366  *  /-----------[5]         |
367  *  |            |          |
368  *  |            v          |
369  *  |           [2']        |
370  *  |          /   \        |
371  *  |         v     v       |
372  *  |       [3']    [4']    |
373  *  |         \      /      |
374  *  |          v    v       |
375  *  |           [5']--------/
376  *  |            |
377  *  |            v
378  *  \--------->[outer]
379  *               |
380  *               v
381  *             [exit]
382  */
383 
384 /**
385  * There are 13 instructions in the loop [bb2, bb3, bb4, bb5], 10 of them are cloneable
386  * So we have the following mapping form unroll factor to number on unrolled instructions:
387  *
388  * factor | unrolled inst count
389  * 1        13
390  * 2        23
391  * 3        33
392  * 4        43
393  * ...
394  * 100      1003
395  *
396  * unrolled_inst_count = (factor * cloneable_inst) + (not_cloneable_inst)
397  */
TEST_F(LoopUnrollTest,LoopWithPhisAndSafePoint)398 TEST_F(LoopUnrollTest, LoopWithPhisAndSafePoint)
399 {
400     CheckLoopWithPhiAndSafePoint(0, 4, 1);
401     CheckLoopWithPhiAndSafePoint(13, 4, 1);
402     CheckLoopWithPhiAndSafePoint(22, 4, 1);
403     CheckLoopWithPhiAndSafePoint(23, 4, 2);
404     CheckLoopWithPhiAndSafePoint(33, 4, 3);
405     CheckLoopWithPhiAndSafePoint(100, 4, 4);
406     CheckLoopWithPhiAndSafePoint(1000, 10, 10);
407     CheckLoopWithPhiAndSafePoint(1003, 100, 100);
408 }
409 
410 /*
411  * Test Graph:
412  *              [0]
413  *               |
414  *               v
415  *         /----[2]<----\
416  *         |     |      |
417  *         |     v      |
418  *         |    [3]-----/
419  *         |
420  *         |
421  *         \--->[4]
422  *               |
423  *               v
424  *       /----->[5]<-----\
425  *       |       |       |
426  *       |       v       |
427  *       \------[6]      |
428  *               |       |
429  *               v       |
430  *              [7]------/
431  *               |
432  *               v
433  *             [exit]
434  */
TEST_F(LoopUnrollTest,UnrollNotApplied)435 TEST_F(LoopUnrollTest, UnrollNotApplied)
436 {
437     constexpr uint32_t UNROLL_FACTOR = 2;
438     GRAPH(GetGraph())
439     {
440         PARAMETER(0, 0).u64();
441         PARAMETER(1, 1).u64();
442         PARAMETER(2, 2).u64();
443         BASIC_BLOCK(2, 3, 4)
444         {
445             INST(3, Opcode::Compare).b().Inputs(0, 1);
446             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);
447         }
448         BASIC_BLOCK(3, 2)
449         {
450             INST(20, Opcode::SaveState).NoVregs();
451             INST(14, Opcode::CallStatic).v0id().InputsAutoType(20);
452         }
453         BASIC_BLOCK(4, 5) {}
454         BASIC_BLOCK(5, 6) {}
455         BASIC_BLOCK(6, 7, 5)
456         {
457             INST(13, Opcode::Add).u64().Inputs(1, 2);
458             INST(8, Opcode::Compare).b().Inputs(13, 0);
459             INST(9, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(8);
460         }
461         BASIC_BLOCK(7, 5, 8)
462         {
463             INST(10, Opcode::Compare).b().Inputs(1, 2);
464             INST(11, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(10);
465         }
466         BASIC_BLOCK(8, -1)
467         {
468             INST(12, Opcode::ReturnVoid);
469         }
470     }
471     auto inst_count = CountOpcodes(GetGraph()->GetBlocksRPO());
472     auto cmp_count = GetOpcodeCount(Opcode::Compare);
473     auto if_count = GetOpcodeCount(Opcode::IfImm);
474     auto add_count = GetOpcodeCount(Opcode::Add);
475 
476     GetGraph()->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
477     GraphChecker(GetGraph()).Check();
478     auto unrolled_count = CountOpcodes(GetGraph()->GetBlocksRPO());
479     EXPECT_EQ(GetOpcodeCount(Opcode::Compare), cmp_count);
480     EXPECT_EQ(GetOpcodeCount(Opcode::IfImm), if_count);
481     EXPECT_EQ(GetOpcodeCount(Opcode::Add), add_count);
482     EXPECT_EQ(unrolled_count, inst_count);
483 
484     EXPECT_EQ(BB(8).GetPredsBlocks().size(), 1U);
485     EXPECT_EQ(BB(8).GetPredsBlocks()[0], &BB(7));
486 }
487 
488 /**
489  *  a, b, c = 0, 1, 2
490  *  while c < 100:
491  *      a, b, c = b, c, c + a
492  *  return c
493  */
BuildGraphPhiInputOfAnotherPhi()494 Graph *LoopUnrollTest::BuildGraphPhiInputOfAnotherPhi()
495 {
496     auto graph = CreateEmptyGraph();
497     GRAPH(graph)
498     {
499         CONSTANT(0, 0);
500         CONSTANT(1, 1);
501         CONSTANT(2, 2);
502 
503         BASIC_BLOCK(2, 2, 3)
504         {
505             INST(4, Opcode::Phi).u64().Inputs(0, 5);
506             INST(5, Opcode::Phi).u64().Inputs(1, 6);
507             INST(6, Opcode::Phi).u64().Inputs(2, 7);
508             INST(7, Opcode::Add).u64().Inputs(4, 6);
509             INST(8, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(7);
510         }
511 
512         BASIC_BLOCK(3, -1)
513         {
514             INST(10, Opcode::Return).u64().Inputs(7);
515         }
516     }
517     return graph;
518 }
519 
TEST_F(LoopUnrollTest,PhiInputOfAnotherPhi)520 TEST_F(LoopUnrollTest, PhiInputOfAnotherPhi)
521 {
522     // Test with UNROLL_FACTOR = 2
523 
524     auto graph = BuildGraphPhiInputOfAnotherPhi();
525 
526     auto graph_unroll_factor_2 = CreateEmptyGraph();
527     GRAPH(graph_unroll_factor_2)
528     {
529         CONSTANT(0, 0);
530         CONSTANT(1, 1);
531         CONSTANT(2, 2);
532 
533         BASIC_BLOCK(2, 4, 3)
534         {
535             INST(4, Opcode::Phi).u64().Inputs(0, 6);
536             INST(5, Opcode::Phi).u64().Inputs(1, 7);
537             INST(6, Opcode::Phi).u64().Inputs(2, 11);
538             INST(7, Opcode::Add).u64().Inputs(4, 6);
539             INST(8, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(7);
540         }
541 
542         BASIC_BLOCK(4, 2, 3)
543         {
544             INST(11, Opcode::Add).u64().Inputs(5, 7);
545             INST(12, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(11);
546         }
547 
548         BASIC_BLOCK(3, -1)
549         {
550             INST(9, Opcode::Phi).u64().Inputs(7, 11);
551             INST(10, Opcode::Return).u64().Inputs(9);
552         }
553     }
554 
555     static constexpr uint64_t PROGRAM_RESULT = 101;
556     graph->RunPass<LoopUnroll>(INST_LIMIT, 2);
557     graph->RunPass<Cleanup>();
558     EXPECT_TRUE(GraphComparator().Compare(graph, graph_unroll_factor_2));
559     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph_unroll_factor_2, PROGRAM_RESULT));
560 
561     // Test with UNROLL_FACTOR = 4
562 
563     graph = BuildGraphPhiInputOfAnotherPhi();
564 
565     auto graph_unroll_factor_4 = CreateEmptyGraph();
566     GRAPH(graph_unroll_factor_4)
567     {
568         CONSTANT(0, 0);
569         CONSTANT(1, 1);
570         CONSTANT(2, 2);
571 
572         BASIC_BLOCK(2, 4, 3)
573         {
574             INST(4, Opcode::Phi).u64().Inputs(0, 11);
575             INST(5, Opcode::Phi).u64().Inputs(1, 13);
576             INST(6, Opcode::Phi).u64().Inputs(2, 15);
577             INST(7, Opcode::Add).u64().Inputs(4, 6);
578             INST(8, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(7);
579         }
580 
581         BASIC_BLOCK(4, 5, 3)
582         {
583             INST(11, Opcode::Add).u64().Inputs(5, 7);
584             INST(12, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(11);
585         }
586 
587         BASIC_BLOCK(5, 6, 3)
588         {
589             INST(13, Opcode::Add).u64().Inputs(6, 11);
590             INST(14, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(13);
591         }
592 
593         BASIC_BLOCK(6, 2, 3)
594         {
595             INST(15, Opcode::Add).u64().Inputs(7, 13);
596             INST(16, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(15);
597         }
598 
599         BASIC_BLOCK(3, -1)
600         {
601             INST(9, Opcode::Phi).u64().Inputs(7, 11, 13, 15);
602             INST(10, Opcode::Return).u64().Inputs(9);
603         }
604     }
605 
606     graph->RunPass<LoopUnroll>(INST_LIMIT, 4);
607     graph->RunPass<Cleanup>();
608     EXPECT_TRUE(GraphComparator().Compare(graph, graph_unroll_factor_4));
609     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph_unroll_factor_4, PROGRAM_RESULT));
610 }
611 
TEST_F(LoopUnrollTest,PhiInputsOutsideLoop)612 TEST_F(LoopUnrollTest, PhiInputsOutsideLoop)
613 {
614     auto graph = CreateEmptyGraph();
615     GRAPH(graph)
616     {
617         PARAMETER(0, 0).u64();
618         CONSTANT(1, 1);
619         CONSTANT(2, 2);
620 
621         BASIC_BLOCK(2, 2, 3)
622         {
623             INST(4, Opcode::Phi).u64().Inputs(1, 2);
624             INST(5, Opcode::Phi).u64().Inputs(0, 6);
625             INST(6, Opcode::Add).u64().Inputs(4, 5);
626             INST(7, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(6);
627         }
628 
629         BASIC_BLOCK(3, -1)
630         {
631             INST(10, Opcode::Return).u64().Inputs(5);
632         }
633     }
634 
635     graph->RunPass<LoopUnroll>(INST_LIMIT, 2);
636     graph->RunPass<Cleanup>();
637 
638     auto expected_graph = CreateEmptyGraph();
639     GRAPH(expected_graph)
640     {
641         PARAMETER(0, 0).u64();
642         CONSTANT(1, 1);
643         CONSTANT(2, 2);
644         BASIC_BLOCK(2, 3)
645         {
646             // preheader
647         }
648         BASIC_BLOCK(3, 4, 5)
649         {
650             INST(4, Opcode::Phi).u64().Inputs(1, 2);
651             INST(5, Opcode::Phi).u64().Inputs(0, 8);
652             INST(6, Opcode::Add).u64().Inputs(4, 5);
653             INST(7, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(6);
654         }
655 
656         BASIC_BLOCK(4, 3, 5)
657         {
658             INST(8, Opcode::Add).u64().Inputs(2, 6);
659             INST(9, Opcode::IfImm).SrcType(DataType::UINT64).CC(CC_LT).Imm(100).Inputs(8);
660         }
661         BASIC_BLOCK(5, -1)
662         {
663             INST(10, Opcode::Phi).u64().Inputs(5, 6);
664             INST(11, Opcode::Return).u64().Inputs(10);
665         }
666     }
667     EXPECT_TRUE(GraphComparator().Compare(graph, expected_graph));
668 }
669 
670 template <ConditionCode cc, size_t stop>
BuildLoopWithIncrement(size_t step)671 Graph *LoopUnrollTest::BuildLoopWithIncrement(size_t step)
672 {
673     auto graph = CreateEmptyGraph();
674     GRAPH(graph)
675     {
676         CONSTANT(0, stop);
677         CONSTANT(1, 0);  // a = 0, b = 0
678         CONSTANT(2, step);
679         BASIC_BLOCK(2, 3, 4)
680         {
681             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(cc).Inputs(1, 0);
682             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);  // if a [cc] stop
683         }
684         BASIC_BLOCK(3, 3, 4)
685         {
686             INST(5, Opcode::Phi).s32().Inputs(1, 7);  // a
687             INST(6, Opcode::Phi).s32().Inputs(1, 8);  // b
688             INST(7, Opcode::Add).s32().Inputs(5, 2);  // a += step
689             INST(8, Opcode::Add).s32().Inputs(6, 7);  // b += a
690             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(cc).Inputs(7, 0);
691             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);  // if a [cc] stop
692         }
693         BASIC_BLOCK(4, -1)
694         {
695             INST(11, Opcode::Phi).s32().Inputs(1, 6);
696             INST(12, Opcode::Return).s32().Inputs(11);  // return b
697         }
698     }
699     return graph;
700 }
701 
702 template <ConditionCode cc, size_t start>
BuildLoopWithDecrement(size_t step)703 Graph *LoopUnrollTest::BuildLoopWithDecrement(size_t step)
704 {
705     auto graph = CreateEmptyGraph();
706     GRAPH(graph)
707     {
708         CONSTANT(0, start);  // a = 10
709         CONSTANT(1, 0);      // b = 0
710         CONSTANT(2, step);
711 
712         BASIC_BLOCK(2, 3, 4)
713         {
714             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(cc).Inputs(0, 1);  // if a [cc] 0
715             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);
716         }
717         BASIC_BLOCK(3, 3, 4)
718         {
719             INST(5, Opcode::Phi).s32().Inputs(0, 8);                                    // a
720             INST(6, Opcode::Phi).s32().Inputs(1, 7);                                    // b
721             INST(7, Opcode::Add).s32().Inputs(6, 5);                                    // b += a
722             INST(8, Opcode::Sub).s32().Inputs(5, 2);                                    // a -= 1
723             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(cc).Inputs(8, 1);  // if a [cc] 0
724             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);
725         }
726         BASIC_BLOCK(4, -1)
727         {
728             INST(11, Opcode::Phi).s32().Inputs(1, 7);
729             INST(12, Opcode::Return).s32().Inputs(11);  // return b
730         }
731     }
732     return graph;
733 }
734 
TEST_F(LoopUnrollTest,CountableLoopWithIncrement)735 TEST_F(LoopUnrollTest, CountableLoopWithIncrement)
736 {
737     static constexpr uint32_t INC_STEP = 1;
738     static constexpr uint32_t INC_STOP = 10;
739     for (size_t unroll_factor = 1; unroll_factor <= 10; unroll_factor++) {
740         auto graph = BuildLoopWithIncrement<CC_LT, INC_STOP>(INC_STEP);
741         graph->RunPass<LoopUnroll>(INST_LIMIT, unroll_factor);
742         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, 45));
743 
744         graph = BuildLoopWithIncrement<CC_LE, INC_STOP>(INC_STEP);
745         graph->RunPass<LoopUnroll>(INST_LIMIT, unroll_factor);
746         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, 55));
747     }
748 
749     static constexpr uint32_t UNROLL_FACTOR = 2;
750     auto graph = BuildLoopWithIncrement<CC_LT, INC_STOP>(INC_STEP);
751     graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
752     graph->RunPass<Cleanup>();
753 
754     auto graph_unroll = CreateEmptyGraph();
755     GRAPH(graph_unroll)
756     {
757         CONSTANT(0, 10);
758         CONSTANT(1, 0);  // a = 0, b = 0
759         CONSTANT(2, 1);  // UNROLL_FACTOR - 1 = 1
760         // NB: add a new constant if UNROLL_FACTOR is changed and fix INST(20, Opcode::Sub).
761 
762         BASIC_BLOCK(2, 3, 5)
763         {
764             // NB: replace the second input if UNROLL_FACTOR is changed:
765             INST(20, Opcode::Sub).s32().Inputs(0, 2);
766             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(1, 20);  // if (a < 10 -
767                                                                                             // (UNROLL_FACTOR - 1))
768             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);
769         }
770         BASIC_BLOCK(3, 3, 5)
771         {
772             INST(5, Opcode::Phi).s32().Inputs(1, 21);   // a
773             INST(6, Opcode::Phi).s32().Inputs(1, 22);   // b
774             INST(7, Opcode::Add).s32().Inputs(5, 2);    // a + 1
775             INST(8, Opcode::Add).s32().Inputs(6, 7);    // b + 1
776             INST(21, Opcode::Add).s32().Inputs(7, 2);   // a + 1
777             INST(22, Opcode::Add).s32().Inputs(8, 21);  // b + 1
778             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(21, 20);
779             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);  // if a < 10 - (UNROLL_FACTOR -
780                                                                                          // 1)
781         }
782         BASIC_BLOCK(5, 6, 4)
783         {
784             INST(11, Opcode::Phi).s32().Inputs(1, 8);
785             INST(25, Opcode::Phi).s32().Inputs(1, 21);                                       // a
786             INST(26, Opcode::Phi).s32().Inputs(1, 22);                                       // b
787             INST(27, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(25, 0);  // if (a < 10)
788             INST(28, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(27);
789         }
790         BASIC_BLOCK(6, 4) {}
791         BASIC_BLOCK(4, -1)
792         {
793             INST(31, Opcode::Phi).s32().Inputs(11, 26);
794             INST(12, Opcode::Return).s32().Inputs(31);  // return b
795         }
796     }
797     EXPECT_TRUE(GraphComparator().Compare(graph, graph_unroll));
798 }
799 
TEST_F(LoopUnrollTest,CountableLoopWithDecrement)800 TEST_F(LoopUnrollTest, CountableLoopWithDecrement)
801 {
802     static constexpr uint32_t DEC_STEP = 1;
803     static constexpr uint32_t DEC_START = 10;
804     for (size_t unroll_factor = 1; unroll_factor <= 10; unroll_factor++) {
805         auto graph = BuildLoopWithDecrement<CC_GT, DEC_START>(DEC_STEP);
806         graph->RunPass<LoopUnroll>(INST_LIMIT, unroll_factor);
807         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, 55));
808 
809         graph = BuildLoopWithDecrement<CC_GE, DEC_START>(DEC_STEP);
810         graph->RunPass<LoopUnroll>(INST_LIMIT, unroll_factor);
811         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, 55));
812     }
813 
814     static constexpr uint32_t UNROLL_FACTOR = 2;
815     auto graph = BuildLoopWithDecrement<CC_GT, DEC_START>(DEC_STEP);
816     graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
817     graph->RunPass<Cleanup>();
818 
819     auto graph_unroll = CreateEmptyGraph();
820     GRAPH(graph_unroll)
821     {
822         CONSTANT(0, 10);  // a = 10
823         CONSTANT(1, 0);   // b = 0
824         CONSTANT(2, 1);   // UNROLL_FACTOR - 1 = 1
825         // NB: add a new constant if UNROLL_FACTOR is changed and fix INST(20, Opcode::Add).
826 
827         BASIC_BLOCK(2, 3, 5)
828         {
829             // NB: replace the second input if UNROLL_FACTOR is changed:
830             INST(20, Opcode::Add).s32().Inputs(1, 2);
831             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GT).Inputs(0, 20);  // if (a > UNROLL_FACTOR -
832                                                                                             // 1)
833             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);
834         }
835         BASIC_BLOCK(3, 3, 5)
836         {
837             INST(5, Opcode::Phi).s32().Inputs(0, 22);                                        // a
838             INST(6, Opcode::Phi).s32().Inputs(1, 21);                                        // b
839             INST(7, Opcode::Add).s32().Inputs(6, 5);                                         // b += a
840             INST(8, Opcode::Sub).s32().Inputs(5, 2);                                         // a -= 1
841             INST(21, Opcode::Add).s32().Inputs(7, 8);                                        // b += a
842             INST(22, Opcode::Sub).s32().Inputs(8, 2);                                        // a -= 1
843             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GT).Inputs(22, 20);  // if (a > UNROLL_FACTOR -
844                                                                                              // 1)
845             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);
846         }
847         BASIC_BLOCK(5, 6, 4)
848         {
849             INST(25, Opcode::Phi).s32().Inputs(1, 21);                                       // b
850             INST(26, Opcode::Phi).s32().Inputs(0, 22);                                       // a
851             INST(27, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GT).Inputs(26, 1);  // if (a > 0)
852             INST(28, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(27);
853         }
854         BASIC_BLOCK(6, 4)
855         {
856             INST(29, Opcode::Add).s32().Inputs(25, 26);  // b += a
857         }
858         BASIC_BLOCK(4, -1)
859         {
860             INST(31, Opcode::Phi).s32().Inputs(25, 29);
861             INST(12, Opcode::Return).s32().Inputs(31);  // return b
862         }
863     }
864     EXPECT_TRUE(GraphComparator().Compare(graph, graph_unroll));
865 }
866 
TEST_F(LoopUnrollTest,InversedCompares)867 TEST_F(LoopUnrollTest, InversedCompares)
868 {
869     // Case 1: if (a < 10 is false) goto exit
870     auto graph = CreateEmptyGraph();
871     GRAPH(graph)
872     {
873         CONSTANT(0, 10);
874         CONSTANT(1, 0);  // a = 0, b = 0
875         CONSTANT(2, 1);
876         BASIC_BLOCK(2, 4, 3)
877         {
878             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(1, 0);
879             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(3);  // if a < 10 goto loop
880         }
881         BASIC_BLOCK(3, 4, 3)
882         {
883             INST(5, Opcode::Phi).s32().Inputs(1, 7);  // a
884             INST(6, Opcode::Phi).s32().Inputs(1, 8);  // b
885             INST(7, Opcode::Add).s32().Inputs(5, 2);  // a += 1
886             INST(8, Opcode::Add).s32().Inputs(6, 7);  // b += a
887             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(7, 0);
888             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(9);  // if a < 10 goto loop
889         }
890         BASIC_BLOCK(4, -1)
891         {
892             INST(11, Opcode::Phi).s32().Inputs(1, 6);
893             INST(12, Opcode::Return).s32().Inputs(11);  // return b
894         }
895     }
896 
897     static constexpr uint32_t UNROLL_FACTOR = 2;
898     graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
899     graph->RunPass<Cleanup>();
900 
901     auto graph_unroll = CreateEmptyGraph();
902     GRAPH(graph_unroll)
903     {
904         CONSTANT(0, 10);
905         CONSTANT(1, 0);  // a = 0, b = 0
906         CONSTANT(2, 1);  // UNROLL_FACTOR - 1 = 1
907         // NB: add a new constant if UNROLL_FACTOR is changed and fix INST(20, Opcode::Sub).
908 
909         BASIC_BLOCK(2, 3, 5)
910         {
911             // NB: replace the second input if UNROLL_FACTOR is changed:
912             INST(20, Opcode::Sub).s32().Inputs(0, 2);
913             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(1, 20);
914             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);  // if (a <= 10 - UNROLL_FACTOR)
915         }
916         BASIC_BLOCK(3, 3, 5)
917         {
918             INST(5, Opcode::Phi).s32().Inputs(1, 21);   // a
919             INST(6, Opcode::Phi).s32().Inputs(1, 22);   // b
920             INST(7, Opcode::Add).s32().Inputs(5, 2);    // a + 1
921             INST(8, Opcode::Add).s32().Inputs(6, 7);    // b + 1
922             INST(21, Opcode::Add).s32().Inputs(7, 2);   // a + 1
923             INST(22, Opcode::Add).s32().Inputs(8, 21);  // b + 1
924             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(21, 20);
925             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);  // if (a <= 10 - UNROLL_FACTOR)
926         }
927         BASIC_BLOCK(5, 4, 6)
928         {
929             INST(11, Opcode::Phi).s32().Inputs(1, 8);
930             INST(25, Opcode::Phi).s32().Inputs(1, 21);                                       // a
931             INST(26, Opcode::Phi).s32().Inputs(1, 22);                                       // b
932             INST(27, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(25, 0);  // if (a < 10)
933             INST(28, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(27);
934         }
935         BASIC_BLOCK(6, 4) {}
936         BASIC_BLOCK(4, -1)
937         {
938             INST(31, Opcode::Phi).s32().Inputs(11, 26);
939             INST(12, Opcode::Return).s32().Inputs(31);  // return b
940         }
941     }
942     EXPECT_TRUE(GraphComparator().Compare(graph, graph_unroll));
943     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, 45));
944 
945     // Case 2: if (a >= 10 is false) goto loop
946     auto graph2 = CreateEmptyGraph();
947     GRAPH(graph2)
948     {
949         CONSTANT(0, 10);
950         CONSTANT(1, 0);  // a = 0, b = 0
951         CONSTANT(2, 1);
952         BASIC_BLOCK(2, 3, 4)
953         {
954             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GE).Inputs(1, 0);
955             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(3);  // if a < 10 goto loop
956         }
957         BASIC_BLOCK(3, 3, 4)
958         {
959             INST(5, Opcode::Phi).s32().Inputs(1, 7);  // a
960             INST(6, Opcode::Phi).s32().Inputs(1, 8);  // b
961             INST(7, Opcode::Add).s32().Inputs(5, 2);  // a += 1
962             INST(8, Opcode::Add).s32().Inputs(6, 7);  // b += a
963             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GE).Inputs(7, 0);
964             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(9);  // if a < 10 goto loop
965         }
966         BASIC_BLOCK(4, -1)
967         {
968             INST(11, Opcode::Phi).s32().Inputs(1, 6);
969             INST(12, Opcode::Return).s32().Inputs(11);  // return b
970         }
971     }
972 
973     graph2->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
974     graph2->RunPass<Cleanup>();
975 
976     auto graph2_unroll = CreateEmptyGraph();
977     GRAPH(graph2_unroll)
978     {
979         CONSTANT(0, 10);
980         CONSTANT(1, 0);  // a = 0, b = 0
981         CONSTANT(2, 1);  // UNROLL_FACTOR - 1 = 1
982         // NB: add a new constant if UNROLL_FACTOR is changed and fix INST(20, Opcode::Sub).
983 
984         BASIC_BLOCK(2, 3, 5)
985         {
986             // NB: replace the second input if UNROLL_FACTOR is changed:
987             INST(20, Opcode::Sub).s32().Inputs(0, 2);
988             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(1, 20);
989             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);  // if (a <= 10 - UNROLL_FACTOR)
990         }
991         BASIC_BLOCK(3, 3, 5)
992         {
993             INST(5, Opcode::Phi).s32().Inputs(1, 21);   // a
994             INST(6, Opcode::Phi).s32().Inputs(1, 22);   // b
995             INST(7, Opcode::Add).s32().Inputs(5, 2);    // a + 1
996             INST(8, Opcode::Add).s32().Inputs(6, 7);    // b + 1
997             INST(21, Opcode::Add).s32().Inputs(7, 2);   // a + 1
998             INST(22, Opcode::Add).s32().Inputs(8, 21);  // b + 1
999             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(21, 20);
1000             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);  // if (a <= 10 - UNROLL_FACTOR)
1001         }
1002         BASIC_BLOCK(5, 6, 4)
1003         {
1004             INST(11, Opcode::Phi).s32().Inputs(1, 8);
1005             INST(25, Opcode::Phi).s32().Inputs(1, 21);                                       // a
1006             INST(26, Opcode::Phi).s32().Inputs(1, 22);                                       // b
1007             INST(27, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_GE).Inputs(25, 0);  // if (a < 10)
1008             INST(28, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(27);
1009         }
1010         BASIC_BLOCK(6, 4) {}
1011         BASIC_BLOCK(4, -1)
1012         {
1013             INST(31, Opcode::Phi).s32().Inputs(11, 26);
1014             INST(12, Opcode::Return).s32().Inputs(31);  // return b
1015         }
1016     }
1017     EXPECT_TRUE(GraphComparator().Compare(graph2, graph2_unroll));
1018     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph2, 45));
1019 
1020     // Case 3 - if (10 != a) goto loop
1021     auto graph3 = CreateEmptyGraph();
1022     GRAPH(graph3)
1023     {
1024         CONSTANT(0, 10);
1025         CONSTANT(1, 0);  // a = 0, b = 0
1026         CONSTANT(2, 1);
1027         BASIC_BLOCK(2, 4, 3)
1028         {
1029             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_NE).Inputs(0, 1);
1030             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(3);  // if 10 != a goto loop
1031         }
1032         BASIC_BLOCK(3, 4, 3)
1033         {
1034             INST(5, Opcode::Phi).s32().Inputs(1, 7);  // a
1035             INST(6, Opcode::Phi).s32().Inputs(1, 8);  // b
1036             INST(7, Opcode::Add).s32().Inputs(5, 2);  // a += 1
1037             INST(8, Opcode::Add).s32().Inputs(6, 7);  // b += a
1038             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_NE).Inputs(0, 7);
1039             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_EQ).Imm(0).Inputs(9);  // if 10 != a goto loop
1040         }
1041         BASIC_BLOCK(4, -1)
1042         {
1043             INST(11, Opcode::Phi).s32().Inputs(1, 6);
1044             INST(12, Opcode::Return).s32().Inputs(11);  // return b
1045         }
1046     }
1047     graph3->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
1048     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph3, 45));
1049 
1050     // Case 4 (decrement): if (0 == a) goto out_loop
1051     auto graph4 = CreateEmptyGraph();
1052     GRAPH(graph4)
1053     {
1054         CONSTANT(0, 9);  // a = 9
1055         CONSTANT(1, 0);  // b = 0
1056         CONSTANT(2, 1);
1057 
1058         BASIC_BLOCK(2, 4, 3)
1059         {
1060             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_EQ).Inputs(1, 0);  // if 0 == a goto out_loop
1061             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);
1062         }
1063         BASIC_BLOCK(3, 4, 3)
1064         {
1065             INST(5, Opcode::Phi).s32().Inputs(0, 8);                                       // a
1066             INST(6, Opcode::Phi).s32().Inputs(1, 7);                                       // b
1067             INST(7, Opcode::Add).s32().Inputs(6, 5);                                       // b += a
1068             INST(8, Opcode::Sub).s32().Inputs(5, 2);                                       // a -= 1
1069             INST(9, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_EQ).Inputs(1, 8);  // if 0 == a goto out_loop
1070             INST(10, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(9);
1071         }
1072         BASIC_BLOCK(4, -1)
1073         {
1074             INST(11, Opcode::Phi).s32().Inputs(1, 7);
1075             INST(12, Opcode::Return).s32().Inputs(11);  // return b
1076         }
1077     }
1078     graph4->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR);
1079     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph4, 45));
1080 }
1081 
TEST_F(LoopUnrollTest,LoopWithDifferentConstants)1082 TEST_F(LoopUnrollTest, LoopWithDifferentConstants)
1083 {
1084     static constexpr uint32_t UNROLL_FACTOR = 2;
1085 
1086     // Chech increment
1087     static constexpr uint32_t INC_STOP = 100;
1088     for (size_t inc_step = 1; inc_step <= 10; inc_step++) {
1089         // CC_LT
1090         size_t result = 0;
1091         for (size_t i = 0; i < INC_STOP; i += inc_step) {
1092             result += i;
1093         }
1094         auto graph = BuildLoopWithIncrement<CC_LT, INC_STOP>(inc_step);
1095         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1096         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1097 
1098         // CC_LE
1099         result = 0;
1100         for (size_t i = 0; i <= INC_STOP; i += inc_step) {
1101             result += i;
1102         }
1103         graph = BuildLoopWithIncrement<CC_LE, INC_STOP>(inc_step);
1104         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1105         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1106 
1107         // CC_NE
1108         if (INC_STOP % inc_step != 0) {
1109             // Otherwise test loop with CC_NE will be infinite
1110             continue;
1111         }
1112         result = 0;
1113         for (size_t i = 0; i != INC_STOP; i += inc_step) {
1114             result += i;
1115         }
1116         graph = BuildLoopWithIncrement<CC_NE, INC_STOP>(inc_step);
1117         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1118         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1119     }
1120 
1121     // Chech decrement
1122     static constexpr uint32_t DEC_START = 100;
1123     for (size_t dec_step = 1; dec_step <= 10; dec_step++) {
1124         // CC_GT
1125         int result = 0;
1126         for (int i = DEC_START; i > 0; i -= dec_step) {
1127             result += i;
1128         }
1129         auto graph = BuildLoopWithDecrement<CC_GT, DEC_START>(dec_step);
1130         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1131         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1132 
1133         // CC_GE
1134         result = 0;
1135         for (int i = DEC_START; i >= 0; i -= dec_step) {
1136             result += i;
1137         }
1138         graph = BuildLoopWithDecrement<CC_GE, DEC_START>(dec_step);
1139         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1140         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1141 
1142         // CC_NE
1143         if (INC_STOP % dec_step != 0) {
1144             // Otherwise test loop with CC_NE will be infinite
1145             continue;
1146         }
1147         result = 0;
1148         for (int i = DEC_START; i != 0; i -= dec_step) {
1149             result += i;
1150         }
1151         graph = BuildLoopWithDecrement<CC_NE, DEC_START>(dec_step);
1152         EXPECT_TRUE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1153         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, result));
1154     }
1155 }
1156 
TEST_F(LoopUnrollTest,PredsInversedOrder)1157 TEST_F(LoopUnrollTest, PredsInversedOrder)
1158 {
1159     auto graph = CreateEmptyGraph();
1160     GRAPH(graph)
1161     {
1162         PARAMETER(0, 0).s64();  // a
1163         PARAMETER(1, 1).s64();  // b
1164         CONSTANT(2, 1);
1165         CONSTANT(3, 2);
1166         BASIC_BLOCK(2, 3, 4)
1167         {
1168             INST(6, Opcode::Phi).s64().Inputs(1, 12);                             // b
1169             INST(7, Opcode::Mod).s64().Inputs(6, 3);                              // b % 2
1170             INST(8, Opcode::If).SrcType(DataType::INT64).CC(CC_EQ).Inputs(7, 2);  // if b % 2 == 1
1171         }
1172         BASIC_BLOCK(3, 4)
1173         {
1174             INST(10, Opcode::Mul).s64().Inputs(6, 6);  // b = b * b
1175         }
1176         BASIC_BLOCK(4, 2, 5)
1177         {
1178             INST(12, Opcode::Phi).s64().Inputs({{3, 10}, {2, 6}});
1179             INST(13, Opcode::Compare).CC(CC_LT).b().Inputs(12, 0);  // if b < a
1180             INST(14, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(13);
1181         }
1182         BASIC_BLOCK(5, -1)
1183         {
1184             INST(15, Opcode::Return).s64().Inputs(12);  // return b
1185         }
1186     }
1187     // Swap BB4 preds
1188     std::swap(BB(4).GetPredsBlocks()[0], BB(4).GetPredsBlocks()[1]);
1189     INS(12).SetInput(0, &INS(6));
1190     INS(12).SetInput(1, &INS(10));
1191     graph->RunPass<LoopUnroll>(INST_LIMIT, 2);
1192     graph->RunPass<Cleanup>();
1193 
1194     auto expected_graph = CreateEmptyGraph();
1195     GRAPH(expected_graph)
1196     {
1197         PARAMETER(0, 0).s64();  // a
1198         PARAMETER(1, 1).s64();  // b
1199         CONSTANT(2, 1);
1200         CONSTANT(3, 2);
1201         BASIC_BLOCK(6, 2) {}
1202         BASIC_BLOCK(2, 3, 4)
1203         {
1204             INST(6, Opcode::Phi).s64().Inputs(1, 19);                             // b
1205             INST(7, Opcode::Mod).s64().Inputs(6, 3);                              // b % 2
1206             INST(8, Opcode::If).SrcType(DataType::INT64).CC(CC_EQ).Inputs(7, 2);  // if b % 2 == 1
1207         }
1208         BASIC_BLOCK(3, 4)
1209         {
1210             INST(10, Opcode::Mul).s64().Inputs(6, 6);  // b = b * b
1211         }
1212         BASIC_BLOCK(4, 9, 8)
1213         {
1214             INST(12, Opcode::Phi).s64().Inputs({{3, 10}, {2, 6}});
1215             INST(13, Opcode::Compare).CC(CC_LT).b().Inputs(12, 0);  // if b < a
1216             INST(14, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(13);
1217         }
1218         BASIC_BLOCK(9, 10, 11)
1219         {
1220             INST(16, Opcode::Mod).s64().Inputs(12, 3);                              // b % 2
1221             INST(17, Opcode::If).SrcType(DataType::INT64).CC(CC_EQ).Inputs(16, 2);  // if b % 2 == 1
1222         }
1223         BASIC_BLOCK(10, 11)
1224         {
1225             INST(18, Opcode::Mul).s64().Inputs(12, 12);  // b = b * b
1226         }
1227         BASIC_BLOCK(11, 2, 8)
1228         {
1229             INST(19, Opcode::Phi).s64().Inputs({{9, 12}, {10, 18}});
1230             INST(20, Opcode::Compare).CC(CC_LT).b().Inputs(19, 0);  // if b < a
1231             INST(21, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(20);
1232         }
1233         BASIC_BLOCK(8, -1)
1234         {
1235             INST(22, Opcode::Phi).s64().Inputs({{4, 12}, {11, 19}});
1236             INST(15, Opcode::Return).s64().Inputs(22);  // return b
1237         }
1238     }
1239     EXPECT_TRUE(GraphComparator().Compare(graph, expected_graph));
1240 }
1241 
1242 // TODO (a.popov) Fix after supporting infinite loops unrolling
TEST_F(LoopUnrollTest,InfiniteLoop)1243 TEST_F(LoopUnrollTest, InfiniteLoop)
1244 {
1245     auto graph = CreateEmptyGraph();
1246     GRAPH(graph)
1247     {
1248         PARAMETER(0, 0).s32();
1249         CONSTANT(1, 1);
1250 
1251         BASIC_BLOCK(2, 2)
1252         {
1253             INST(2, Opcode::Phi).s32().Inputs(0, 3);
1254             INST(3, Opcode::Add).s32().Inputs(2, 1);
1255         }
1256     }
1257     EXPECT_FALSE(graph->RunPass<LoopUnroll>(1000, 2));
1258 }
1259 
TEST_F(LoopUnrollTest,PhiDominatesItsPhiInput)1260 TEST_F(LoopUnrollTest, PhiDominatesItsPhiInput)
1261 {
1262     auto graph = CreateEmptyGraph();
1263     GRAPH(graph)
1264     {
1265         CONSTANT(0, 0);
1266         CONSTANT(1, 1);
1267         CONSTANT(2, 100);
1268 
1269         BASIC_BLOCK(2, 2, 3)
1270         {
1271             INST(5, Opcode::Phi).s32().Inputs(0, 7);
1272             INST(6, Opcode::Phi).s32().Inputs(1, 5);
1273             INST(7, Opcode::Add).s32().Inputs(5, 6);  // Fibonacci Sequence
1274             INST(8, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(7, 2);
1275             INST(9, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(8);
1276         }
1277         BASIC_BLOCK(3, -1)
1278         {
1279             INST(10, Opcode::Return).s32().Inputs(7);
1280         }
1281     }
1282 
1283     static constexpr uint64_t PROGRAM_RESULT = 144;
1284     for (auto unroll_factor = 2; unroll_factor < 10; ++unroll_factor) {
1285         auto clone = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();
1286         clone->RunPass<LoopUnroll>(INST_LIMIT, unroll_factor);
1287         EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(clone, PROGRAM_RESULT));
1288     }
1289     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, PROGRAM_RESULT));
1290 }
1291 
TEST_F(LoopUnrollTest,BackEdgeWithoutCompare)1292 TEST_F(LoopUnrollTest, BackEdgeWithoutCompare)
1293 {
1294     auto graph = CreateEmptyGraph();
1295     GRAPH(graph)
1296     {
1297         CONSTANT(0, 10);
1298         CONSTANT(1, 0);  // a = 0, b = 0
1299         CONSTANT(2, 1);
1300 
1301         BASIC_BLOCK(2, 3, 6)
1302         {
1303             INST(3, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(1, 0);
1304             INST(4, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(3);  // if a < 10
1305         }
1306         BASIC_BLOCK(3, 4, 5)
1307         {
1308             INST(5, Opcode::Phi).s32().Inputs(1, 7);   // a
1309             INST(6, Opcode::Phi).s32().Inputs(1, 12);  // b
1310             INST(7, Opcode::Add).s32().Inputs(5, 2);   // a++
1311             INST(8, Opcode::Compare).b().SrcType(DataType::INT32).CC(CC_LT).Inputs(7, 0);
1312             INST(9, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(8);  // if a < 10
1313         }
1314         BASIC_BLOCK(4, 5)
1315         {
1316             INST(10, Opcode::Add).s32().Inputs(6, 2);  // b++
1317         }
1318         BASIC_BLOCK(5, 3, 6)
1319         {
1320             INST(11, Opcode::Phi).s32().Inputs(6, 10);                                   // b
1321             INST(12, Opcode::Add).s32().Inputs(11, 7);                                   // b += a
1322             INST(13, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(8);  // if a < 10
1323         }
1324         BASIC_BLOCK(6, -1)
1325         {
1326             INST(14, Opcode::Phi).s32().Inputs(1, 12);
1327             INST(15, Opcode::Return).s32().Inputs(14);  // return b
1328         }
1329     }
1330     auto unrolled_graph = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();
1331     EXPECT_TRUE(unrolled_graph->RunPass<LoopUnroll>(INST_LIMIT, 2));
1332 
1333     static constexpr uint64_t PROGRAM_RESULT = 64;
1334     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(graph, PROGRAM_RESULT));
1335     EXPECT_TRUE(CheckRetOnVixlSimulator<uint64_t>(unrolled_graph, PROGRAM_RESULT));
1336 }
1337 
TEST_F(LoopUnrollTest,UnrollWithCalls)1338 TEST_F(LoopUnrollTest, UnrollWithCalls)
1339 {
1340     auto graph = CreateEmptyGraph();
1341     GRAPH(graph)
1342     {
1343         PARAMETER(0, 0).u64();  // a
1344         PARAMETER(1, 1).u64();  // b
1345 
1346         BASIC_BLOCK(2, 2, 3)
1347         {
1348             INST(4, Opcode::Phi).u64().Inputs(0, 6);
1349             INST(20, Opcode::SaveState).NoVregs();
1350             INST(5, Opcode::CallStatic).u64().InputsAutoType(20);
1351             INST(6, Opcode::Add).u64().Inputs(4, 5);              // a += call()
1352             INST(7, Opcode::Compare).CC(CC_LT).b().Inputs(6, 1);  // while a < b
1353             INST(8, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(7);
1354         }
1355 
1356         BASIC_BLOCK(3, -1)
1357         {
1358             INST(11, Opcode::Return).u64().Inputs(6);  // return a
1359         }
1360     }
1361     static constexpr auto UNROLL_FACTOR = 5U;
1362     auto default_is_unroll_with_calls = options.IsCompilerUnrollLoopWithCalls();
1363 
1364     // Enable loop unroll with calls
1365     options.SetCompilerUnrollLoopWithCalls(true);
1366     auto clone = GraphCloner(graph, graph->GetAllocator(), graph->GetLocalAllocator()).CloneGraph();
1367     EXPECT_TRUE(clone->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1368     CountOpcodes(clone->GetBlocksRPO());
1369     EXPECT_EQ(GetOpcodeCount(Opcode::CallStatic), UNROLL_FACTOR);
1370 
1371     // Disable loop unroll with calls
1372     options.SetCompilerUnrollLoopWithCalls(false);
1373     EXPECT_FALSE(graph->RunPass<LoopUnroll>(INST_LIMIT, UNROLL_FACTOR));
1374     CountOpcodes(graph->GetBlocksRPO());
1375     EXPECT_EQ(GetOpcodeCount(Opcode::CallStatic), 1U);
1376 
1377     // Return default option
1378     options.SetCompilerUnrollLoopWithCalls(default_is_unroll_with_calls);
1379 }
1380 }  // namespace panda::compiler
1381