• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "compiler_logger.h"
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/graph_cloner.h"
19 #include "optimizer/optimizations/loop_unroll.h"
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23 
24 namespace ark::compiler {
RunImpl()25 bool LoopUnroll::RunImpl()
26 {
27     COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Run " << GetPassName();
28     RunLoopsVisitor();
29     COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << GetPassName() << " complete";
30     GetGraph()->SetUnrollComplete();
31     return isApplied_;
32 }
33 
InvalidateAnalyses()34 void LoopUnroll::InvalidateAnalyses()
35 {
36     GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
37     GetGraph()->InvalidateAnalysis<AliasAnalysis>();
38     GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
39     InvalidateBlocksOrderAnalyzes(GetGraph());
40 }
41 
42 template <typename T>
ConditionOverFlowImpl(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)43 bool ConditionOverFlowImpl(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
44 {
45     auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
46     auto testValue = static_cast<T>(loopInfo.test->CastToConstant()->GetIntValue());
47     auto typeMin = std::numeric_limits<T>::min();
48     auto typeMax = std::numeric_limits<T>::max();
49     if (immValue > static_cast<uint64_t>(typeMax)) {
50         return true;
51     }
52     if (loopInfo.isInc) {
53         // condition will be updated: test_value - imm_value
54         // so if (test_value - imm_value) < type_min, it's overflow
55         return (typeMin + static_cast<T>(immValue)) > testValue;
56     }
57     // condition will be updated: test_value + imm_value
58     // so if (test_value + imm_value) > type_max, it's overflow
59     return (typeMax - static_cast<T>(immValue)) < testValue;
60 }
61 
62 /// NOTE(a.popov) Create pre-header compare if it doesn't exist
63 
ConditionOverFlow(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)64 bool ConditionOverFlow(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
65 {
66     auto type = loopInfo.index->GetType();
67     ASSERT(DataType::GetCommonType(type) == DataType::INT64);
68     auto updateOpcode = loopInfo.update->GetOpcode();
69     if (updateOpcode == Opcode::AddOverflowCheck || updateOpcode == Opcode::SubOverflowCheck) {
70         return true;
71     }
72     if (!loopInfo.test->IsConst()) {
73         return false;
74     }
75 
76     switch (type) {
77         case DataType::INT32:
78             return ConditionOverFlowImpl<int32_t>(loopInfo, unrollFactor);
79         case DataType::UINT32:
80             return ConditionOverFlowImpl<uint32_t>(loopInfo, unrollFactor);
81         case DataType::INT64:
82             return ConditionOverFlowImpl<int64_t>(loopInfo, unrollFactor);
83         case DataType::UINT64:
84             return ConditionOverFlowImpl<uint64_t>(loopInfo, unrollFactor);
85         default:
86             return true;
87     }
88 }
89 
TransformLoopImpl(Loop * loop,std::optional<uint64_t> optIterations,bool noSideExits,uint32_t unrollFactor,std::optional<CountableLoopInfo> loopInfo)90 void LoopUnroll::TransformLoopImpl(Loop *loop, std::optional<uint64_t> optIterations, bool noSideExits,
91                                    uint32_t unrollFactor, std::optional<CountableLoopInfo> loopInfo)
92 {
93     auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
94     if (optIterations && noSideExits) {
95         // GCC gives false positive here
96 #if !defined(__clang__)
97 #pragma GCC diagnostic push
98 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
99 #endif
100         auto iterations = *optIterations;
101         ASSERT(unrollFactor != 0);
102         auto remainingIters = iterations % unrollFactor;
103 #if !defined(__clang__)
104 #pragma GCC diagnostic pop
105 #endif
106         Loop *cloneLoop = remainingIters == 0 ? nullptr : graphCloner.CloneLoop(loop);
107         // Unroll loop without side-exits and fix compare in the pre-header and back-edge
108         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
109         FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
110         // Unroll loop without side-exits for remaining iterations
111         if (remainingIters != 0) {
112             graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(cloneLoop, remainingIters);
113         }
114         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
115             << "Unrolled without side-exits the loop with constant number of iterations (" << iterations
116             << "). Loop id = " << loop->GetId();
117     } else if (noSideExits) {
118         auto cloneLoop = graphCloner.CloneLoop(loop);
119         // Unroll loop without side-exits and fix compare in the pre-header and back-edge
120         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
121         FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
122         // Unroll loop with side-exits for remaining iterations
123         graphCloner.UnrollLoopBody<UnrollType::UNROLL_POST_INCREMENT>(cloneLoop, unrollFactor - 1);
124         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
125             << "Unrolled without side-exits the loop with unroll factor = " << unrollFactor
126             << ". Loop id = " << loop->GetId();
127     } else if (g_options.IsCompilerUnrollWithSideExits()) {
128         graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITH_SIDE_EXITS>(loop, unrollFactor);
129         COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Unrolled with side-exits the loop with unroll factor = " << unrollFactor
130                                             << ". Loop id = " << loop->GetId();
131     }
132 }
133 
TransformLoop(Loop * loop)134 bool LoopUnroll::TransformLoop(Loop *loop)
135 {
136     auto unrollParams = GetUnrollParams(loop);
137     if (!g_options.IsCompilerUnrollLoopWithCalls() && unrollParams.hasCall) {
138         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
139             << "Loop isn't unrolled since it contains calls. Loop id = " << loop->GetId();
140         return false;
141     }
142 
143     auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
144     uint32_t unrollFactor = std::min(unrollParams.unrollFactor, unrollFactor_);
145     auto loopParser = CountableLoopParser(*loop);
146     auto loopInfo = loopParser.Parse();
147     std::optional<uint64_t> optIterations {};
148     auto noBranching = false;
149     if (loopInfo.has_value()) {
150         optIterations = CountableLoopParser::GetLoopIterations(*loopInfo);
151         if (optIterations == 0) {
152             optIterations.reset();
153         }
154         if (optIterations.has_value()) {
155             // Increase instruction limit for unroll without branching
156             // <= unroll_factor * 2 because unroll without side exits would create unroll_factor * 2 - 1 copies of loop
157             noBranching = unrollParams.cloneableInsts <= instLimit_ &&
158                           (*optIterations <= unrollFactor * 2U || *optIterations <= unrollFactor_) &&
159                           CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
160         }
161     }
162 
163     if (noBranching) {
164         auto iterations = *optIterations;
165         graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(loop, iterations);
166         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
167             << "Unrolled without branching the loop with constant number of iterations (" << iterations
168             << "). Loop id = " << loop->GetId();
169         isApplied_ = true;
170         GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), iterations,
171                                                      unrollParams.cloneableInsts, "without branching");
172         return true;
173     }
174 
175     return UnrollWithBranching(unrollFactor, loop, loopInfo, optIterations);
176 }
177 
UnrollWithBranching(uint32_t unrollFactor,Loop * loop,std::optional<CountableLoopInfo> loopInfo,std::optional<uint64_t> optIterations)178 bool LoopUnroll::UnrollWithBranching(uint32_t unrollFactor, Loop *loop, std::optional<CountableLoopInfo> loopInfo,
179                                      std::optional<uint64_t> optIterations)
180 {
181     auto unrollParams = GetUnrollParams(loop);
182 
183     if (unrollFactor <= 1U) {
184         COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
185             << "Loop isn't unrolled due to unroll factor = " << unrollFactor << ". Loop id = " << loop->GetId();
186         return false;
187     }
188 
189     auto noSideExits = false;
190     if (loopInfo.has_value()) {
191         noSideExits =
192             !ConditionOverFlow(*loopInfo, unrollFactor) && CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
193     }
194 
195     TransformLoopImpl(loop, optIterations, noSideExits, unrollFactor, loopInfo);
196     isApplied_ = true;
197     GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), unrollFactor,
198                                                  unrollParams.cloneableInsts,
199                                                  noSideExits ? "without side exits" : "with side exits");
200     return true;
201 }
202 
203 /**
204  * @return - unroll parameters:
205  * - maximum value of unroll factor, depends on INST_LIMIT
206  * - number of cloneable instructions
207  */
GetUnrollParams(Loop * loop)208 LoopUnroll::UnrollParams LoopUnroll::GetUnrollParams(Loop *loop)
209 {
210     uint32_t baseInstCount = 0;
211     uint32_t notCloneableCount = 0;
212     bool hasCall = false;
213     for (auto block : loop->GetBlocks()) {
214         for (auto inst : block->AllInsts()) {
215             baseInstCount++;
216             if ((block->IsLoopHeader() && inst->IsPhi()) || inst->GetOpcode() == Opcode::SafePoint) {
217                 notCloneableCount++;
218             }
219             hasCall |= inst->IsCall() && !static_cast<CallInst *>(inst)->IsInlined();
220         }
221     }
222 
223     UnrollParams params = {1, (baseInstCount - notCloneableCount), hasCall};
224     if (baseInstCount >= instLimit_) {
225         return params;
226     }
227     uint32_t canBeClonedCount = instLimit_ - baseInstCount;
228     params.unrollFactor = unrollFactor_;
229     if (params.cloneableInsts > 0) {
230         params.unrollFactor = (canBeClonedCount / params.cloneableInsts) + 1;
231     }
232     return params;
233 }
234 
235 /**
236  * @return - `if_imm`'s compare input when `if_imm` its single user,
237  * otherwise create a new one Compare for this `if_imm` and return it
238  */
GetOrCreateIfImmUniqueCompare(Inst * ifImm)239 Inst *GetOrCreateIfImmUniqueCompare(Inst *ifImm)
240 {
241     ASSERT(ifImm->GetOpcode() == Opcode::IfImm);
242     auto compare = ifImm->GetInput(0).GetInst();
243     ASSERT(compare->GetOpcode() == Opcode::Compare);
244     if (compare->HasSingleUser()) {
245         return compare;
246     }
247     auto newCmp = compare->Clone(compare->GetBasicBlock()->GetGraph());
248     newCmp->SetInput(0, compare->GetInput(0).GetInst());
249     newCmp->SetInput(1, compare->GetInput(1).GetInst());
250     ifImm->InsertBefore(newCmp);
251     ifImm->SetInput(0, newCmp);
252     return newCmp;
253 }
254 
255 /// Normalize control-flow to the form: `if condition is true goto loop_header`
NormalizeControlFlow(BasicBlock * edge,const BasicBlock * loopHeader)256 void NormalizeControlFlow(BasicBlock *edge, const BasicBlock *loopHeader)
257 {
258     auto ifImm = edge->GetLastInst()->CastToIfImm();
259     ASSERT(ifImm->GetImm() == 0);
260     if (ifImm->GetCc() == CC_EQ) {
261         ifImm->SetCc(CC_NE);
262         edge->SwapTrueFalseSuccessors<true>();
263     }
264     auto cmp = ifImm->GetInput(0).GetInst()->CastToCompare();
265     if (!cmp->HasSingleUser()) {
266         auto newCmp = cmp->Clone(edge->GetGraph());
267         ifImm->InsertBefore(newCmp);
268         ifImm->SetInput(0, newCmp);
269         cmp = newCmp->CastToCompare();
270     }
271     if (edge->GetFalseSuccessor() == loopHeader) {
272         auto inversedCc = GetInverseConditionCode(cmp->GetCc());
273         cmp->SetCc(inversedCc);
274         edge->SwapTrueFalseSuccessors<true>();
275     }
276 }
277 
CreateNewTestInst(const CountableLoopInfo & loopInfo,Inst * constInst,Inst * preHeaderCmp)278 Inst *LoopUnroll::CreateNewTestInst(const CountableLoopInfo &loopInfo, Inst *constInst, Inst *preHeaderCmp)
279 {
280     Inst *test = nullptr;
281     if (loopInfo.isInc) {
282         test = GetGraph()->CreateInstSub(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
283                                          loopInfo.test, constInst);
284 #ifdef PANDA_COMPILER_DEBUG_INFO
285         test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
286 #endif
287     } else {
288         test = GetGraph()->CreateInstAdd(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
289                                          loopInfo.test, constInst);
290 #ifdef PANDA_COMPILER_DEBUG_INFO
291         test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
292 #endif
293     }
294     preHeaderCmp->InsertBefore(test);
295     return test;
296 }
297 
298 /**
299  * Replace `Compare(init, test)` with these instructions:
300  *
301  * Constant(unroll_factor)
302  * Sub/Add(test, Constant)
303  * Compare(init, SubI/AddI)
304  *
305  * And replace condition code if it is `CC_NE`.
306  * We use Constant + Sub/Add because low-level instructions (SubI/AddI) may appear only after Lowering pass.
307  */
FixCompareInst(const CountableLoopInfo & loopInfo,BasicBlock * header,uint32_t unrollFactor)308 void LoopUnroll::FixCompareInst(const CountableLoopInfo &loopInfo, BasicBlock *header, uint32_t unrollFactor)
309 {
310     auto preHeader = header->GetLoop()->GetPreHeader();
311     auto backEdge = loopInfo.ifImm->GetBasicBlock();
312     ASSERT(!preHeader->IsEmpty() && preHeader->GetLastInst()->GetOpcode() == Opcode::IfImm);
313     auto preHeaderIf = preHeader->GetLastInst()->CastToIfImm();
314     auto preHeaderCmp = GetOrCreateIfImmUniqueCompare(preHeaderIf);
315     auto backEdgeCmp = GetOrCreateIfImmUniqueCompare(loopInfo.ifImm);
316     NormalizeControlFlow(preHeader, header);
317     NormalizeControlFlow(backEdge, header);
318     // Create Sub/Add + Const instructions and replace Compare's test inst input
319     auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
320     auto newTest = CreateNewTestInst(loopInfo, GetGraph()->FindOrCreateConstant(immValue), preHeaderCmp);
321     auto testInputIdx = 1;
322     if (backEdgeCmp->GetInput(0) == loopInfo.test) {
323         testInputIdx = 0;
324     } else {
325         ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
326     }
327     ASSERT(preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test);
328     preHeaderCmp->SetInput(testInputIdx, newTest);
329     backEdgeCmp->SetInput(testInputIdx, newTest);
330     // Replace CC_NE ConditionCode
331     if (loopInfo.normalizedCc == CC_NE) {
332         auto cc = loopInfo.isInc ? CC_LT : CC_GT;
333         if (testInputIdx == 0) {
334             cc = SwapOperandsConditionCode(cc);
335         }
336         preHeaderCmp->CastToCompare()->SetCc(cc);
337         backEdgeCmp->CastToCompare()->SetCc(cc);
338     }
339     // for not constant test-instruction we need to insert `overflow-check`:
340     // `test - imm_value` should be less than `test` (incerement loop-index case)
341     // `test + imm_value` should be greater than `test` (decrement loop-index case)
342     // If overflow-check is failed goto after-loop
343     if (!loopInfo.test->IsConst()) {
344         auto cc = loopInfo.isInc ? CC_LT : CC_GT;
345         // Create overflow_compare
346         auto overflowCompare = GetGraph()->CreateInstCompare(compiler::DataType::BOOL, preHeaderCmp->GetPc(), newTest,
347                                                              loopInfo.test, loopInfo.test->GetType(), cc);
348 #ifdef PANDA_COMPILER_DEBUG_INFO
349         overflowCompare->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
350 #endif
351         // Create (pre_header_compare AND overflow_compare) inst
352         auto andInst = GetGraph()->CreateInstAnd(DataType::BOOL, preHeaderCmp->GetPc(), preHeaderCmp, overflowCompare);
353 #ifdef PANDA_COMPILER_DEBUG_INFO
354         andInst->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
355 #endif
356         preHeaderIf->SetInput(0, andInst);
357         preHeaderIf->InsertBefore(andInst);
358         andInst->InsertBefore(overflowCompare);
359     }
360 }
361 }  // namespace ark::compiler
362