1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/countable_loop_parser.h"
17 #include "optimizer/analysis/bounds_analysis.h"
18 #include "optimizer/analysis/loop_analyzer.h"
19 #include "optimizer/ir/basicblock.h"
20 #include "optimizer/ir/graph.h"
21
22 namespace ark::compiler {
23 /**
24 * Check if loop is countable
25 *
26 * [Loop]
27 * Phi(init, update)
28 * ...
29 * update(phi, 1)
30 * Compare(Add/Sub, test)
31 *
32 * where `update` is Add or Sub instruction
33 */
Parse()34 std::optional<CountableLoopInfo> CountableLoopParser::Parse()
35 {
36 if (loop_.IsIrreducible() || loop_.IsOsrLoop() || loop_.IsTryCatchLoop() || loop_.GetBackEdges().size() != 1 ||
37 loop_.IsRoot() || loop_.IsInfinite()) {
38 return std::nullopt;
39 }
40
41 if (!ParseLoopExit()) {
42 return std::nullopt;
43 }
44
45 if (!SetUpdateAndTestInputs()) {
46 return std::nullopt;
47 }
48
49 if (!IsInstIncOrDec(loopInfo_.update)) {
50 return std::nullopt;
51 }
52 SetIndexAndConstStep();
53 if (loopInfo_.index->GetBasicBlock() != loop_.GetHeader()) {
54 return std::nullopt;
55 }
56
57 if (!TryProcessBackEdge()) {
58 return std::nullopt;
59 }
60 return loopInfo_;
61 }
62
ParseLoopExit()63 bool CountableLoopParser::ParseLoopExit()
64 {
65 auto loopExit = FindLoopExitBlock();
66 if (loopExit == nullptr) {
67 return false;
68 }
69 if (loopExit->IsEmpty() || (loopExit != loop_.GetHeader() && loopExit != loop_.GetBackEdges()[0])) {
70 return false;
71 }
72 isHeadLoopExit_ = (loopExit == loop_.GetHeader() && loopExit != loop_.GetBackEdges()[0]);
73 loopInfo_.ifImm = loopExit->GetLastInst();
74 if (loopInfo_.ifImm->GetOpcode() != Opcode::IfImm && loopInfo_.ifImm->GetOpcode() != Opcode::If) {
75 return false;
76 }
77 auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
78 if (loopExitCmp->GetOpcode() != Opcode::Compare) {
79 return false;
80 }
81 if (isHeadLoopExit_ && !loopExitCmp->GetInput(0).GetInst()->IsPhi() &&
82 !loopExitCmp->GetInput(1).GetInst()->IsPhi()) {
83 return false;
84 }
85 auto cmpType = loopExitCmp->CastToCompare()->GetOperandsType();
86 return DataType::GetCommonType(cmpType) == DataType::INT64;
87 }
88
TryProcessBackEdge()89 bool CountableLoopParser::TryProcessBackEdge()
90 {
91 ASSERT(loopInfo_.index->IsPhi());
92 auto backEdge {loop_.GetBackEdges()[0]};
93 auto backEdgeIdx {loopInfo_.index->CastToPhi()->GetPredBlockIndex(backEdge)};
94 if (loopInfo_.index->GetInput(backEdgeIdx).GetInst() != loopInfo_.update) {
95 return false;
96 }
97 ASSERT(loopInfo_.index->GetInputsCount() == MAX_SUCCS_NUM);
98 loopInfo_.init = loopInfo_.index->GetInput(1 - backEdgeIdx).GetInst();
99 SetNormalizedConditionCode();
100 return IsConditionCodeAcceptable();
101 }
102
HasPreHeaderCompare(Loop * loop,const CountableLoopInfo & loopInfo)103 bool CountableLoopParser::HasPreHeaderCompare(Loop *loop, const CountableLoopInfo &loopInfo)
104 {
105 auto preHeader = loop->GetPreHeader();
106 auto backEdge = loop->GetBackEdges()[0];
107 if (loopInfo.ifImm->GetBasicBlock() != backEdge || preHeader->IsEmpty() ||
108 preHeader->GetLastInst()->GetOpcode() != Opcode::IfImm) {
109 return false;
110 }
111 auto preHeaderIfImm = preHeader->GetLastInst();
112 ASSERT(preHeaderIfImm->GetOpcode() == Opcode::IfImm);
113 auto preHeaderCmp = preHeaderIfImm->GetInput(0).GetInst();
114 if (preHeaderCmp->GetOpcode() != Opcode::Compare) {
115 return false;
116 }
117 auto backEdgeCmp = loopInfo.ifImm->GetInput(0).GetInst();
118 ASSERT(backEdgeCmp->GetOpcode() == Opcode::Compare);
119
120 // Compare condition codes
121 if (preHeaderCmp->CastToCompare()->GetCc() != backEdgeCmp->CastToCompare()->GetCc()) {
122 return false;
123 }
124
125 if (loopInfo.ifImm->CastToIfImm()->GetCc() != preHeaderIfImm->CastToIfImm()->GetCc() ||
126 loopInfo.ifImm->CastToIfImm()->GetImm() != preHeaderIfImm->CastToIfImm()->GetImm()) {
127 return false;
128 }
129
130 // Compare control-flow
131 if (preHeader->GetTrueSuccessor() != backEdge->GetTrueSuccessor() ||
132 preHeader->GetFalseSuccessor() != backEdge->GetFalseSuccessor()) {
133 return false;
134 }
135
136 // Compare test inputs
137 auto testInputIdx = 1;
138 if (backEdgeCmp->GetInput(0) == loopInfo.test) {
139 testInputIdx = 0;
140 } else {
141 ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
142 }
143
144 return preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test &&
145 preHeaderCmp->GetInput(1 - testInputIdx).GetInst() == loopInfo.init;
146 }
147
148 // Returns exact number of iterations for loop with constant boundaries
149 // if its index does not overflow
GetLoopIterations(const CountableLoopInfo & loopInfo)150 std::optional<uint64_t> CountableLoopParser::GetLoopIterations(const CountableLoopInfo &loopInfo)
151 {
152 if (!loopInfo.init->IsConst() || !loopInfo.test->IsConst() || loopInfo.constStep == 0) {
153 return std::nullopt;
154 }
155 uint64_t initValue = loopInfo.init->CastToConstant()->GetInt64Value();
156 uint64_t testValue = loopInfo.test->CastToConstant()->GetInt64Value();
157 auto type = loopInfo.index->GetType();
158
159 if (loopInfo.isInc) {
160 int64_t maxTest = BoundsRange::GetMax(type) - static_cast<int64_t>(loopInfo.constStep);
161 if (loopInfo.normalizedCc == CC_LE) {
162 maxTest--;
163 }
164 if (static_cast<int64_t>(testValue) > maxTest) {
165 // index may overflow
166 return std::nullopt;
167 }
168 } else {
169 int64_t minTest = BoundsRange::GetMin(type) + static_cast<int64_t>(loopInfo.constStep);
170 if (loopInfo.normalizedCc == CC_GE) {
171 minTest++;
172 }
173 if (static_cast<int64_t>(testValue) < minTest) {
174 // index may overflow
175 return std::nullopt;
176 }
177 std::swap(initValue, testValue);
178 }
179 if (static_cast<int64_t>(initValue) > static_cast<int64_t>(testValue)) {
180 return 0;
181 }
182 uint64_t diff = testValue - initValue;
183 uint64_t count = diff + loopInfo.constStep;
184 if (diff > std::numeric_limits<uint64_t>::max() - loopInfo.constStep) {
185 // count may overflow
186 return std::nullopt;
187 }
188 if (loopInfo.normalizedCc == CC_LT || loopInfo.normalizedCc == CC_GT) {
189 count--;
190 }
191 return count / loopInfo.constStep;
192 }
193
194 /*
195 * Check if instruction is Add or Sub with constant and phi inputs
196 */
IsInstIncOrDec(Inst * inst)197 bool CountableLoopParser::IsInstIncOrDec(Inst *inst)
198 {
199 if (!inst->IsAddSub()) {
200 return false;
201 }
202 ConstantInst *cnst = nullptr;
203 if (inst->GetInput(0).GetInst()->IsConst() && inst->GetInput(1).GetInst()->IsPhi()) {
204 cnst = inst->GetInput(0).GetInst()->CastToConstant();
205 } else if (inst->GetInput(1).GetInst()->IsConst() && inst->GetInput(0).GetInst()->IsPhi()) {
206 cnst = inst->GetInput(1).GetInst()->CastToConstant();
207 }
208 return cnst != nullptr;
209 }
210
211 // NOTE(a.popov) Suppot 'GetLoopExit()' method in the 'Loop' class
FindLoopExitBlock()212 BasicBlock *CountableLoopParser::FindLoopExitBlock()
213 {
214 auto outerLoop = loop_.GetOuterLoop();
215 BasicBlock *loopExit = nullptr;
216 for (auto block : loop_.GetBlocks()) {
217 const auto &succs = block->GetSuccsBlocks();
218 auto it = std::find_if(succs.begin(), succs.end(),
219 [&outerLoop](const BasicBlock *bb) { return bb->GetLoop() == outerLoop; });
220 if (it != succs.end()) {
221 // Countable loop must have a single exit:
222 if (loopExit != nullptr) {
223 return nullptr;
224 }
225 loopExit = block;
226 }
227 }
228 return loopExit;
229 }
230
SetUpdateAndTestInputs()231 bool CountableLoopParser::SetUpdateAndTestInputs()
232 {
233 auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
234 ASSERT(loopExitCmp->GetOpcode() == Opcode::Compare);
235 loopInfo_.update = loopExitCmp->GetInput(0).GetInst();
236 loopInfo_.test = loopExitCmp->GetInput(1).GetInst();
237 if (isHeadLoopExit_) {
238 if (!loopInfo_.update->IsPhi()) {
239 std::swap(loopInfo_.update, loopInfo_.test);
240 }
241 ASSERT(loopInfo_.update->IsPhi());
242 if (loopInfo_.update->GetBasicBlock() != loop_.GetHeader()) {
243 return false;
244 }
245 auto backEdge {loop_.GetBackEdges()[0]};
246 loopInfo_.update = loopInfo_.update->CastToPhi()->GetPhiInput(backEdge);
247 } else {
248 if (!IsInstIncOrDec(loopInfo_.update)) {
249 std::swap(loopInfo_.update, loopInfo_.test);
250 }
251 }
252
253 return true;
254 }
255
SetIndexAndConstStep()256 void CountableLoopParser::SetIndexAndConstStep()
257 {
258 loopInfo_.index = loopInfo_.update->GetInput(0).GetInst();
259 auto constInst = loopInfo_.update->GetInput(1).GetInst();
260 if (loopInfo_.index->IsConst()) {
261 loopInfo_.index = loopInfo_.update->GetInput(1).GetInst();
262 constInst = loopInfo_.update->GetInput(0).GetInst();
263 }
264
265 ASSERT(constInst->GetType() == DataType::INT64);
266 auto cnst = constInst->CastToConstant()->GetIntValue();
267 const uint64_t mask = (1ULL << 63U);
268 auto isNeg = DataType::IsTypeSigned(loopInfo_.update->GetType()) && (cnst & mask) != 0;
269 loopInfo_.isInc = loopInfo_.update->IsAdd();
270 if (isNeg) {
271 cnst = ~cnst + 1;
272 loopInfo_.isInc = !loopInfo_.isInc;
273 }
274 loopInfo_.constStep = cnst;
275 }
276
SetNormalizedConditionCode()277 void CountableLoopParser::SetNormalizedConditionCode()
278 {
279 auto loopExit = loopInfo_.ifImm->GetBasicBlock();
280 ASSERT(loopExit != nullptr);
281 auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
282 ASSERT(loopExitCmp->GetOpcode() == Opcode::Compare);
283 auto cc = loopExitCmp->CastToCompare()->GetCc();
284 if (loopInfo_.test == loopExitCmp->GetInput(0).GetInst()) {
285 cc = SwapOperandsConditionCode(cc);
286 }
287 ASSERT(loopInfo_.ifImm->CastToIfImm()->GetImm() == 0);
288 if (loopInfo_.ifImm->CastToIfImm()->GetCc() == CC_EQ) {
289 cc = GetInverseConditionCode(cc);
290 } else {
291 ASSERT(loopInfo_.ifImm->CastToIfImm()->GetCc() == CC_NE);
292 }
293 auto loop = loopExit->GetLoop();
294 if (loopExit->GetFalseSuccessor()->GetLoop() == loop ||
295 loopExit->GetFalseSuccessor()->GetLoop()->GetOuterLoop() == loop) {
296 cc = GetInverseConditionCode(cc);
297 } else {
298 ASSERT(loopExit->GetTrueSuccessor()->GetLoop() == loop ||
299 loopExit->GetTrueSuccessor()->GetLoop()->GetOuterLoop() == loop);
300 }
301 loopInfo_.normalizedCc = cc;
302 }
303
IsConditionCodeAcceptable()304 bool CountableLoopParser::IsConditionCodeAcceptable()
305 {
306 auto cc = loopInfo_.normalizedCc;
307 // Condition should be: inc <= test | inc < test
308 if (loopInfo_.isInc && cc != CC_LE && cc != CC_LT) {
309 return false;
310 }
311 // Condition should be: dec >= test | dec > test
312 if (!loopInfo_.isInc && cc != CC_GE && cc != CC_GT) {
313 return false;
314 }
315 return true;
316 }
317 } // namespace ark::compiler
318