1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include "compiler_logger.h"
16 #include "optimizer/ir/basicblock.h"
17 #include "optimizer/ir/graph.h"
18 #include "optimizer/ir/graph_cloner.h"
19 #include "optimizer/optimizations/loop_unroll.h"
20 #include "optimizer/analysis/alias_analysis.h"
21 #include "optimizer/analysis/bounds_analysis.h"
22 #include "optimizer/analysis/dominators_tree.h"
23
24 namespace ark::compiler {
RunImpl()25 bool LoopUnroll::RunImpl()
26 {
27 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Run " << GetPassName();
28 RunLoopsVisitor();
29 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << GetPassName() << " complete";
30 GetGraph()->SetUnrollComplete();
31 return isApplied_;
32 }
33
InvalidateAnalyses()34 void LoopUnroll::InvalidateAnalyses()
35 {
36 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
37 GetGraph()->InvalidateAnalysis<AliasAnalysis>();
38 GetGraph()->InvalidateAnalysis<LoopAnalyzer>();
39 InvalidateBlocksOrderAnalyzes(GetGraph());
40 }
41
42 template <typename T>
ConditionOverFlowImpl(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)43 bool ConditionOverFlowImpl(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
44 {
45 auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
46 auto testValue = static_cast<T>(loopInfo.test->CastToConstant()->GetIntValue());
47 auto typeMin = std::numeric_limits<T>::min();
48 auto typeMax = std::numeric_limits<T>::max();
49 if (immValue > static_cast<uint64_t>(typeMax)) {
50 return true;
51 }
52 if (loopInfo.isInc) {
53 // condition will be updated: test_value - imm_value
54 // so if (test_value - imm_value) < type_min, it's overflow
55 return (typeMin + static_cast<T>(immValue)) > testValue;
56 }
57 // condition will be updated: test_value + imm_value
58 // so if (test_value + imm_value) > type_max, it's overflow
59 return (typeMax - static_cast<T>(immValue)) < testValue;
60 }
61
62 /// NOTE(a.popov) Create pre-header compare if it doesn't exist
63
ConditionOverFlow(const CountableLoopInfo & loopInfo,uint32_t unrollFactor)64 bool ConditionOverFlow(const CountableLoopInfo &loopInfo, uint32_t unrollFactor)
65 {
66 auto type = loopInfo.index->GetType();
67 ASSERT(DataType::GetCommonType(type) == DataType::INT64);
68 auto updateOpcode = loopInfo.update->GetOpcode();
69 if (updateOpcode == Opcode::AddOverflowCheck || updateOpcode == Opcode::SubOverflowCheck) {
70 return true;
71 }
72 if (!loopInfo.test->IsConst()) {
73 return false;
74 }
75
76 switch (type) {
77 case DataType::INT32:
78 return ConditionOverFlowImpl<int32_t>(loopInfo, unrollFactor);
79 case DataType::UINT32:
80 return ConditionOverFlowImpl<uint32_t>(loopInfo, unrollFactor);
81 case DataType::INT64:
82 return ConditionOverFlowImpl<int64_t>(loopInfo, unrollFactor);
83 case DataType::UINT64:
84 return ConditionOverFlowImpl<uint64_t>(loopInfo, unrollFactor);
85 default:
86 return true;
87 }
88 }
89
TransformLoopImpl(Loop * loop,std::optional<uint64_t> optIterations,bool noSideExits,uint32_t unrollFactor,std::optional<CountableLoopInfo> loopInfo)90 void LoopUnroll::TransformLoopImpl(Loop *loop, std::optional<uint64_t> optIterations, bool noSideExits,
91 uint32_t unrollFactor, std::optional<CountableLoopInfo> loopInfo)
92 {
93 auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
94 if (optIterations && noSideExits) {
95 // GCC gives false positive here
96 #if !defined(__clang__)
97 #pragma GCC diagnostic push
98 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
99 #endif
100 auto iterations = *optIterations;
101 ASSERT(unrollFactor != 0);
102 auto remainingIters = iterations % unrollFactor;
103 #if !defined(__clang__)
104 #pragma GCC diagnostic pop
105 #endif
106 Loop *cloneLoop = remainingIters == 0 ? nullptr : graphCloner.CloneLoop(loop);
107 // Unroll loop without side-exits and fix compare in the pre-header and back-edge
108 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
109 FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
110 // Unroll loop without side-exits for remaining iterations
111 if (remainingIters != 0) {
112 graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(cloneLoop, remainingIters);
113 }
114 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
115 << "Unrolled without side-exits the loop with constant number of iterations (" << iterations
116 << "). Loop id = " << loop->GetId();
117 } else if (noSideExits) {
118 auto cloneLoop = graphCloner.CloneLoop(loop);
119 // Unroll loop without side-exits and fix compare in the pre-header and back-edge
120 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITHOUT_SIDE_EXITS>(loop, unrollFactor);
121 FixCompareInst(loopInfo.value(), loop->GetHeader(), unrollFactor);
122 // Unroll loop with side-exits for remaining iterations
123 graphCloner.UnrollLoopBody<UnrollType::UNROLL_POST_INCREMENT>(cloneLoop, unrollFactor - 1);
124 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
125 << "Unrolled without side-exits the loop with unroll factor = " << unrollFactor
126 << ". Loop id = " << loop->GetId();
127 } else if (g_options.IsCompilerUnrollWithSideExits()) {
128 graphCloner.UnrollLoopBody<UnrollType::UNROLL_WITH_SIDE_EXITS>(loop, unrollFactor);
129 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Unrolled with side-exits the loop with unroll factor = " << unrollFactor
130 << ". Loop id = " << loop->GetId();
131 }
132 }
133
TransformLoop(Loop * loop)134 bool LoopUnroll::TransformLoop(Loop *loop)
135 {
136 auto unrollParams = GetUnrollParams(loop);
137 if (!g_options.IsCompilerUnrollLoopWithCalls() && unrollParams.hasCall) {
138 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
139 << "Loop isn't unrolled since it contains calls. Loop id = " << loop->GetId();
140 return false;
141 }
142
143 auto graphCloner = GraphCloner(GetGraph(), GetGraph()->GetAllocator(), GetGraph()->GetLocalAllocator());
144 uint32_t unrollFactor = std::min(unrollParams.unrollFactor, unrollFactor_);
145 auto loopParser = CountableLoopParser(*loop);
146 auto loopInfo = loopParser.Parse();
147 std::optional<uint64_t> optIterations {};
148 auto noBranching = false;
149 if (loopInfo.has_value()) {
150 optIterations = CountableLoopParser::GetLoopIterations(*loopInfo);
151 if (optIterations == 0) {
152 optIterations.reset();
153 }
154 if (optIterations.has_value()) {
155 // Increase instruction limit for unroll without branching
156 // <= unroll_factor * 2 because unroll without side exits would create unroll_factor * 2 - 1 copies of loop
157 noBranching = unrollParams.cloneableInsts <= instLimit_ &&
158 (*optIterations <= unrollFactor * 2U || *optIterations <= unrollFactor_) &&
159 CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
160 }
161 }
162
163 if (noBranching) {
164 auto iterations = *optIterations;
165 graphCloner.UnrollLoopBody<UnrollType::UNROLL_CONSTANT_ITERATIONS>(loop, iterations);
166 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
167 << "Unrolled without branching the loop with constant number of iterations (" << iterations
168 << "). Loop id = " << loop->GetId();
169 isApplied_ = true;
170 GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), iterations,
171 unrollParams.cloneableInsts, "without branching");
172 return true;
173 }
174
175 return UnrollWithBranching(unrollFactor, loop, loopInfo, optIterations);
176 }
177
UnrollWithBranching(uint32_t unrollFactor,Loop * loop,std::optional<CountableLoopInfo> loopInfo,std::optional<uint64_t> optIterations)178 bool LoopUnroll::UnrollWithBranching(uint32_t unrollFactor, Loop *loop, std::optional<CountableLoopInfo> loopInfo,
179 std::optional<uint64_t> optIterations)
180 {
181 auto unrollParams = GetUnrollParams(loop);
182
183 if (unrollFactor <= 1U) {
184 COMPILER_LOG(DEBUG, LOOP_TRANSFORM)
185 << "Loop isn't unrolled due to unroll factor = " << unrollFactor << ". Loop id = " << loop->GetId();
186 return false;
187 }
188
189 auto noSideExits = false;
190 if (loopInfo.has_value()) {
191 noSideExits =
192 !ConditionOverFlow(*loopInfo, unrollFactor) && CountableLoopParser::HasPreHeaderCompare(loop, *loopInfo);
193 }
194
195 TransformLoopImpl(loop, optIterations, noSideExits, unrollFactor, loopInfo);
196 isApplied_ = true;
197 GetGraph()->GetEventWriter().EventLoopUnroll(loop->GetId(), loop->GetHeader()->GetGuestPc(), unrollFactor,
198 unrollParams.cloneableInsts,
199 noSideExits ? "without side exits" : "with side exits");
200 return true;
201 }
202
203 /**
204 * @return - unroll parameters:
205 * - maximum value of unroll factor, depends on INST_LIMIT
206 * - number of cloneable instructions
207 */
GetUnrollParams(Loop * loop)208 LoopUnroll::UnrollParams LoopUnroll::GetUnrollParams(Loop *loop)
209 {
210 uint32_t baseInstCount = 0;
211 uint32_t notCloneableCount = 0;
212 bool hasCall = false;
213 for (auto block : loop->GetBlocks()) {
214 for (auto inst : block->AllInsts()) {
215 baseInstCount++;
216 if ((block->IsLoopHeader() && inst->IsPhi()) || inst->GetOpcode() == Opcode::SafePoint) {
217 notCloneableCount++;
218 }
219 hasCall |= inst->IsCall() && !static_cast<CallInst *>(inst)->IsInlined();
220 }
221 }
222
223 UnrollParams params = {1, (baseInstCount - notCloneableCount), hasCall};
224 if (baseInstCount >= instLimit_) {
225 return params;
226 }
227 uint32_t canBeClonedCount = instLimit_ - baseInstCount;
228 params.unrollFactor = unrollFactor_;
229 if (params.cloneableInsts > 0) {
230 params.unrollFactor = (canBeClonedCount / params.cloneableInsts) + 1;
231 }
232 return params;
233 }
234
235 /**
236 * @return - `if_imm`'s compare input when `if_imm` its single user,
237 * otherwise create a new one Compare for this `if_imm` and return it
238 */
GetOrCreateIfImmUniqueCompare(Inst * ifImm)239 Inst *GetOrCreateIfImmUniqueCompare(Inst *ifImm)
240 {
241 ASSERT(ifImm->GetOpcode() == Opcode::IfImm);
242 auto compare = ifImm->GetInput(0).GetInst();
243 ASSERT(compare->GetOpcode() == Opcode::Compare);
244 if (compare->HasSingleUser()) {
245 return compare;
246 }
247 auto newCmp = compare->Clone(compare->GetBasicBlock()->GetGraph());
248 newCmp->SetInput(0, compare->GetInput(0).GetInst());
249 newCmp->SetInput(1, compare->GetInput(1).GetInst());
250 ifImm->InsertBefore(newCmp);
251 ifImm->SetInput(0, newCmp);
252 return newCmp;
253 }
254
255 /// Normalize control-flow to the form: `if condition is true goto loop_header`
NormalizeControlFlow(BasicBlock * edge,const BasicBlock * loopHeader)256 void NormalizeControlFlow(BasicBlock *edge, const BasicBlock *loopHeader)
257 {
258 auto ifImm = edge->GetLastInst()->CastToIfImm();
259 ASSERT(ifImm->GetImm() == 0);
260 if (ifImm->GetCc() == CC_EQ) {
261 ifImm->SetCc(CC_NE);
262 edge->SwapTrueFalseSuccessors<true>();
263 }
264 auto cmp = ifImm->GetInput(0).GetInst()->CastToCompare();
265 if (!cmp->HasSingleUser()) {
266 auto newCmp = cmp->Clone(edge->GetGraph());
267 ifImm->InsertBefore(newCmp);
268 ifImm->SetInput(0, newCmp);
269 cmp = newCmp->CastToCompare();
270 }
271 if (edge->GetFalseSuccessor() == loopHeader) {
272 auto inversedCc = GetInverseConditionCode(cmp->GetCc());
273 cmp->SetCc(inversedCc);
274 edge->SwapTrueFalseSuccessors<true>();
275 }
276 }
277
CreateNewTestInst(const CountableLoopInfo & loopInfo,Inst * constInst,Inst * preHeaderCmp)278 Inst *LoopUnroll::CreateNewTestInst(const CountableLoopInfo &loopInfo, Inst *constInst, Inst *preHeaderCmp)
279 {
280 Inst *test = nullptr;
281 if (loopInfo.isInc) {
282 test = GetGraph()->CreateInstSub(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
283 loopInfo.test, constInst);
284 #ifdef PANDA_COMPILER_DEBUG_INFO
285 test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
286 #endif
287 } else {
288 test = GetGraph()->CreateInstAdd(preHeaderCmp->CastToCompare()->GetOperandsType(), preHeaderCmp->GetPc(),
289 loopInfo.test, constInst);
290 #ifdef PANDA_COMPILER_DEBUG_INFO
291 test->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
292 #endif
293 }
294 preHeaderCmp->InsertBefore(test);
295 return test;
296 }
297
298 /**
299 * Replace `Compare(init, test)` with these instructions:
300 *
301 * Constant(unroll_factor)
302 * Sub/Add(test, Constant)
303 * Compare(init, SubI/AddI)
304 *
305 * And replace condition code if it is `CC_NE`.
306 * We use Constant + Sub/Add because low-level instructions (SubI/AddI) may appear only after Lowering pass.
307 */
FixCompareInst(const CountableLoopInfo & loopInfo,BasicBlock * header,uint32_t unrollFactor)308 void LoopUnroll::FixCompareInst(const CountableLoopInfo &loopInfo, BasicBlock *header, uint32_t unrollFactor)
309 {
310 auto preHeader = header->GetLoop()->GetPreHeader();
311 auto backEdge = loopInfo.ifImm->GetBasicBlock();
312 ASSERT(!preHeader->IsEmpty() && preHeader->GetLastInst()->GetOpcode() == Opcode::IfImm);
313 auto preHeaderIf = preHeader->GetLastInst()->CastToIfImm();
314 auto preHeaderCmp = GetOrCreateIfImmUniqueCompare(preHeaderIf);
315 auto backEdgeCmp = GetOrCreateIfImmUniqueCompare(loopInfo.ifImm);
316 NormalizeControlFlow(preHeader, header);
317 NormalizeControlFlow(backEdge, header);
318 // Create Sub/Add + Const instructions and replace Compare's test inst input
319 auto immValue = (static_cast<uint64_t>(unrollFactor) - 1) * loopInfo.constStep;
320 auto newTest = CreateNewTestInst(loopInfo, GetGraph()->FindOrCreateConstant(immValue), preHeaderCmp);
321 auto testInputIdx = 1;
322 if (backEdgeCmp->GetInput(0) == loopInfo.test) {
323 testInputIdx = 0;
324 } else {
325 ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
326 }
327 ASSERT(preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test);
328 preHeaderCmp->SetInput(testInputIdx, newTest);
329 backEdgeCmp->SetInput(testInputIdx, newTest);
330 // Replace CC_NE ConditionCode
331 if (loopInfo.normalizedCc == CC_NE) {
332 auto cc = loopInfo.isInc ? CC_LT : CC_GT;
333 if (testInputIdx == 0) {
334 cc = SwapOperandsConditionCode(cc);
335 }
336 preHeaderCmp->CastToCompare()->SetCc(cc);
337 backEdgeCmp->CastToCompare()->SetCc(cc);
338 }
339 // for not constant test-instruction we need to insert `overflow-check`:
340 // `test - imm_value` should be less than `test` (incerement loop-index case)
341 // `test + imm_value` should be greater than `test` (decrement loop-index case)
342 // If overflow-check is failed goto after-loop
343 if (!loopInfo.test->IsConst()) {
344 auto cc = loopInfo.isInc ? CC_LT : CC_GT;
345 // Create overflow_compare
346 auto overflowCompare = GetGraph()->CreateInstCompare(compiler::DataType::BOOL, preHeaderCmp->GetPc(), newTest,
347 loopInfo.test, loopInfo.test->GetType(), cc);
348 #ifdef PANDA_COMPILER_DEBUG_INFO
349 overflowCompare->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
350 #endif
351 // Create (pre_header_compare AND overflow_compare) inst
352 auto andInst = GetGraph()->CreateInstAnd(DataType::BOOL, preHeaderCmp->GetPc(), preHeaderCmp, overflowCompare);
353 #ifdef PANDA_COMPILER_DEBUG_INFO
354 andInst->SetCurrentMethod(preHeaderCmp->GetCurrentMethod());
355 #endif
356 preHeaderIf->SetInput(0, andInst);
357 preHeaderIf->InsertBefore(andInst);
358 andInst->InsertBefore(overflowCompare);
359 }
360 }
361 } // namespace ark::compiler
362