1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/loop_analyzer.h"
17 #include "optimizer/ir/analysis.h"
18 #include "compiler_logger.h"
19 #include "loop_unswitcher.h"
20
21 namespace ark::compiler {
LoopUnswitcher(Graph * graph,ArenaAllocator * allocator,ArenaAllocator * localAllocator)22 LoopUnswitcher::LoopUnswitcher(Graph *graph, ArenaAllocator *allocator, ArenaAllocator *localAllocator)
23 : GraphCloner(graph, allocator, localAllocator), conditions_(allocator->Adapter())
24 {
25 }
26
27 /**
28 * Unswitch loop in selected branch instruction.
29 * Return pointer to new loop or nullptr if cannot unswitch loop.
30 */
UnswitchLoop(Loop * loop,Inst * inst)31 Loop *LoopUnswitcher::UnswitchLoop(Loop *loop, Inst *inst)
32 {
33 ASSERT(loop != nullptr && !loop->IsRoot());
34 ASSERT_PRINT(IsLoopSingleBackEdgeExitPoint(loop), "Cloning blocks doesn't have single entry/exit point");
35 ASSERT(loop->GetPreHeader() != nullptr);
36 ASSERT(!loop->IsIrreducible());
37 ASSERT(!loop->IsOsrLoop());
38 if (loop->GetPreHeader()->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
39 return nullptr;
40 }
41 ASSERT(cloneMarker_ == UNDEF_MARKER);
42
43 auto markerHolder = MarkerHolder(GetGraph());
44 cloneMarker_ = markerHolder.GetMarker();
45 auto unswitchData = PrepareLoopToUnswitch(loop);
46
47 conditions_.clear();
48 for (auto bb : loop->GetBlocks()) {
49 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
50 continue;
51 }
52
53 auto ifImm = bb->GetLastInst();
54 if (IsConditionEqual(ifImm, inst, false) || IsConditionEqual(ifImm, inst, true)) {
55 // will replace all equal or oposite conditions
56 conditions_.push_back(ifImm);
57 }
58 }
59
60 CloneBlocksAndInstructions<InstCloneType::CLONE_ALL, false>(*unswitchData->blocks, GetGraph());
61 BuildLoopUnswitchControlFlow(unswitchData);
62 BuildLoopUnswitchDataFlow(unswitchData, inst);
63 MakeLoopCloneInfo(unswitchData);
64 GetGraph()->RunPass<DominatorsTree>();
65
66 auto cloneLoop = GetClone(loop->GetHeader())->GetLoop();
67 ASSERT(cloneLoop != loop && cloneLoop->GetOuterLoop() == loop->GetOuterLoop());
68 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Loop " << loop->GetId() << " is copied";
69 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Created new loop, id = " << cloneLoop->GetId();
70 return cloneLoop;
71 }
72
PrepareLoopToUnswitch(Loop * loop)73 GraphCloner::LoopClonerData *LoopUnswitcher::PrepareLoopToUnswitch(Loop *loop)
74 {
75 auto preHeader = loop->GetPreHeader();
76 ASSERT(preHeader->GetSuccsBlocks().size() == MAX_SUCCS_NUM);
77 // If `outside_succ` has more than 2 predecessors, create a new one
78 // with loop header and back-edge predecessors only and insert it before `outside_succ`
79 auto outsideSucc = GetLoopOutsideSuccessor(loop);
80 constexpr auto PREDS_NUM = 2;
81 if (outsideSucc->GetPredsBlocks().size() > PREDS_NUM) {
82 auto backEdge = loop->GetBackEdges()[0];
83 outsideSucc = CreateNewOutsideSucc(outsideSucc, backEdge, preHeader);
84 }
85 // Split outside succ after last phi
86 // create empty block before outside succ if outside succ don't contain phi insts
87 if (outsideSucc->HasPhi() && outsideSucc->GetFirstInst() != nullptr) {
88 auto lastPhi = outsideSucc->GetFirstInst()->GetPrev();
89 auto block = outsideSucc->SplitBlockAfterInstruction(lastPhi, true);
90 // if `outside_succ` is pre-header replace it by `block`
91 for (auto inLoop : loop->GetOuterLoop()->GetInnerLoops()) {
92 if (inLoop->GetPreHeader() == outsideSucc) {
93 inLoop->SetPreHeader(block);
94 }
95 }
96 } else if (outsideSucc->GetFirstInst() != nullptr) {
97 auto block = outsideSucc->InsertEmptyBlockBefore();
98 outsideSucc->GetLoop()->AppendBlock(block);
99 outsideSucc = block;
100 }
101 // Populate `LoopClonerData`
102 auto allocator = GetGraph()->GetLocalAllocator();
103 auto unswitchData = allocator->New<LoopClonerData>();
104 unswitchData->blocks = allocator->New<ArenaVector<BasicBlock *>>(allocator->Adapter());
105 unswitchData->blocks->resize(loop->GetBlocks().size() + 1);
106 unswitchData->blocks->at(0) = preHeader;
107 std::copy(loop->GetBlocks().begin(), loop->GetBlocks().end(), unswitchData->blocks->begin() + 1);
108 unswitchData->blocks->push_back(outsideSucc);
109 unswitchData->outer = outsideSucc;
110 unswitchData->header = loop->GetHeader();
111 unswitchData->preHeader = loop->GetPreHeader();
112 return unswitchData;
113 }
114
BuildLoopUnswitchControlFlow(LoopClonerData * unswitchData)115 void LoopUnswitcher::BuildLoopUnswitchControlFlow(LoopClonerData *unswitchData)
116 {
117 ASSERT(unswitchData != nullptr);
118 auto outerClone = GetClone(unswitchData->outer);
119 auto preHeaderClone = GetClone(unswitchData->preHeader);
120
121 auto commonOuter = GetGraph()->CreateEmptyBlock();
122
123 while (!unswitchData->outer->GetSuccsBlocks().empty()) {
124 auto succ = unswitchData->outer->GetSuccsBlocks().front();
125 succ->ReplacePred(unswitchData->outer, commonOuter);
126 unswitchData->outer->RemoveSucc(succ);
127 }
128 unswitchData->outer->AddSucc(commonOuter);
129 outerClone->AddSucc(commonOuter);
130
131 auto commonPredecessor = GetGraph()->CreateEmptyBlock();
132 while (!unswitchData->preHeader->GetPredsBlocks().empty()) {
133 auto pred = unswitchData->preHeader->GetPredsBlocks().front();
134 pred->ReplaceSucc(unswitchData->preHeader, commonPredecessor);
135 unswitchData->preHeader->RemovePred(pred);
136 }
137 commonPredecessor->AddSucc(unswitchData->preHeader);
138 commonPredecessor->AddSucc(preHeaderClone);
139
140 for (auto &block : *unswitchData->blocks) {
141 if (block != unswitchData->preHeader) {
142 CloneEdges<CloneEdgeType::EDGE_PRED>(block);
143 }
144 if (block != unswitchData->outer) {
145 CloneEdges<CloneEdgeType::EDGE_SUCC>(block);
146 }
147 }
148 ASSERT(unswitchData->outer->GetPredBlockIndex(unswitchData->preHeader) ==
149 outerClone->GetPredBlockIndex(preHeaderClone));
150 ASSERT(unswitchData->header->GetPredBlockIndex(unswitchData->preHeader) ==
151 GetClone(unswitchData->header)->GetPredBlockIndex(preHeaderClone));
152 }
153
BuildLoopUnswitchDataFlow(LoopClonerData * unswitchData,Inst * ifInst)154 void LoopUnswitcher::BuildLoopUnswitchDataFlow(LoopClonerData *unswitchData, Inst *ifInst)
155 {
156 ASSERT(unswitchData != nullptr);
157 for (const auto &block : *unswitchData->blocks) {
158 for (const auto &inst : block->AllInsts()) {
159 if (inst->GetOpcode() == Opcode::NOP) {
160 continue;
161 }
162 if (inst->IsMarked(cloneMarker_)) {
163 SetCloneInputs<false>(inst);
164 UpdateCaller(inst);
165 }
166 }
167 }
168
169 auto commonOuter = unswitchData->outer->GetSuccessor(0);
170 for (auto phi : unswitchData->outer->PhiInsts()) {
171 auto phiClone = GetClone(phi);
172 auto phiJoin = commonOuter->GetGraph()->CreateInstPhi(phi->GetType(), phi->GetPc());
173 phi->ReplaceUsers(phiJoin);
174 phiJoin->AppendInput(phi);
175 phiJoin->AppendInput(phiClone);
176 commonOuter->AppendPhi(phiJoin);
177 }
178
179 auto commonPredecessor = unswitchData->preHeader->GetPredecessor(0);
180 auto ifInstUnswitch = ifInst->Clone(commonPredecessor->GetGraph());
181 for (size_t i = 0; i < ifInst->GetInputsCount(); i++) {
182 auto input = ifInst->GetInput(i);
183 ifInstUnswitch->SetInput(i, input.GetInst());
184 }
185 commonPredecessor->AppendInst(ifInstUnswitch);
186
187 ReplaceWithConstantCondition(ifInstUnswitch);
188 }
189
ReplaceWithConstantCondition(Inst * ifInst)190 void LoopUnswitcher::ReplaceWithConstantCondition(Inst *ifInst)
191 {
192 auto graph = ifInst->GetBasicBlock()->GetGraph();
193 auto i1 = graph->FindOrCreateConstant(1);
194 auto i2 = graph->FindOrCreateConstant(0);
195
196 auto ifImm = ifInst->CastToIfImm();
197 ASSERT(ifImm->GetCc() == ConditionCode::CC_NE || ifImm->GetCc() == ConditionCode::CC_EQ);
198 if ((ifImm->GetImm() == 0) != (ifImm->GetCc() == ConditionCode::CC_NE)) {
199 std::swap(i1, i2);
200 }
201
202 for (auto inst : conditions_) {
203 if (IsConditionEqual(inst, ifInst, true)) {
204 inst->SetInput(0, i2);
205 GetClone(inst)->SetInput(0, i1);
206 } else {
207 inst->SetInput(0, i1);
208 GetClone(inst)->SetInput(0, i2);
209 }
210 }
211 }
212
AllInputsConst(Inst * inst)213 static bool AllInputsConst(Inst *inst)
214 {
215 for (auto input : inst->GetInputs()) {
216 if (!input.GetInst()->IsConst()) {
217 return false;
218 }
219 }
220 return true;
221 }
222
IsHoistable(Inst * inst,Loop * loop)223 static bool IsHoistable(Inst *inst, Loop *loop)
224 {
225 for (auto input : inst->GetInputs()) {
226 if (!input.GetInst()->GetBasicBlock()->IsDominate(loop->GetPreHeader())) {
227 return false;
228 }
229 }
230 return true;
231 }
232
FindUnswitchInst(Loop * loop)233 Inst *LoopUnswitcher::FindUnswitchInst(Loop *loop)
234 {
235 for (auto bb : loop->GetBlocks()) {
236 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
237 continue;
238 }
239 auto ifInst = bb->GetLastInst();
240 if (AllInputsConst(ifInst)) {
241 continue;
242 }
243 if (IsHoistable(ifInst, loop)) {
244 return ifInst;
245 }
246 }
247 return nullptr;
248 }
249
IsSmallLoop(Loop * loop)250 bool LoopUnswitcher::IsSmallLoop(Loop *loop)
251 {
252 auto loopParser = CountableLoopParser(*loop);
253 auto loopInfo = loopParser.Parse();
254 if (!loopInfo.has_value()) {
255 return false;
256 }
257 auto iterations = CountableLoopParser::GetLoopIterations(*loopInfo);
258 if (!iterations.has_value()) {
259 return false;
260 }
261 return *iterations <= 1;
262 }
263
CountLoopInstructions(const Loop * loop)264 static uint32_t CountLoopInstructions(const Loop *loop)
265 {
266 uint32_t count = 0;
267 for (auto block : loop->GetBlocks()) {
268 count += block->CountInsts();
269 }
270 return count;
271 }
272
EstimateUnswitchInstructionsCount(BasicBlock * bb,const BasicBlock * backEdge,const Inst * unswitchInst,bool trueCond,Marker marker)273 static uint32_t EstimateUnswitchInstructionsCount(BasicBlock *bb, const BasicBlock *backEdge, const Inst *unswitchInst,
274 bool trueCond, Marker marker)
275 {
276 if (bb->IsMarked(marker)) {
277 return 0;
278 }
279 bb->SetMarker(marker);
280
281 uint32_t count = bb->CountInsts();
282 if (bb == backEdge) {
283 return count;
284 }
285
286 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
287 count += EstimateUnswitchInstructionsCount(bb->GetSuccsBlocks()[0], backEdge, unswitchInst, trueCond, marker);
288 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), false)) {
289 auto succ = trueCond ? bb->GetTrueSuccessor() : bb->GetFalseSuccessor();
290 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
291 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), true)) {
292 auto succ = trueCond ? bb->GetFalseSuccessor() : bb->GetTrueSuccessor();
293 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
294 } else {
295 for (auto succ : bb->GetSuccsBlocks()) {
296 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
297 }
298 }
299 return count;
300 }
301
EstimateInstructionsCount(const Loop * loop,const Inst * unswitchInst,int64_t * loopSize,int64_t * trueCount,int64_t * falseCount)302 void LoopUnswitcher::EstimateInstructionsCount(const Loop *loop, const Inst *unswitchInst, int64_t *loopSize,
303 int64_t *trueCount, int64_t *falseCount)
304 {
305 ASSERT(loop->GetBackEdges().size() == 1);
306 ASSERT(loop->GetInnerLoops().empty());
307 *loopSize = static_cast<int64_t>(CountLoopInstructions(loop));
308 auto backEdge = loop->GetBackEdges()[0];
309 auto graph = backEdge->GetGraph();
310
311 auto trueMarker = graph->NewMarker();
312 *trueCount = static_cast<int64_t>(
313 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, true, trueMarker));
314 graph->EraseMarker(trueMarker);
315
316 auto falseMarker = graph->NewMarker();
317 *falseCount = static_cast<int64_t>(
318 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, false, falseMarker));
319 graph->EraseMarker(falseMarker);
320 }
321 } // namespace ark::compiler
322