• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "compiler_logger.h"
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/graph_cloner.h"
19 #include "optimizer/optimizations/loop_unroll.h"
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23 
24 namespace ark::compiler {
RunImpl()25 bool LoopUnroll::RunImpl()
26 {
27     COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Run " << GetPassName();
28     RunLoopsVisitor();
29     COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << GetPassName() << " complete";
30     GetGraph()->SetUnrollComplete();
31     return isApplied_;
32 }
33 
InvalidateAnalyses()34 void LoopUnroll::InvalidateAnalyses()
35 {
36     GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
37     GetGraph()->InvalidateAnalysis<AliasAnalysis>();
38     GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
39     InvalidateBlocksOrderAnalyzes(GetGraph());
40 }
41 
42 template <typename T>
ConditionOverFlowImpl(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)43 bool ConditionOverFlowImpl(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
44 {
45     auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
46     auto testValue = static_cast<T>(loopInfo.test->CastToConstant()->GetIntValue());
47     auto typeMin = std::numeric_limits<T>::min();
48     auto typeMax = std::numeric_limits<T>::max();
49     if (immValue > static_cast<uint64_t>(typeMax)) {
50         return true;
51     }
52     if (loopInfo.isInc) {
53         // condition will be updated: test_value - imm_value
54         // so if (test_value - imm_value) < type_min, it's overflow
55         return (typeMin + static_cast<T>(immValue)) > testValue;
56     }
57     // condition will be updated: test_value + imm_value
58     // so if (test_value + imm_value) > type_max, it's overflow
59     return (typeMax - static_cast<T>(immValue)) < testValue;
60 }
61 
62 /// NOTE(a.popov) Create pre-header compare if it doesn't exist
63 
ConditionOverFlow(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)64 bool ConditionOverFlow(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
65 {
66     auto type = loopInfo.index->GetType();
67     ASSERT(DataType::GetCommonType(type) == DataType::INT64);
68     auto updateOpcode = loopInfo.update->GetOpcode();
69     if (updateOpcode == Opcode::AddOverflowCheck || updateOpcode == Opcode::SubOverflowCheck) {
70         return true;
71     }
72     if (!loopInfo.test->IsConst()) {
73         return false;
74     }
75 
76     switch (type) {
77         case DataType::INT32:
78             return ConditionOverFlowImpl<int32_t>(loopInfo, unrollFactor);
79         case DataType::UINT32:
80             return ConditionOverFlowImpl<uint32_t>(loopInfo, unrollFactor);
81         case DataType::INT64:
82             return ConditionOverFlowImpl<int64_t>(loopInfo, unrollFactor);
83         case DataType::UINT64:
84             return ConditionOverFlowImpl<uint64_t>(loopInfo, unrollFactor);
85         default:
86             return true;
87     }
88 }
89 
TransformLoopImpl(Loop * loop,std::optional<uint64_t> optIterations,bool noSideExits,uint32_t unrollFactor,std::optional<CountableLoopInfo> loopInfo)90 void LoopUnroll::TransformLoopImpl(Loop *loop, std::optional<uint64_t> optIterations, bool noSideExits,
91                                    uint32_t unrollFactor, std::optional<CountableLoopInfo> loopInfo)
92 {
93     auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
94     if (optIterations && noSideExits) {
95         // GCC gives false positive here
96 #if !defined(__clang__)
97 #pragma GCC diagnostic push
98 // CC-OFFNXT(warning_suppression) GCC false positive
99 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
100 #endif
101         auto iterations = *optIterations;
102         ASSERT(unrollFactor != 0);
103         auto remainingIters = iterations % unrollFactor;
104 #if !defined(__clang__)
105 #pragma GCC diagnostic pop
106 #endif
107         Loop *cloneLoop = remainingIters == 0 ? nullptr : graphCloner.CloneLoop(loop);
108         // Unroll loop without side-exits and fix compare in the pre-header and back-edge
109         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
110         FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
111         // Unroll loop without side-exits for remaining iterations
112         if (remainingIters != 0) {
113             graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(cloneLoop, remainingIters);
114         }
115         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
116             << "Unrolled without side-exits the loop with constant number of iterations (" << iterations
117             << "). Loop id = " << loop->GetId();
118     } else if (noSideExits) {
119         auto cloneLoop = graphCloner.CloneLoop(loop);
120         // Unroll loop without side-exits and fix compare in the pre-header and back-edge
121         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
122         FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
123         // Unroll loop with side-exits for remaining iterations
124         graphCloner.UnrollLoopBody<UnrollType::UNROLL_POST_INCREMENT>(cloneLoop, unrollFactor - 1);
125         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
126             << "Unrolled without side-exits the loop with unroll factor = " << unrollFactor
127             << ". Loop id = " << loop->GetId();
128     } else if (g_options.IsCompilerUnrollWithSideExits()) {
129         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITH_SIDE_EXITS>(loop, unrollFactor);
130         COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Unrolled with side-exits the loop with unroll factor = " << unrollFactor
131                                             << ". Loop id = " << loop->GetId();
132     }
133 }
134 
TransformLoop(Loop * loop)135 bool LoopUnroll::TransformLoop(Loop *loop)
136 {
137     if (!loop->GetInnerLoops().empty()) {
138         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
139             << "Loop isn't unrolled since it contains loops. Loop id = " << loop->GetId();
140         return false;
141     }
142     auto unrollParams = GetUnrollParams(loop);
143     if (!g_options.IsCompilerUnrollLoopWithCalls() && unrollParams.hasCall) {
144         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
145             << "Loop isn't unrolled since it contains calls. Loop id = " << loop->GetId();
146         return false;
147     }
148 
149     auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
150     uint32_t unrollFactor = std::min(unrollParams.unrollFactor, unrollFactor_);
151     auto loopParser = CountableLoopParser(*loop);
152     auto loopInfo = loopParser.Parse();
153     std::optional<uint64_t> optIterations {};
154     auto noBranching = false;
155     if (loopInfo.has_value()) {
156         optIterations = CountableLoopParser::GetLoopIterations(*loopInfo);
157         if (optIterations == 0) {
158             optIterations.reset();
159         }
160         if (optIterations.has_value()) {
161             // Increase instruction limit for unroll without branching
162             // <= unroll_factor * 2 because unroll without side exits would create unroll_factor * 2 - 1 copies of loop
163             noBranching = unrollParams.cloneableInsts <= instLimit_ &&
164                           (*optIterations <= unrollFactor * 2U || *optIterations <= unrollFactor_) &&
165                           CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
166         }
167     }
168 
169     if (noBranching) {
170         auto iterations = *optIterations;
171         graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(loop, iterations);
172         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
173             << "Unrolled without branching the loop with constant number of iterations (" << iterations
174             << "). Loop id = " << loop->GetId();
175         isApplied_ = true;
176         GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), iterations,
177                                                      unrollParams.cloneableInsts, "without branching");
178         return true;
179     }
180 
181     return UnrollWithBranching(unrollFactor, loop, loopInfo, optIterations);
182 }
183 
UnrollWithBranching(uint32_t unrollFactor,Loop * loop,std::optional<CountableLoopInfo> loopInfo,std::optional<uint64_t> optIterations)184 bool LoopUnroll::UnrollWithBranching(uint32_t unrollFactor, Loop *loop, std::optional<CountableLoopInfo> loopInfo,
185                                      std::optional<uint64_t> optIterations)
186 {
187     auto unrollParams = GetUnrollParams(loop);
188 
189     if (unrollFactor <= 1U) {
190         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
191             << "Loop isn't unrolled due to unroll factor = " << unrollFactor << ". Loop id = " << loop->GetId();
192         return false;
193     }
194 
195     auto noSideExits = false;
196     if (loopInfo.has_value()) {
197         noSideExits =
198             !ConditionOverFlow(*loopInfo, unrollFactor) && CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
199     }
200 
201     TransformLoopImpl(loop, optIterations, noSideExits, unrollFactor, loopInfo);
202     isApplied_ = true;
203     GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), unrollFactor,
204                                                  unrollParams.cloneableInsts,
205                                                  noSideExits ? "without side exits" : "with side exits");
206     return true;
207 }
208 
209 /**
210  * @return - unroll parameters:
211  * - maximum value of unroll factor, depends on INST_LIMIT
212  * - number of cloneable instructions
213  */
GetUnrollParams(Loop * loop)214 LoopUnroll::UnrollParams LoopUnroll::GetUnrollParams(Loop *loop)
215 {
216     uint32_t baseInstCount = 0;
217     uint32_t notCloneableCount = 0;
218     bool hasCall = false;
219     for (auto block : loop->GetBlocks()) {
220         for (auto inst : block->AllInsts()) {
221             baseInstCount++;
222             if ((block->IsLoopHeader() && inst->IsPhi()) || inst->GetOpcode() == Opcode::SafePoint) {
223                 notCloneableCount++;
224             }
225             hasCall |= inst->IsCall() && !static_cast<CallInst *>(inst)->IsInlined();
226         }
227     }
228 
229     UnrollParams params = {1, (baseInstCount - notCloneableCount), hasCall};
230     if (baseInstCount >= instLimit_) {
231         return params;
232     }
233     uint32_t canBeClonedCount = instLimit_ - baseInstCount;
234     params.unrollFactor = unrollFactor_;
235     if (params.cloneableInsts > 0) {
236         params.unrollFactor = (canBeClonedCount / params.cloneableInsts) + 1;
237     }
238     return params;
239 }
240 
241 /**
242  * @return - `if_imm`'s compare input when `if_imm` its single user,
243  * otherwise create a new one Compare for this `if_imm` and return it
244  */
GetOrCreateIfImmUniqueCompare(Inst * ifImm)245 Inst *GetOrCreateIfImmUniqueCompare(Inst *ifImm)
246 {
247     ASSERT(ifImm->GetOpcode() == Opcode::IfImm);
248     auto compare = ifImm->GetInput(0).GetInst();
249     ASSERT(compare->GetOpcode() == Opcode::Compare);
250     if (compare->HasSingleUser()) {
251         return compare;
252     }
253     auto newCmp = compare->Clone(compare->GetBasicBlock()->GetGraph());
254     newCmp->SetInput(0, compare->GetInput(0).GetInst());
255     newCmp->SetInput(1, compare->GetInput(1).GetInst());
256     ifImm->InsertBefore(newCmp);
257     ifImm->SetInput(0, newCmp);
258     return newCmp;
259 }
260 
261 /// Normalize control-flow to the form: `if condition is true goto loop_header`
NormalizeControlFlow(BasicBlock * edge,const BasicBlock * loopHeader)262 void NormalizeControlFlow(BasicBlock *edge, const BasicBlock *loopHeader)
263 {
264     auto ifImm = edge->GetLastInst()->CastToIfImm();
265     ASSERT(ifImm->GetImm() == 0);
266     if (ifImm->GetCc() == CC_EQ) {
267         ifImm->SetCc(CC_NE);
268         edge->SwapTrueFalseSuccessors<true>();
269     }
270     auto cmp = ifImm->GetInput(0).GetInst()->CastToCompare();
271     if (!cmp->HasSingleUser()) {
272         auto newCmp = cmp->Clone(edge->GetGraph());
273         ifImm->InsertBefore(newCmp);
274         ifImm->SetInput(0, newCmp);
275         cmp = newCmp->CastToCompare();
276     }
277     if (edge->GetFalseSuccessor() == loopHeader) {
278         auto inversedCc = GetInverseConditionCode(cmp->GetCc());
279         cmp->SetCc(inversedCc);
280         edge->SwapTrueFalseSuccessors<true>();
281     }
282 }
283 
CreateNewTestInst(const CountableLoopInfo & loopInfo,Inst * constInst,Inst * preHeaderCmp)284 Inst *LoopUnroll::CreateNewTestInst(const CountableLoopInfo &loopInfo, Inst *constInst, Inst *preHeaderCmp)
285 {
286     Inst *test = nullptr;
287     if (loopInfo.isInc) {
288         test = GetGraph()->CreateInstSub(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
289                                          loopInfo.test, constInst);
290 #ifdef PANDA_COMPILER_DEBUG_INFO
291         test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
292 #endif
293     } else {
294         test = GetGraph()->CreateInstAdd(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
295                                          loopInfo.test, constInst);
296 #ifdef PANDA_COMPILER_DEBUG_INFO
297         test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
298 #endif
299     }
300     preHeaderCmp->InsertBefore(test);
301     return test;
302 }
303 
304 /**
305  * Replace `Compare(init, test)` with these instructions:
306  *
307  * Constant(unroll_factor)
308  * Sub/Add(test, Constant)
309  * Compare(init, SubI/AddI)
310  *
311  * And replace condition code if it is `CC_NE`.
312  * We use Constant + Sub/Add because low-level instructions (SubI/AddI) may appear only after Lowering pass.
313  */
FixCompareInst(const CountableLoopInfo & loopInfo,BasicBlock * header,uint32_t unrollFactor)314 void LoopUnroll::FixCompareInst(const CountableLoopInfo &loopInfo, BasicBlock *header, uint32_t unrollFactor)
315 {
316     auto preHeader = header->GetLoop()->GetPreHeader();
317     auto backEdge = loopInfo.ifImm->GetBasicBlock();
318     ASSERT(!preHeader->IsEmpty() && preHeader->GetLastInst()->GetOpcode() == Opcode::IfImm);
319     auto preHeaderIf = preHeader->GetLastInst()->CastToIfImm();
320     auto preHeaderCmp = GetOrCreateIfImmUniqueCompare(preHeaderIf);
321     auto backEdgeCmp = GetOrCreateIfImmUniqueCompare(loopInfo.ifImm);
322     NormalizeControlFlow(preHeader, header);
323     NormalizeControlFlow(backEdge, header);
324     // Create Sub/Add + Const instructions and replace Compare's test inst input
325     auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
326     auto newTest = CreateNewTestInst(loopInfo, GetGraph()->FindOrCreateConstant(immValue), preHeaderCmp);
327     auto testInputIdx = 1;
328     if (backEdgeCmp->GetInput(0) == loopInfo.test) {
329         testInputIdx = 0;
330     } else {
331         ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
332     }
333     ASSERT(preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test);
334     preHeaderCmp->SetInput(testInputIdx, newTest);
335     backEdgeCmp->SetInput(testInputIdx, newTest);
336     // Replace CC_NE ConditionCode
337     if (loopInfo.normalizedCc == CC_NE) {
338         auto cc = loopInfo.isInc ? CC_LT : CC_GT;
339         if (testInputIdx == 0) {
340             cc = SwapOperandsConditionCode(cc);
341         }
342         preHeaderCmp->CastToCompare()->SetCc(cc);
343         backEdgeCmp->CastToCompare()->SetCc(cc);
344     }
345     // for not constant test-instruction we need to insert `overflow-check`:
346     // `test - imm_value` should be less than `test` (incerement loop-index case)
347     // `test + imm_value` should be greater than `test` (decrement loop-index case)
348     // If overflow-check is failed goto after-loop
349     if (!loopInfo.test->IsConst()) {
350         auto cc = loopInfo.isInc ? CC_LT : CC_GT;
351         // Create overflow_compare
352         auto overflowCompare = GetGraph()->CreateInstCompare(compiler::DataType::BOOL, preHeaderCmp->GetPc(), newTest,
353                                                              loopInfo.test, loopInfo.test->GetType(), cc);
354 #ifdef PANDA_COMPILER_DEBUG_INFO
355         overflowCompare->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
356 #endif
357         // Create (pre_header_compare AND overflow_compare) inst
358         auto andInst = GetGraph()->CreateInstAnd(DataType::BOOL, preHeaderCmp->GetPc(), preHeaderCmp, overflowCompare);
359 #ifdef PANDA_COMPILER_DEBUG_INFO
360         andInst->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
361 #endif
362         preHeaderIf->SetInput(0, andInst);
363         preHeaderIf->InsertBefore(andInst);
364         andInst->InsertBefore(overflowCompare);
365     }
366 }
367 }  // namespace ark::compiler
368