1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "compiler_logger.h"
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/graph_cloner.h"
19 #include "optimizer/optimizations/loop_unroll.h"
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23
24 namespace ark::compiler {
RunImpl()25 bool LoopUnroll::RunImpl()
26 {
27 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Run " << GetPassName();
28 RunLoopsVisitor();
29 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << GetPassName() << " complete";
30 GetGraph()->SetUnrollComplete();
31 return isApplied_;
32 }
33
InvalidateAnalyses()34 void LoopUnroll::InvalidateAnalyses()
35 {
36 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
37 GetGraph()->InvalidateAnalysis<AliasAnalysis>();
38 GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
39 InvalidateBlocksOrderAnalyzes(GetGraph());
40 }
41
42 template <typename T>
ConditionOverFlowImpl(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)43 bool ConditionOverFlowImpl(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
44 {
45 auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
46 auto testValue = static_cast<T>(loopInfo.test->CastToConstant()->GetIntValue());
47 auto typeMin = std::numeric_limits<T>::min();
48 auto typeMax = std::numeric_limits<T>::max();
49 if (immValue > static_cast<uint64_t>(typeMax)) {
50 return true;
51 }
52 if (loopInfo.isInc) {
53 // condition will be updated: test_value - imm_value
54 // so if (test_value - imm_value) < type_min, it's overflow
55 return (typeMin + static_cast<T>(immValue)) > testValue;
56 }
57 // condition will be updated: test_value + imm_value
58 // so if (test_value + imm_value) > type_max, it's overflow
59 return (typeMax - static_cast<T>(immValue)) < testValue;
60 }
61
62 /// NOTE(a.popov) Create pre-header compare if it doesn't exist
63
ConditionOverFlow(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)64 bool ConditionOverFlow(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
65 {
66 auto type = loopInfo.index->GetType();
67 ASSERT(DataType::GetCommonType(type) == DataType::INT64);
68 auto updateOpcode = loopInfo.update->GetOpcode();
69 if (updateOpcode == Opcode::AddOverflowCheck || updateOpcode == Opcode::SubOverflowCheck) {
70 return true;
71 }
72 if (!loopInfo.test->IsConst()) {
73 return false;
74 }
75
76 switch (type) {
77 case DataType::INT32:
78 return ConditionOverFlowImpl<int32_t>(loopInfo, unrollFactor);
79 case DataType::UINT32:
80 return ConditionOverFlowImpl<uint32_t>(loopInfo, unrollFactor);
81 case DataType::INT64:
82 return ConditionOverFlowImpl<int64_t>(loopInfo, unrollFactor);
83 case DataType::UINT64:
84 return ConditionOverFlowImpl<uint64_t>(loopInfo, unrollFactor);
85 default:
86 return true;
87 }
88 }
89
TransformLoopImpl(Loop * loop,std::optional<uint64_t> optIterations,bool noSideExits,uint32_t unrollFactor,std::optional<CountableLoopInfo> loopInfo)90 void LoopUnroll::TransformLoopImpl(Loop *loop, std::optional<uint64_t> optIterations, bool noSideExits,
91 uint32_t unrollFactor, std::optional<CountableLoopInfo> loopInfo)
92 {
93 auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
94 if (optIterations && noSideExits) {
95 // GCC gives false positive here
96 #if !defined(__clang__)
97 #pragma GCC diagnostic push
98 // CC-OFFNXT(warning_suppression) GCC false positive
99 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
100 #endif
101 auto iterations = *optIterations;
102 ASSERT(unrollFactor != 0);
103 auto remainingIters = iterations % unrollFactor;
104 #if !defined(__clang__)
105 #pragma GCC diagnostic pop
106 #endif
107 Loop *cloneLoop = remainingIters == 0 ? nullptr : graphCloner.CloneLoop(loop);
108 // Unroll loop without side-exits and fix compare in the pre-header and back-edge
109 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
110 FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
111 // Unroll loop without side-exits for remaining iterations
112 if (remainingIters != 0) {
113 graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(cloneLoop, remainingIters);
114 }
115 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
116 << "Unrolled without side-exits the loop with constant number of iterations (" << iterations
117 << "). Loop id = " << loop->GetId();
118 } else if (noSideExits) {
119 auto cloneLoop = graphCloner.CloneLoop(loop);
120 // Unroll loop without side-exits and fix compare in the pre-header and back-edge
121 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
122 FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
123 // Unroll loop with side-exits for remaining iterations
124 graphCloner.UnrollLoopBody<UnrollType::UNROLL_POST_INCREMENT>(cloneLoop, unrollFactor - 1);
125 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
126 << "Unrolled without side-exits the loop with unroll factor = " << unrollFactor
127 << ". Loop id = " << loop->GetId();
128 } else if (g_options.IsCompilerUnrollWithSideExits()) {
129 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITH_SIDE_EXITS>(loop, unrollFactor);
130 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Unrolled with side-exits the loop with unroll factor = " << unrollFactor
131 << ". Loop id = " << loop->GetId();
132 }
133 }
134
TransformLoop(Loop * loop)135 bool LoopUnroll::TransformLoop(Loop *loop)
136 {
137 if (!loop->GetInnerLoops().empty()) {
138 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
139 << "Loop isn't unrolled since it contains loops. Loop id = " << loop->GetId();
140 return false;
141 }
142 auto unrollParams = GetUnrollParams(loop);
143 if (!g_options.IsCompilerUnrollLoopWithCalls() && unrollParams.hasCall) {
144 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
145 << "Loop isn't unrolled since it contains calls. Loop id = " << loop->GetId();
146 return false;
147 }
148
149 auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
150 uint32_t unrollFactor = std::min(unrollParams.unrollFactor, unrollFactor_);
151 auto loopParser = CountableLoopParser(*loop);
152 auto loopInfo = loopParser.Parse();
153 std::optional<uint64_t> optIterations {};
154 auto noBranching = false;
155 if (loopInfo.has_value()) {
156 optIterations = CountableLoopParser::GetLoopIterations(*loopInfo);
157 if (optIterations == 0) {
158 optIterations.reset();
159 }
160 if (optIterations.has_value()) {
161 // Increase instruction limit for unroll without branching
162 // <= unroll_factor * 2 because unroll without side exits would create unroll_factor * 2 - 1 copies of loop
163 noBranching = unrollParams.cloneableInsts <= instLimit_ &&
164 (*optIterations <= unrollFactor * 2U || *optIterations <= unrollFactor_) &&
165 CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
166 }
167 }
168
169 if (noBranching) {
170 auto iterations = *optIterations;
171 graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(loop, iterations);
172 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
173 << "Unrolled without branching the loop with constant number of iterations (" << iterations
174 << "). Loop id = " << loop->GetId();
175 isApplied_ = true;
176 GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), iterations,
177 unrollParams.cloneableInsts, "without branching");
178 return true;
179 }
180
181 return UnrollWithBranching(unrollFactor, loop, loopInfo, optIterations);
182 }
183
UnrollWithBranching(uint32_t unrollFactor,Loop * loop,std::optional<CountableLoopInfo> loopInfo,std::optional<uint64_t> optIterations)184 bool LoopUnroll::UnrollWithBranching(uint32_t unrollFactor, Loop *loop, std::optional<CountableLoopInfo> loopInfo,
185 std::optional<uint64_t> optIterations)
186 {
187 auto unrollParams = GetUnrollParams(loop);
188
189 if (unrollFactor <= 1U) {
190 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
191 << "Loop isn't unrolled due to unroll factor = " << unrollFactor << ". Loop id = " << loop->GetId();
192 return false;
193 }
194
195 auto noSideExits = false;
196 if (loopInfo.has_value()) {
197 noSideExits =
198 !ConditionOverFlow(*loopInfo, unrollFactor) && CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
199 }
200
201 TransformLoopImpl(loop, optIterations, noSideExits, unrollFactor, loopInfo);
202 isApplied_ = true;
203 GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), unrollFactor,
204 unrollParams.cloneableInsts,
205 noSideExits ? "without side exits" : "with side exits");
206 return true;
207 }
208
209 /**
210 * @return - unroll parameters:
211 * - maximum value of unroll factor, depends on INST_LIMIT
212 * - number of cloneable instructions
213 */
GetUnrollParams(Loop * loop)214 LoopUnroll::UnrollParams LoopUnroll::GetUnrollParams(Loop *loop)
215 {
216 uint32_t baseInstCount = 0;
217 uint32_t notCloneableCount = 0;
218 bool hasCall = false;
219 for (auto block : loop->GetBlocks()) {
220 for (auto inst : block->AllInsts()) {
221 baseInstCount++;
222 if ((block->IsLoopHeader() && inst->IsPhi()) || inst->GetOpcode() == Opcode::SafePoint) {
223 notCloneableCount++;
224 }
225 hasCall |= inst->IsCall() && !static_cast<CallInst *>(inst)->IsInlined();
226 }
227 }
228
229 UnrollParams params = {1, (baseInstCount - notCloneableCount), hasCall};
230 if (baseInstCount >= instLimit_) {
231 return params;
232 }
233 uint32_t canBeClonedCount = instLimit_ - baseInstCount;
234 params.unrollFactor = unrollFactor_;
235 if (params.cloneableInsts > 0) {
236 params.unrollFactor = (canBeClonedCount / params.cloneableInsts) + 1;
237 }
238 return params;
239 }
240
241 /**
242 * @return - `if_imm`'s compare input when `if_imm` its single user,
243 * otherwise create a new one Compare for this `if_imm` and return it
244 */
GetOrCreateIfImmUniqueCompare(Inst * ifImm)245 Inst *GetOrCreateIfImmUniqueCompare(Inst *ifImm)
246 {
247 ASSERT(ifImm->GetOpcode() == Opcode::IfImm);
248 auto compare = ifImm->GetInput(0).GetInst();
249 ASSERT(compare->GetOpcode() == Opcode::Compare);
250 if (compare->HasSingleUser()) {
251 return compare;
252 }
253 auto newCmp = compare->Clone(compare->GetBasicBlock()->GetGraph());
254 newCmp->SetInput(0, compare->GetInput(0).GetInst());
255 newCmp->SetInput(1, compare->GetInput(1).GetInst());
256 ifImm->InsertBefore(newCmp);
257 ifImm->SetInput(0, newCmp);
258 return newCmp;
259 }
260
261 /// Normalize control-flow to the form: `if condition is true goto loop_header`
NormalizeControlFlow(BasicBlock * edge,const BasicBlock * loopHeader)262 void NormalizeControlFlow(BasicBlock *edge, const BasicBlock *loopHeader)
263 {
264 auto ifImm = edge->GetLastInst()->CastToIfImm();
265 ASSERT(ifImm->GetImm() == 0);
266 if (ifImm->GetCc() == CC_EQ) {
267 ifImm->SetCc(CC_NE);
268 edge->SwapTrueFalseSuccessors<true>();
269 }
270 auto cmp = ifImm->GetInput(0).GetInst()->CastToCompare();
271 if (!cmp->HasSingleUser()) {
272 auto newCmp = cmp->Clone(edge->GetGraph());
273 ifImm->InsertBefore(newCmp);
274 ifImm->SetInput(0, newCmp);
275 cmp = newCmp->CastToCompare();
276 }
277 if (edge->GetFalseSuccessor() == loopHeader) {
278 auto inversedCc = GetInverseConditionCode(cmp->GetCc());
279 cmp->SetCc(inversedCc);
280 edge->SwapTrueFalseSuccessors<true>();
281 }
282 }
283
CreateNewTestInst(const CountableLoopInfo & loopInfo,Inst * constInst,Inst * preHeaderCmp)284 Inst *LoopUnroll::CreateNewTestInst(const CountableLoopInfo &loopInfo, Inst *constInst, Inst *preHeaderCmp)
285 {
286 Inst *test = nullptr;
287 if (loopInfo.isInc) {
288 test = GetGraph()->CreateInstSub(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
289 loopInfo.test, constInst);
290 #ifdef PANDA_COMPILER_DEBUG_INFO
291 test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
292 #endif
293 } else {
294 test = GetGraph()->CreateInstAdd(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
295 loopInfo.test, constInst);
296 #ifdef PANDA_COMPILER_DEBUG_INFO
297 test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
298 #endif
299 }
300 preHeaderCmp->InsertBefore(test);
301 return test;
302 }
303
304 /**
305 * Replace `Compare(init, test)` with these instructions:
306 *
307 * Constant(unroll_factor)
308 * Sub/Add(test, Constant)
309 * Compare(init, SubI/AddI)
310 *
311 * And replace condition code if it is `CC_NE`.
312 * We use Constant + Sub/Add because low-level instructions (SubI/AddI) may appear only after Lowering pass.
313 */
FixCompareInst(const CountableLoopInfo & loopInfo,BasicBlock * header,uint32_t unrollFactor)314 void LoopUnroll::FixCompareInst(const CountableLoopInfo &loopInfo, BasicBlock *header, uint32_t unrollFactor)
315 {
316 auto preHeader = header->GetLoop()->GetPreHeader();
317 auto backEdge = loopInfo.ifImm->GetBasicBlock();
318 ASSERT(!preHeader->IsEmpty() && preHeader->GetLastInst()->GetOpcode() == Opcode::IfImm);
319 auto preHeaderIf = preHeader->GetLastInst()->CastToIfImm();
320 auto preHeaderCmp = GetOrCreateIfImmUniqueCompare(preHeaderIf);
321 auto backEdgeCmp = GetOrCreateIfImmUniqueCompare(loopInfo.ifImm);
322 NormalizeControlFlow(preHeader, header);
323 NormalizeControlFlow(backEdge, header);
324 // Create Sub/Add + Const instructions and replace Compare's test inst input
325 auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
326 auto newTest = CreateNewTestInst(loopInfo, GetGraph()->FindOrCreateConstant(immValue), preHeaderCmp);
327 auto testInputIdx = 1;
328 if (backEdgeCmp->GetInput(0) == loopInfo.test) {
329 testInputIdx = 0;
330 } else {
331 ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
332 }
333 ASSERT(preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test);
334 preHeaderCmp->SetInput(testInputIdx, newTest);
335 backEdgeCmp->SetInput(testInputIdx, newTest);
336 // Replace CC_NE ConditionCode
337 if (loopInfo.normalizedCc == CC_NE) {
338 auto cc = loopInfo.isInc ? CC_LT : CC_GT;
339 if (testInputIdx == 0) {
340 cc = SwapOperandsConditionCode(cc);
341 }
342 preHeaderCmp->CastToCompare()->SetCc(cc);
343 backEdgeCmp->CastToCompare()->SetCc(cc);
344 }
345 // for not constant test-instruction we need to insert `overflow-check`:
346 // `test - imm_value` should be less than `test` (incerement loop-index case)
347 // `test + imm_value` should be greater than `test` (decrement loop-index case)
348 // If overflow-check is failed goto after-loop
349 if (!loopInfo.test->IsConst()) {
350 auto cc = loopInfo.isInc ? CC_LT : CC_GT;
351 // Create overflow_compare
352 auto overflowCompare = GetGraph()->CreateInstCompare(compiler::DataType::BOOL, preHeaderCmp->GetPc(), newTest,
353 loopInfo.test, loopInfo.test->GetType(), cc);
354 #ifdef PANDA_COMPILER_DEBUG_INFO
355 overflowCompare->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
356 #endif
357 // Create (pre_header_compare AND overflow_compare) inst
358 auto andInst = GetGraph()->CreateInstAnd(DataType::BOOL, preHeaderCmp->GetPc(), preHeaderCmp, overflowCompare);
359 #ifdef PANDA_COMPILER_DEBUG_INFO
360 andInst->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
361 #endif
362 preHeaderIf->SetInput(0, andInst);
363 preHeaderIf->InsertBefore(andInst);
364 andInst->InsertBefore(overflowCompare);
365 }
366 }
367 } // namespace ark::compiler
368