1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/loop_analyzer.h"
17 #include "optimizer/ir/analysis.h"
18 #include "compiler_logger.h"
19 #include "loop_unswitcher.h"
20
21 namespace ark::compiler {
LoopUnswitcher(Graph * graph,ArenaAllocator * allocator,ArenaAllocator * localAllocator)22 LoopUnswitcher::LoopUnswitcher(Graph *graph, ArenaAllocator *allocator, ArenaAllocator *localAllocator)
23 : GraphCloner(graph, allocator, localAllocator), conditions_(allocator->Adapter())
24 {
25 }
26
27 /**
28 * Unswitch loop in selected branch instruction.
29 * Return pointer to new loop or nullptr if cannot unswitch loop.
30 */
UnswitchLoop(Loop * loop,Inst * inst)31 Loop *LoopUnswitcher::UnswitchLoop(Loop *loop, Inst *inst)
32 {
33 ASSERT(loop != nullptr && !loop->IsRoot());
34 ASSERT_PRINT(IsLoopSingleBackEdgeExitPoint(loop), "Cloning blocks doesn't have single entry/exit point");
35 ASSERT(loop->GetPreHeader() != nullptr);
36 ASSERT(!loop->IsIrreducible());
37 ASSERT(!loop->IsOsrLoop());
38 if (loop->GetPreHeader()->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
39 return nullptr;
40 }
41 ASSERT(cloneMarker_ == UNDEF_MARKER);
42
43 auto markerHolder = MarkerHolder(GetGraph());
44 cloneMarker_ = markerHolder.GetMarker();
45 auto unswitchData = PrepareLoopToClone(loop);
46
47 conditions_.clear();
48 for (auto bb : loop->GetBlocks()) {
49 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
50 continue;
51 }
52
53 auto ifImm = bb->GetLastInst();
54 if (IsConditionEqual(ifImm, inst, false) || IsConditionEqual(ifImm, inst, true)) {
55 // will replace all equal or oposite conditions
56 conditions_.push_back(ifImm);
57 }
58 }
59
60 CloneBlocksAndInstructions<InstCloneType::CLONE_ALL, false>(*unswitchData->blocks, GetGraph());
61 BuildLoopUnswitchControlFlow(unswitchData);
62 BuildLoopUnswitchDataFlow(unswitchData, inst);
63 MakeLoopCloneInfo(unswitchData);
64 GetGraph()->RunPass<DominatorsTree>();
65
66 auto cloneLoop = GetClone(loop->GetHeader())->GetLoop();
67 ASSERT(cloneLoop != loop && cloneLoop->GetOuterLoop() == loop->GetOuterLoop());
68 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Loop " << loop->GetId() << " is copied";
69 COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Created new loop, id = " << cloneLoop->GetId();
70 return cloneLoop;
71 }
72
BuildLoopUnswitchControlFlow(LoopClonerData * unswitchData)73 void LoopUnswitcher::BuildLoopUnswitchControlFlow(LoopClonerData *unswitchData)
74 {
75 ASSERT(unswitchData != nullptr);
76 auto outerClone = GetClone(unswitchData->outer);
77 auto preHeaderClone = GetClone(unswitchData->preHeader);
78
79 auto commonOuter = GetGraph()->CreateEmptyBlock();
80
81 while (!unswitchData->outer->GetSuccsBlocks().empty()) {
82 auto succ = unswitchData->outer->GetSuccsBlocks().front();
83 succ->ReplacePred(unswitchData->outer, commonOuter);
84 unswitchData->outer->RemoveSucc(succ);
85 }
86 unswitchData->outer->AddSucc(commonOuter);
87 outerClone->AddSucc(commonOuter);
88
89 auto commonPredecessor = GetGraph()->CreateEmptyBlock();
90 while (!unswitchData->preHeader->GetPredsBlocks().empty()) {
91 auto pred = unswitchData->preHeader->GetPredsBlocks().front();
92 pred->ReplaceSucc(unswitchData->preHeader, commonPredecessor);
93 unswitchData->preHeader->RemovePred(pred);
94 }
95 commonPredecessor->AddSucc(unswitchData->preHeader);
96 commonPredecessor->AddSucc(preHeaderClone);
97
98 for (auto &block : *unswitchData->blocks) {
99 if (block != unswitchData->preHeader) {
100 CloneEdges<CloneEdgeType::EDGE_PRED>(block);
101 }
102 if (block != unswitchData->outer) {
103 CloneEdges<CloneEdgeType::EDGE_SUCC>(block);
104 }
105 }
106 ASSERT(unswitchData->outer->GetPredBlockIndex(unswitchData->preHeader) ==
107 outerClone->GetPredBlockIndex(preHeaderClone));
108 ASSERT(unswitchData->header->GetPredBlockIndex(unswitchData->preHeader) ==
109 GetClone(unswitchData->header)->GetPredBlockIndex(preHeaderClone));
110 }
111
BuildLoopUnswitchDataFlow(LoopClonerData * unswitchData,Inst * ifInst)112 void LoopUnswitcher::BuildLoopUnswitchDataFlow(LoopClonerData *unswitchData, Inst *ifInst)
113 {
114 ASSERT(unswitchData != nullptr);
115 ProcessMarkedInsts(unswitchData);
116
117 auto commonOuter = unswitchData->outer->GetSuccessor(0);
118 for (auto phi : unswitchData->outer->PhiInsts()) {
119 auto phiClone = GetClone(phi);
120 auto phiJoin = commonOuter->GetGraph()->CreateInstPhi(phi->GetType(), phi->GetPc());
121 phi->ReplaceUsers(phiJoin);
122 phiJoin->AppendInput(phi);
123 phiJoin->AppendInput(phiClone);
124 commonOuter->AppendPhi(phiJoin);
125 }
126
127 auto commonPredecessor = unswitchData->preHeader->GetPredecessor(0);
128 auto ifInstUnswitch = ifInst->Clone(commonPredecessor->GetGraph());
129 for (size_t i = 0; i < ifInst->GetInputsCount(); i++) {
130 auto input = ifInst->GetInput(i);
131 ifInstUnswitch->SetInput(i, input.GetInst());
132 }
133 commonPredecessor->AppendInst(ifInstUnswitch);
134
135 ReplaceWithConstantCondition(ifInstUnswitch);
136 }
137
ReplaceWithConstantCondition(Inst * ifInst)138 void LoopUnswitcher::ReplaceWithConstantCondition(Inst *ifInst)
139 {
140 auto graph = ifInst->GetBasicBlock()->GetGraph();
141 auto i1 = graph->FindOrCreateConstant(1);
142 auto i2 = graph->FindOrCreateConstant(0);
143
144 auto ifImm = ifInst->CastToIfImm();
145 ASSERT(ifImm->GetCc() == ConditionCode::CC_NE || ifImm->GetCc() == ConditionCode::CC_EQ);
146 if ((ifImm->GetImm() == 0) != (ifImm->GetCc() == ConditionCode::CC_NE)) {
147 std::swap(i1, i2);
148 }
149
150 for (auto inst : conditions_) {
151 if (IsConditionEqual(inst, ifInst, true)) {
152 inst->SetInput(0, i2);
153 GetClone(inst)->SetInput(0, i1);
154 } else {
155 inst->SetInput(0, i1);
156 GetClone(inst)->SetInput(0, i2);
157 }
158 }
159 }
160
AllInputsConst(Inst * inst)161 static bool AllInputsConst(Inst *inst)
162 {
163 for (auto input : inst->GetInputs()) {
164 if (!input.GetInst()->IsConst()) {
165 return false;
166 }
167 }
168 return true;
169 }
170
IsHoistable(Inst * inst,Loop * loop)171 static bool IsHoistable(Inst *inst, Loop *loop)
172 {
173 for (auto input : inst->GetInputs()) {
174 if (!input.GetInst()->GetBasicBlock()->IsDominate(loop->GetPreHeader())) {
175 return false;
176 }
177 }
178 return true;
179 }
180
FindUnswitchInst(Loop * loop)181 Inst *LoopUnswitcher::FindUnswitchInst(Loop *loop)
182 {
183 for (auto bb : loop->GetBlocks()) {
184 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
185 continue;
186 }
187 auto ifInst = bb->GetLastInst();
188 if (AllInputsConst(ifInst)) {
189 continue;
190 }
191 if (IsHoistable(ifInst, loop)) {
192 return ifInst;
193 }
194 }
195 return nullptr;
196 }
197
IsSmallLoop(Loop * loop)198 bool LoopUnswitcher::IsSmallLoop(Loop *loop)
199 {
200 auto loopParser = CountableLoopParser(*loop);
201 auto loopInfo = loopParser.Parse();
202 if (!loopInfo.has_value()) {
203 return false;
204 }
205 auto iterations = CountableLoopParser::GetLoopIterations(*loopInfo);
206 if (!iterations.has_value()) {
207 return false;
208 }
209 return *iterations <= 1;
210 }
211
CountLoopInstructions(const Loop * loop)212 static uint32_t CountLoopInstructions(const Loop *loop)
213 {
214 uint32_t count = 0;
215 for (auto block : loop->GetBlocks()) {
216 count += block->CountInsts();
217 }
218 return count;
219 }
220
EstimateUnswitchInstructionsCount(BasicBlock * bb,const BasicBlock * backEdge,const Inst * unswitchInst,bool trueCond,Marker marker)221 static uint32_t EstimateUnswitchInstructionsCount(BasicBlock *bb, const BasicBlock *backEdge, const Inst *unswitchInst,
222 bool trueCond, Marker marker)
223 {
224 if (bb->IsMarked(marker)) {
225 return 0;
226 }
227 bb->SetMarker(marker);
228
229 uint32_t count = bb->CountInsts();
230 if (bb == backEdge) {
231 return count;
232 }
233
234 if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
235 count += EstimateUnswitchInstructionsCount(bb->GetSuccsBlocks()[0], backEdge, unswitchInst, trueCond, marker);
236 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), false)) {
237 auto succ = trueCond ? bb->GetTrueSuccessor() : bb->GetFalseSuccessor();
238 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
239 } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), true)) {
240 auto succ = trueCond ? bb->GetFalseSuccessor() : bb->GetTrueSuccessor();
241 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
242 } else {
243 for (auto succ : bb->GetSuccsBlocks()) {
244 count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
245 }
246 }
247 return count;
248 }
249
EstimateInstructionsCount(const Loop * loop,const Inst * unswitchInst,int64_t * loopSize,int64_t * trueCount,int64_t * falseCount)250 void LoopUnswitcher::EstimateInstructionsCount(const Loop *loop, const Inst *unswitchInst, int64_t *loopSize,
251 int64_t *trueCount, int64_t *falseCount)
252 {
253 ASSERT(loop->GetBackEdges().size() == 1);
254 ASSERT(loop->GetInnerLoops().empty());
255 *loopSize = static_cast<int64_t>(CountLoopInstructions(loop));
256 auto backEdge = loop->GetBackEdges()[0];
257 auto graph = backEdge->GetGraph();
258
259 auto trueMarker = graph->NewMarker();
260 *trueCount = static_cast<int64_t>(
261 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, true, trueMarker));
262 graph->EraseMarker(trueMarker);
263
264 auto falseMarker = graph->NewMarker();
265 *falseCount = static_cast<int64_t>(
266 EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, false, falseMarker));
267 graph->EraseMarker(falseMarker);
268 }
269 } // namespace ark::compiler
270