1 /*
2 * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/loop_analyzer.h"
17 #include "optimizer/ir/analysis.h"
18 #include "compiler_logger.h"
19 #include "loop_unswitcher.h"
20
21 namespace ark::compiler {
LoopUnswitcher(Graph * graph,ArenaAllocator * allocator,ArenaAllocator * localAllocator)22 LoopUnswitcher::LoopUnswitcher(Graph *graph, ArenaAllocator *allocator, ArenaAllocator *localAllocator)
23 : GraphCloner(graph, allocator, localAllocator), conditions_(allocator->Adapter())
24 {
25 }
26
27 /**
28 * Unswitch loop in selected branch instruction.
29 * Return pointer to new loop or nullptr if cannot unswitch loop.
30 */
UnswitchLoop(Loop * loop,Inst * inst)31 Loop *LoopUnswitcher::UnswitchLoop(Loop *loop, Inst *inst)
32 {
33 ASSERT(loop != nullptr && !loop->IsRoot());
34 ASSERT_PRINT(IsLoopSingleBackEdgeExitPoint(loop), "Cloning blocks doesn't have single entry/exit point");
35 ASSERT(loop->GetPreHeader() != nullptr);
36 ASSERT(!loop->IsIrreducible());
37 ASSERT(!loop->IsOsrLoop());
38 if (loop->GetPreHeader()->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
39 return nullptr;
40 }
41 ASSERT(cloneMarker_ == UNDEF_MARKER);
42
43 auto markerHolder = MarkerHolder(GetGraph());
44 cloneMarker_ = markerHolder.GetMarker();
45 auto unswitchData = PrepareLoopToClone(loop);
46
47 conditions_.clear();
48 for (auto bb : loop->GetBlocks()) {
49 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
50 continue;
51 }
52
53 auto ifImm = bb->GetLastInst();
54 if (IsConditionEqual(ifImm, inst, false) || IsConditionEqual(ifImm, inst, true)) {
55 // will replace all equal or oposite conditions
56 conditions_.push_back(ifImm);
57 }
58 }
59
60 ASSERT(unswitchData != nullptr);
61 CloneBlocksAndInstructions<InstCloneType::CLONE_ALL, false>(*unswitchData->blocks, GetGraph());
62 BuildLoopUnswitchControlFlow(unswitchData);
63 BuildLoopUnswitchDataFlow(unswitchData, inst);
64 MakeLoopCloneInfo(unswitchData);
65 GetGraph()->RunPass<DominatorsTree>();
66
67 auto cloneLoop = GetClone(loop->GetHeader())->GetLoop();
68 ASSERT(cloneLoop != loop && cloneLoop->GetOuterLoop() == loop->GetOuterLoop());
69 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Loop " << loop->GetId() << " is copied";
70 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Created new loop, id = " << cloneLoop->GetId();
71 return cloneLoop;
72 }
73
BuildLoopUnswitchControlFlow(LoopClonerData * unswitchData)74 void LoopUnswitcher::BuildLoopUnswitchControlFlow(LoopClonerData *unswitchData)
75 {
76 ASSERT(unswitchData != nullptr);
77 auto outerClone = GetClone(unswitchData->outer);
78 auto preHeaderClone = GetClone(unswitchData->preHeader);
79
80 auto commonOuter = GetGraph()->CreateEmptyBlock();
81
82 while (!unswitchData->outer->GetSuccsBlocks().empty()) {
83 auto succ = unswitchData->outer->GetSuccsBlocks().front();
84 succ->ReplacePred(unswitchData->outer, commonOuter);
85 unswitchData->outer->RemoveSucc(succ);
86 }
87 unswitchData->outer->AddSucc(commonOuter);
88 outerClone->AddSucc(commonOuter);
89
90 auto commonPredecessor = GetGraph()->CreateEmptyBlock();
91 while (!unswitchData->preHeader->GetPredsBlocks().empty()) {
92 auto pred = unswitchData->preHeader->GetPredsBlocks().front();
93 pred->ReplaceSucc(unswitchData->preHeader, commonPredecessor);
94 unswitchData->preHeader->RemovePred(pred);
95 }
96 commonPredecessor->AddSucc(unswitchData->preHeader);
97 commonPredecessor->AddSucc(preHeaderClone);
98
99 for (auto &block : *unswitchData->blocks) {
100 if (block != unswitchData->preHeader) {
101 CloneEdges<CloneEdgeType::EDGE_PRED>(block);
102 }
103 if (block != unswitchData->outer) {
104 CloneEdges<CloneEdgeType::EDGE_SUCC>(block);
105 }
106 }
107 ASSERT(unswitchData->outer->GetPredBlockIndex(unswitchData->preHeader) ==
108 outerClone->GetPredBlockIndex(preHeaderClone));
109 ASSERT(unswitchData->header->GetPredBlockIndex(unswitchData->preHeader) ==
110 GetClone(unswitchData->header)->GetPredBlockIndex(preHeaderClone));
111 }
112
BuildLoopUnswitchDataFlow(LoopClonerData * unswitchData,Inst * ifInst)113 void LoopUnswitcher::BuildLoopUnswitchDataFlow(LoopClonerData *unswitchData, Inst *ifInst)
114 {
115 ASSERT(unswitchData != nullptr);
116 ProcessMarkedInsts(unswitchData);
117
118 auto commonOuter = unswitchData->outer->GetSuccessor(0);
119 for (auto phi : unswitchData->outer->PhiInsts()) {
120 auto phiClone = GetClone(phi);
121 auto phiJoin = commonOuter->GetGraph()->CreateInstPhi(phi->GetType(), phi->GetPc());
122 phi->ReplaceUsers(phiJoin);
123 phiJoin->AppendInput(phi);
124 phiJoin->AppendInput(phiClone);
125 commonOuter->AppendPhi(phiJoin);
126 }
127
128 auto commonPredecessor = unswitchData->preHeader->GetPredecessor(0);
129 auto ifInstUnswitch = ifInst->Clone(commonPredecessor->GetGraph());
130 for (size_t i = 0; i < ifInst->GetInputsCount(); i++) {
131 auto input = ifInst->GetInput(i);
132 ifInstUnswitch->SetInput(i, input.GetInst());
133 }
134 commonPredecessor->AppendInst(ifInstUnswitch);
135
136 ReplaceWithConstantCondition(ifInstUnswitch);
137 }
138
ReplaceWithConstantCondition(Inst * ifInst)139 void LoopUnswitcher::ReplaceWithConstantCondition(Inst *ifInst)
140 {
141 auto graph = ifInst->GetBasicBlock()->GetGraph();
142 auto i1 = graph->FindOrCreateConstant(1);
143 auto i2 = graph->FindOrCreateConstant(0);
144
145 auto ifImm = ifInst->CastToIfImm();
146 ASSERT(ifImm->GetCc() == ConditionCode::CC_NE || ifImm->GetCc() == ConditionCode::CC_EQ);
147 if ((ifImm->GetImm() == 0) != (ifImm->GetCc() == ConditionCode::CC_NE)) {
148 std::swap(i1, i2);
149 }
150
151 for (auto inst : conditions_) {
152 if (IsConditionEqual(inst, ifInst, true)) {
153 inst->SetInput(0, i2);
154 GetClone(inst)->SetInput(0, i1);
155 } else {
156 inst->SetInput(0, i1);
157 GetClone(inst)->SetInput(0, i2);
158 }
159 }
160 }
161
AllInputsConst(Inst * inst)162 static bool AllInputsConst(Inst *inst)
163 {
164 for (auto input : inst->GetInputs()) {
165 if (!input.GetInst()->IsConst()) {
166 return false;
167 }
168 }
169 return true;
170 }
171
IsHoistable(Inst * inst,Loop * loop)172 static bool IsHoistable(Inst *inst, Loop *loop)
173 {
174 for (auto input : inst->GetInputs()) {
175 if (!input.GetInst()->GetBasicBlock()->IsDominate(loop->GetPreHeader())) {
176 return false;
177 }
178 }
179 return true;
180 }
181
FindUnswitchInst(Loop * loop)182 Inst *LoopUnswitcher::FindUnswitchInst(Loop *loop)
183 {
184 for (auto bb : loop->GetBlocks()) {
185 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
186 continue;
187 }
188 auto ifInst = bb->GetLastInst();
189 if (AllInputsConst(ifInst)) {
190 continue;
191 }
192 if (IsHoistable(ifInst, loop)) {
193 return ifInst;
194 }
195 }
196 return nullptr;
197 }
198
IsSmallLoop(Loop * loop)199 bool LoopUnswitcher::IsSmallLoop(Loop *loop)
200 {
201 auto loopParser = CountableLoopParser(*loop);
202 auto loopInfo = loopParser.Parse();
203 if (!loopInfo.has_value()) {
204 return false;
205 }
206 auto iterations = CountableLoopParser::GetLoopIterations(*loopInfo);
207 if (!iterations.has_value()) {
208 return false;
209 }
210 return *iterations <= 1;
211 }
212
CountLoopInstructions(const Loop * loop)213 static uint32_t CountLoopInstructions(const Loop *loop)
214 {
215 uint32_t count = 0;
216 for (auto block : loop->GetBlocks()) {
217 count += block->CountInsts();
218 }
219 return count;
220 }
221
EstimateUnswitchInstructionsCount(BasicBlock * bb,const BasicBlock * backEdge,const Inst * unswitchInst,bool trueCond,Marker marker)222 static uint32_t EstimateUnswitchInstructionsCount(BasicBlock *bb, const BasicBlock *backEdge, const Inst *unswitchInst,
223 bool trueCond, Marker marker)
224 {
225 if (bb->IsMarked(marker)) {
226 return 0;
227 }
228 bb->SetMarker(marker);
229
230 uint32_t count = bb->CountInsts();
231 if (bb == backEdge) {
232 return count;
233 }
234
235 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
236 count += EstimateUnswitchInstructionsCount(bb->GetSuccsBlocks()[0], backEdge, unswitchInst, trueCond, marker);
237 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), false)) {
238 auto succ = trueCond ? bb->GetTrueSuccessor() : bb->GetFalseSuccessor();
239 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
240 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), true)) {
241 auto succ = trueCond ? bb->GetFalseSuccessor() : bb->GetTrueSuccessor();
242 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
243 } else {
244 for (auto succ : bb->GetSuccsBlocks()) {
245 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
246 }
247 }
248 return count;
249 }
250
EstimateInstructionsCount(const Loop * loop,const Inst * unswitchInst,int64_t * loopSize,int64_t * trueCount,int64_t * falseCount)251 void LoopUnswitcher::EstimateInstructionsCount(const Loop *loop, const Inst *unswitchInst, int64_t *loopSize,
252 int64_t *trueCount, int64_t *falseCount)
253 {
254 ASSERT(loop->GetBackEdges().size() == 1);
255 ASSERT(loop->GetInnerLoops().empty());
256 *loopSize = static_cast<int64_t>(CountLoopInstructions(loop));
257 auto backEdge = loop->GetBackEdges()[0];
258 auto graph = backEdge->GetGraph();
259
260 auto trueMarker = graph->NewMarker();
261 *trueCount = static_cast<int64_t>(
262 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, true, trueMarker));
263 graph->EraseMarker(trueMarker);
264
265 auto falseMarker = graph->NewMarker();
266 *falseCount = static_cast<int64_t>(
267 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, false, falseMarker));
268 graph->EraseMarker(falseMarker);
269 }
270 } // namespace ark::compiler
271