• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimizer/analysis/countable_loop_parser.h"
17 #include "optimizer/analysis/bounds_analysis.h"
18 #include "optimizer/analysis/loop_analyzer.h"
19 #include "optimizer/ir/basicblock.h"
20 #include "optimizer/ir/graph.h"
21 
22 namespace ark::compiler {
23 /**
24  * Check if loop is countable
25  *
26  * [Loop]
27  * Phi(init, update)
28  * ...
29  * update(phi, 1)
30  * Compare(Add/Sub, test)
31  *
32  * where `update` is Add or Sub instruction
33  */
Parse()34 std::optional<CountableLoopInfo> CountableLoopParser::Parse()
35 {
36     if (loop_.IsIrreducible() || loop_.IsOsrLoop() || loop_.IsTryCatchLoop() || loop_.GetBackEdges().size() != 1 ||
37         loop_.IsRoot() || loop_.IsInfinite()) {
38         return std::nullopt;
39     }
40 
41     if (!ParseLoopExit()) {
42         return std::nullopt;
43     }
44 
45     if (!SetUpdateAndTestInputs()) {
46         return std::nullopt;
47     }
48 
49     if (!IsInstIncOrDec(loopInfo_.update)) {
50         return std::nullopt;
51     }
52     SetIndexAndConstStep();
53     if (loopInfo_.index->GetBasicBlock() != loop_.GetHeader()) {
54         return std::nullopt;
55     }
56 
57     if (!TryProcessBackEdge()) {
58         return std::nullopt;
59     }
60     return loopInfo_;
61 }
62 
ParseLoopExit()63 bool CountableLoopParser::ParseLoopExit()
64 {
65     auto loopExit = FindLoopExitBlock();
66     if (loopExit == nullptr) {
67         return false;
68     }
69     if (loopExit->IsEmpty() || (loopExit != loop_.GetHeader() && loopExit != loop_.GetBackEdges()[0])) {
70         return false;
71     }
72     isHeadLoopExit_ = (loopExit == loop_.GetHeader() && loopExit != loop_.GetBackEdges()[0]);
73     loopInfo_.ifImm = loopExit->GetLastInst();
74     if (loopInfo_.ifImm->GetOpcode() != Opcode::IfImm && loopInfo_.ifImm->GetOpcode() != Opcode::If) {
75         return false;
76     }
77     auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
78     if (loopExitCmp->GetOpcode() != Opcode::Compare) {
79         return false;
80     }
81     if (isHeadLoopExit_ && !loopExitCmp->GetInput(0).GetInst()->IsPhi() &&
82         !loopExitCmp->GetInput(1).GetInst()->IsPhi()) {
83         return false;
84     }
85     auto cmpType = loopExitCmp->CastToCompare()->GetOperandsType();
86     return DataType::GetCommonType(cmpType) == DataType::INT64;
87 }
88 
TryProcessBackEdge()89 bool CountableLoopParser::TryProcessBackEdge()
90 {
91     ASSERT(loopInfo_.index->IsPhi());
92     auto backEdge {loop_.GetBackEdges()[0]};
93     auto backEdgeIdx {loopInfo_.index->CastToPhi()->GetPredBlockIndex(backEdge)};
94     if (loopInfo_.index->GetInput(backEdgeIdx).GetInst() != loopInfo_.update) {
95         return false;
96     }
97     ASSERT(loopInfo_.index->GetInputsCount() == MAX_SUCCS_NUM);
98     loopInfo_.init = loopInfo_.index->GetInput(1 - backEdgeIdx).GetInst();
99     SetNormalizedConditionCode();
100     return IsConditionCodeAcceptable();
101 }
102 
HasPreHeaderCompare(Loop * loop,const CountableLoopInfo & loopInfo)103 bool CountableLoopParser::HasPreHeaderCompare(Loop *loop, const CountableLoopInfo &loopInfo)
104 {
105     auto preHeader = loop->GetPreHeader();
106     auto backEdge = loop->GetBackEdges()[0];
107     if (loopInfo.ifImm->GetBasicBlock() != backEdge || preHeader->IsEmpty() ||
108         preHeader->GetLastInst()->GetOpcode() != Opcode::IfImm) {
109         return false;
110     }
111     auto preHeaderIfImm = preHeader->GetLastInst();
112     ASSERT(preHeaderIfImm->GetOpcode() == Opcode::IfImm);
113     auto preHeaderCmp = preHeaderIfImm->GetInput(0).GetInst();
114     if (preHeaderCmp->GetOpcode() != Opcode::Compare) {
115         return false;
116     }
117     auto backEdgeCmp = loopInfo.ifImm->GetInput(0).GetInst();
118     ASSERT(backEdgeCmp->GetOpcode() == Opcode::Compare);
119 
120     // Compare condition codes
121     if (preHeaderCmp->CastToCompare()->GetCc() != backEdgeCmp->CastToCompare()->GetCc()) {
122         return false;
123     }
124 
125     if (loopInfo.ifImm->CastToIfImm()->GetCc() != preHeaderIfImm->CastToIfImm()->GetCc() ||
126         loopInfo.ifImm->CastToIfImm()->GetImm() != preHeaderIfImm->CastToIfImm()->GetImm()) {
127         return false;
128     }
129 
130     // Compare control-flow
131     if (preHeader->GetTrueSuccessor() != backEdge->GetTrueSuccessor() ||
132         preHeader->GetFalseSuccessor() != backEdge->GetFalseSuccessor()) {
133         return false;
134     }
135 
136     // Compare test inputs
137     auto testInputIdx = 1;
138     if (backEdgeCmp->GetInput(0) == loopInfo.test) {
139         testInputIdx = 0;
140     } else {
141         ASSERT(backEdgeCmp->GetInput(1) == loopInfo.test);
142     }
143 
144     return preHeaderCmp->GetInput(testInputIdx).GetInst() == loopInfo.test &&
145            preHeaderCmp->GetInput(1 - testInputIdx).GetInst() == loopInfo.init;
146 }
147 
148 // Returns exact number of iterations for loop with constant boundaries
149 // if its index does not overflow
GetLoopIterations(const CountableLoopInfo & loopInfo)150 std::optional<uint64_t> CountableLoopParser::GetLoopIterations(const CountableLoopInfo &loopInfo)
151 {
152     if (!loopInfo.init->IsConst() || !loopInfo.test->IsConst() || loopInfo.constStep == 0) {
153         return std::nullopt;
154     }
155     uint64_t initValue = loopInfo.init->CastToConstant()->GetInt64Value();
156     uint64_t testValue = loopInfo.test->CastToConstant()->GetInt64Value();
157     auto type = loopInfo.index->GetType();
158 
159     if (loopInfo.isInc) {
160         int64_t maxTest = BoundsRange::GetMax(type) - static_cast<int64_t>(loopInfo.constStep);
161         if (loopInfo.normalizedCc == CC_LE) {
162             maxTest--;
163         }
164         if (static_cast<int64_t>(testValue) > maxTest) {
165             // index may overflow
166             return std::nullopt;
167         }
168     } else {
169         int64_t minTest = BoundsRange::GetMin(type) + static_cast<int64_t>(loopInfo.constStep);
170         if (loopInfo.normalizedCc == CC_GE) {
171             minTest++;
172         }
173         if (static_cast<int64_t>(testValue) < minTest) {
174             // index may overflow
175             return std::nullopt;
176         }
177         std::swap(initValue, testValue);
178     }
179     if (static_cast<int64_t>(initValue) > static_cast<int64_t>(testValue)) {
180         return 0;
181     }
182     uint64_t diff = testValue - initValue;
183     uint64_t count = diff + loopInfo.constStep;
184     if (diff > std::numeric_limits<uint64_t>::max() - loopInfo.constStep) {
185         // count may overflow
186         return std::nullopt;
187     }
188     if (loopInfo.normalizedCc == CC_LT || loopInfo.normalizedCc == CC_GT) {
189         count--;
190     }
191     return count / loopInfo.constStep;
192 }
193 
194 /*
195  * Check if instruction is Add or Sub with constant and phi inputs
196  */
IsInstIncOrDec(Inst * inst)197 bool CountableLoopParser::IsInstIncOrDec(Inst *inst)
198 {
199     if (!inst->IsAddSub()) {
200         return false;
201     }
202     ConstantInst *cnst = nullptr;
203     if (inst->GetInput(0).GetInst()->IsConst() && inst->GetInput(1).GetInst()->IsPhi()) {
204         cnst = inst->GetInput(0).GetInst()->CastToConstant();
205     } else if (inst->GetInput(1).GetInst()->IsConst() && inst->GetInput(0).GetInst()->IsPhi()) {
206         cnst = inst->GetInput(1).GetInst()->CastToConstant();
207     }
208     return cnst != nullptr;
209 }
210 
211 // NOTE(a.popov) Suppot 'GetLoopExit()' method in the 'Loop' class
FindLoopExitBlock()212 BasicBlock *CountableLoopParser::FindLoopExitBlock()
213 {
214     auto outerLoop = loop_.GetOuterLoop();
215     BasicBlock *loopExit = nullptr;
216     for (auto block : loop_.GetBlocks()) {
217         const auto &succs = block->GetSuccsBlocks();
218         auto it = std::find_if(succs.begin(), succs.end(),
219                                [&outerLoop](const BasicBlock *bb) { return bb->GetLoop() == outerLoop; });
220         if (it != succs.end()) {
221             // Countable loop must have a single exit:
222             if (loopExit != nullptr) {
223                 return nullptr;
224             }
225             loopExit = block;
226         }
227     }
228     return loopExit;
229 }
230 
SetUpdateAndTestInputs()231 bool CountableLoopParser::SetUpdateAndTestInputs()
232 {
233     auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
234     ASSERT(loopExitCmp->GetOpcode() == Opcode::Compare);
235     loopInfo_.update = loopExitCmp->GetInput(0).GetInst();
236     loopInfo_.test = loopExitCmp->GetInput(1).GetInst();
237     if (isHeadLoopExit_) {
238         if (!loopInfo_.update->IsPhi()) {
239             std::swap(loopInfo_.update, loopInfo_.test);
240         }
241         ASSERT(loopInfo_.update->IsPhi());
242         if (loopInfo_.update->GetBasicBlock() != loop_.GetHeader()) {
243             return false;
244         }
245         auto backEdge {loop_.GetBackEdges()[0]};
246         loopInfo_.update = loopInfo_.update->CastToPhi()->GetPhiInput(backEdge);
247     } else {
248         if (!IsInstIncOrDec(loopInfo_.update)) {
249             std::swap(loopInfo_.update, loopInfo_.test);
250         }
251     }
252 
253     return true;
254 }
255 
SetIndexAndConstStep()256 void CountableLoopParser::SetIndexAndConstStep()
257 {
258     loopInfo_.index = loopInfo_.update->GetInput(0).GetInst();
259     auto constInst = loopInfo_.update->GetInput(1).GetInst();
260     if (loopInfo_.index->IsConst()) {
261         loopInfo_.index = loopInfo_.update->GetInput(1).GetInst();
262         constInst = loopInfo_.update->GetInput(0).GetInst();
263     }
264 
265     ASSERT(constInst->GetType() == DataType::INT64);
266     auto cnst = constInst->CastToConstant()->GetIntValue();
267     const uint64_t mask = (1ULL << 63U);
268     auto isNeg = DataType::IsTypeSigned(loopInfo_.update->GetType()) && (cnst & mask) != 0;
269     loopInfo_.isInc = loopInfo_.update->IsAdd();
270     if (isNeg) {
271         cnst = ~cnst + 1;
272         loopInfo_.isInc = !loopInfo_.isInc;
273     }
274     loopInfo_.constStep = cnst;
275 }
276 
SetNormalizedConditionCode()277 void CountableLoopParser::SetNormalizedConditionCode()
278 {
279     auto loopExit = loopInfo_.ifImm->GetBasicBlock();
280     ASSERT(loopExit != nullptr);
281     auto loopExitCmp = loopInfo_.ifImm->GetInput(0).GetInst();
282     ASSERT(loopExitCmp->GetOpcode() == Opcode::Compare);
283     auto cc = loopExitCmp->CastToCompare()->GetCc();
284     if (loopInfo_.test == loopExitCmp->GetInput(0).GetInst()) {
285         cc = SwapOperandsConditionCode(cc);
286     }
287     ASSERT(loopInfo_.ifImm->CastToIfImm()->GetImm() == 0);
288     if (loopInfo_.ifImm->CastToIfImm()->GetCc() == CC_EQ) {
289         cc = GetInverseConditionCode(cc);
290     } else {
291         ASSERT(loopInfo_.ifImm->CastToIfImm()->GetCc() == CC_NE);
292     }
293     auto loop = loopExit->GetLoop();
294     if (loopExit->GetFalseSuccessor()->GetLoop() == loop ||
295         loopExit->GetFalseSuccessor()->GetLoop()->GetOuterLoop() == loop) {
296         cc = GetInverseConditionCode(cc);
297     } else {
298         ASSERT(loopExit->GetTrueSuccessor()->GetLoop() == loop ||
299                loopExit->GetTrueSuccessor()->GetLoop()->GetOuterLoop() == loop);
300     }
301     loopInfo_.normalizedCc = cc;
302 }
303 
IsConditionCodeAcceptable()304 bool CountableLoopParser::IsConditionCodeAcceptable()
305 {
306     auto cc = loopInfo_.normalizedCc;
307     // Condition should be: inc <= test | inc < test
308     if (loopInfo_.isInc && cc != CC_LE && cc != CC_LT) {
309         return false;
310     }
311     // Condition should be: dec >= test | dec > test
312     if (!loopInfo_.isInc && cc != CC_GE && cc != CC_GT) {
313         return false;
314     }
315     return true;
316 }
317 }  // namespace ark::compiler
318