• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimizer/analysis/loop_analyzer.h"
17 #include "optimizer/ir/analysis.h"
18 #include "compiler_logger.h"
19 #include "loop_unswitcher.h"
20 
21 namespace ark::compiler {
LoopUnswitcher(Graph * graph,ArenaAllocator * allocator,ArenaAllocator * localAllocator)22 LoopUnswitcher::LoopUnswitcher(Graph *graph, ArenaAllocator *allocator, ArenaAllocator *localAllocator)
23     : GraphCloner(graph, allocator, localAllocator), conditions_(allocator->Adapter())
24 {
25 }
26 
27 /**
28  * Unswitch loop in selected branch instruction.
29  * Return pointer to new loop or nullptr if cannot unswitch loop.
30  */
UnswitchLoop(Loop * loop,Inst * inst)31 Loop *LoopUnswitcher::UnswitchLoop(Loop *loop, Inst *inst)
32 {
33     ASSERT(loop != nullptr && !loop->IsRoot());
34     ASSERT_PRINT(IsLoopSingleBackEdgeExitPoint(loop), "Cloning blocks doesn't have single entry/exit point");
35     ASSERT(loop->GetPreHeader() != nullptr);
36     ASSERT(!loop->IsIrreducible());
37     ASSERT(!loop->IsOsrLoop());
38     if (loop->GetPreHeader()->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
39         return nullptr;
40     }
41     ASSERT(cloneMarker_ == UNDEF_MARKER);
42 
43     auto markerHolder = MarkerHolder(GetGraph());
44     cloneMarker_ = markerHolder.GetMarker();
45     auto unswitchData = PrepareLoopToClone(loop);
46 
47     conditions_.clear();
48     for (auto bb : loop->GetBlocks()) {
49         if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
50             continue;
51         }
52 
53         auto ifImm = bb->GetLastInst();
54         if (IsConditionEqual(ifImm, inst, false) || IsConditionEqual(ifImm, inst, true)) {
55             // will replace all equal or oposite conditions
56             conditions_.push_back(ifImm);
57         }
58     }
59 
60     ASSERT(unswitchData != nullptr);
61     CloneBlocksAndInstructions<InstCloneType::CLONE_ALL, false>(*unswitchData->blocks, GetGraph());
62     BuildLoopUnswitchControlFlow(unswitchData);
63     BuildLoopUnswitchDataFlow(unswitchData, inst);
64     MakeLoopCloneInfo(unswitchData);
65     GetGraph()->RunPass<DominatorsTree>();
66 
67     auto cloneLoop = GetClone(loop->GetHeader())->GetLoop();
68     ASSERT(cloneLoop != loop && cloneLoop->GetOuterLoop() == loop->GetOuterLoop());
69     COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Loop " << loop->GetId() << " is copied";
70     COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Created new loop, id = " << cloneLoop->GetId();
71     return cloneLoop;
72 }
73 
BuildLoopUnswitchControlFlow(LoopClonerData * unswitchData)74 void LoopUnswitcher::BuildLoopUnswitchControlFlow(LoopClonerData *unswitchData)
75 {
76     ASSERT(unswitchData != nullptr);
77     auto outerClone = GetClone(unswitchData->outer);
78     auto preHeaderClone = GetClone(unswitchData->preHeader);
79 
80     auto commonOuter = GetGraph()->CreateEmptyBlock();
81 
82     while (!unswitchData->outer->GetSuccsBlocks().empty()) {
83         auto succ = unswitchData->outer->GetSuccsBlocks().front();
84         succ->ReplacePred(unswitchData->outer, commonOuter);
85         unswitchData->outer->RemoveSucc(succ);
86     }
87     unswitchData->outer->AddSucc(commonOuter);
88     outerClone->AddSucc(commonOuter);
89 
90     auto commonPredecessor = GetGraph()->CreateEmptyBlock();
91     while (!unswitchData->preHeader->GetPredsBlocks().empty()) {
92         auto pred = unswitchData->preHeader->GetPredsBlocks().front();
93         pred->ReplaceSucc(unswitchData->preHeader, commonPredecessor);
94         unswitchData->preHeader->RemovePred(pred);
95     }
96     commonPredecessor->AddSucc(unswitchData->preHeader);
97     commonPredecessor->AddSucc(preHeaderClone);
98 
99     for (auto &block : *unswitchData->blocks) {
100         if (block != unswitchData->preHeader) {
101             CloneEdges<CloneEdgeType::EDGE_PRED>(block);
102         }
103         if (block != unswitchData->outer) {
104             CloneEdges<CloneEdgeType::EDGE_SUCC>(block);
105         }
106     }
107     ASSERT(unswitchData->outer->GetPredBlockIndex(unswitchData->preHeader) ==
108            outerClone->GetPredBlockIndex(preHeaderClone));
109     ASSERT(unswitchData->header->GetPredBlockIndex(unswitchData->preHeader) ==
110            GetClone(unswitchData->header)->GetPredBlockIndex(preHeaderClone));
111 }
112 
BuildLoopUnswitchDataFlow(LoopClonerData * unswitchData,Inst * ifInst)113 void LoopUnswitcher::BuildLoopUnswitchDataFlow(LoopClonerData *unswitchData, Inst *ifInst)
114 {
115     ASSERT(unswitchData != nullptr);
116     ProcessMarkedInsts(unswitchData);
117 
118     auto commonOuter = unswitchData->outer->GetSuccessor(0);
119     for (auto phi : unswitchData->outer->PhiInsts()) {
120         auto phiClone = GetClone(phi);
121         auto phiJoin = commonOuter->GetGraph()->CreateInstPhi(phi->GetType(), phi->GetPc());
122         phi->ReplaceUsers(phiJoin);
123         phiJoin->AppendInput(phi);
124         phiJoin->AppendInput(phiClone);
125         commonOuter->AppendPhi(phiJoin);
126     }
127 
128     auto commonPredecessor = unswitchData->preHeader->GetPredecessor(0);
129     auto ifInstUnswitch = ifInst->Clone(commonPredecessor->GetGraph());
130     for (size_t i = 0; i < ifInst->GetInputsCount(); i++) {
131         auto input = ifInst->GetInput(i);
132         ifInstUnswitch->SetInput(i, input.GetInst());
133     }
134     commonPredecessor->AppendInst(ifInstUnswitch);
135 
136     ReplaceWithConstantCondition(ifInstUnswitch);
137 }
138 
ReplaceWithConstantCondition(Inst * ifInst)139 void LoopUnswitcher::ReplaceWithConstantCondition(Inst *ifInst)
140 {
141     auto graph = ifInst->GetBasicBlock()->GetGraph();
142     auto i1 = graph->FindOrCreateConstant(1);
143     auto i2 = graph->FindOrCreateConstant(0);
144 
145     auto ifImm = ifInst->CastToIfImm();
146     ASSERT(ifImm->GetCc() == ConditionCode::CC_NE || ifImm->GetCc() == ConditionCode::CC_EQ);
147     if ((ifImm->GetImm() == 0) != (ifImm->GetCc() == ConditionCode::CC_NE)) {
148         std::swap(i1, i2);
149     }
150 
151     for (auto inst : conditions_) {
152         if (IsConditionEqual(inst, ifInst, true)) {
153             inst->SetInput(0, i2);
154             GetClone(inst)->SetInput(0, i1);
155         } else {
156             inst->SetInput(0, i1);
157             GetClone(inst)->SetInput(0, i2);
158         }
159     }
160 }
161 
AllInputsConst(Inst * inst)162 static bool AllInputsConst(Inst *inst)
163 {
164     for (auto input : inst->GetInputs()) {
165         if (!input.GetInst()->IsConst()) {
166             return false;
167         }
168     }
169     return true;
170 }
171 
IsHoistable(Inst * inst,Loop * loop)172 static bool IsHoistable(Inst *inst, Loop *loop)
173 {
174     for (auto input : inst->GetInputs()) {
175         if (!input.GetInst()->GetBasicBlock()->IsDominate(loop->GetPreHeader())) {
176             return false;
177         }
178     }
179     return true;
180 }
181 
FindUnswitchInst(Loop * loop)182 Inst *LoopUnswitcher::FindUnswitchInst(Loop *loop)
183 {
184     for (auto bb : loop->GetBlocks()) {
185         if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
186             continue;
187         }
188         auto ifInst = bb->GetLastInst();
189         if (AllInputsConst(ifInst)) {
190             continue;
191         }
192         if (IsHoistable(ifInst, loop)) {
193             return ifInst;
194         }
195     }
196     return nullptr;
197 }
198 
IsSmallLoop(Loop * loop)199 bool LoopUnswitcher::IsSmallLoop(Loop *loop)
200 {
201     auto loopParser = CountableLoopParser(*loop);
202     auto loopInfo = loopParser.Parse();
203     if (!loopInfo.has_value()) {
204         return false;
205     }
206     auto iterations = CountableLoopParser::GetLoopIterations(*loopInfo);
207     if (!iterations.has_value()) {
208         return false;
209     }
210     return *iterations <= 1;
211 }
212 
CountLoopInstructions(const Loop * loop)213 static uint32_t CountLoopInstructions(const Loop *loop)
214 {
215     uint32_t count = 0;
216     for (auto block : loop->GetBlocks()) {
217         count += block->CountInsts();
218     }
219     return count;
220 }
221 
EstimateUnswitchInstructionsCount(BasicBlock * bb,const BasicBlock * backEdge,const Inst * unswitchInst,bool trueCond,Marker marker)222 static uint32_t EstimateUnswitchInstructionsCount(BasicBlock *bb, const BasicBlock *backEdge, const Inst *unswitchInst,
223                                                   bool trueCond, Marker marker)
224 {
225     if (bb->IsMarked(marker)) {
226         return 0;
227     }
228     bb->SetMarker(marker);
229 
230     uint32_t count = bb->CountInsts();
231     if (bb == backEdge) {
232         return count;
233     }
234 
235     if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
236         count += EstimateUnswitchInstructionsCount(bb->GetSuccsBlocks()[0], backEdge, unswitchInst, trueCond, marker);
237     } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), false)) {
238         auto succ = trueCond ? bb->GetTrueSuccessor() : bb->GetFalseSuccessor();
239         count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
240     } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), true)) {
241         auto succ = trueCond ? bb->GetFalseSuccessor() : bb->GetTrueSuccessor();
242         count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
243     } else {
244         for (auto succ : bb->GetSuccsBlocks()) {
245             count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
246         }
247     }
248     return count;
249 }
250 
EstimateInstructionsCount(const Loop * loop,const Inst * unswitchInst,int64_t * loopSize,int64_t * trueCount,int64_t * falseCount)251 void LoopUnswitcher::EstimateInstructionsCount(const Loop *loop, const Inst *unswitchInst, int64_t *loopSize,
252                                                int64_t *trueCount, int64_t *falseCount)
253 {
254     ASSERT(loop->GetBackEdges().size() == 1);
255     ASSERT(loop->GetInnerLoops().empty());
256     *loopSize = static_cast<int64_t>(CountLoopInstructions(loop));
257     auto backEdge = loop->GetBackEdges()[0];
258     auto graph = backEdge->GetGraph();
259 
260     auto trueMarker = graph->NewMarker();
261     *trueCount = static_cast<int64_t>(
262         EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, true, trueMarker));
263     graph->EraseMarker(trueMarker);
264 
265     auto falseMarker = graph->NewMarker();
266     *falseCount = static_cast<int64_t>(
267         EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, false, falseMarker));
268     graph->EraseMarker(falseMarker);
269 }
270 }  // namespace ark::compiler
271