• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimizer/analysis/loop_analyzer.h"
17 #include "optimizer/ir/analysis.h"
18 #include "compiler_logger.h"
19 #include "loop_unswitcher.h"
20 
21 namespace ark::compiler {
LoopUnswitcher(Graph * graph,ArenaAllocator * allocator,ArenaAllocator * localAllocator)22 LoopUnswitcher::LoopUnswitcher(Graph *graph, ArenaAllocator *allocator, ArenaAllocator *localAllocator)
23     : GraphCloner(graph, allocator, localAllocator), conditions_(allocator->Adapter())
24 {
25 }
26 
27 /**
28  * Unswitch loop in selected branch instruction.
29  * Return pointer to new loop or nullptr if cannot unswitch loop.
30  */
UnswitchLoop(Loop * loop,Inst * inst)31 Loop *LoopUnswitcher::UnswitchLoop(Loop *loop, Inst *inst)
32 {
33     ASSERT(loop != nullptr && !loop->IsRoot());
34     ASSERT_PRINT(IsLoopSingleBackEdgeExitPoint(loop), "Cloning blocks doesn't have single entry/exit point");
35     ASSERT(loop->GetPreHeader() != nullptr);
36     ASSERT(!loop->IsIrreducible());
37     ASSERT(!loop->IsOsrLoop());
38     if (loop->GetPreHeader()->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
39         return nullptr;
40     }
41     ASSERT(cloneMarker_ == UNDEF_MARKER);
42 
43     auto markerHolder = MarkerHolder(GetGraph());
44     cloneMarker_ = markerHolder.GetMarker();
45     auto unswitchData = PrepareLoopToClone(loop);
46 
47     conditions_.clear();
48     for (auto bb : loop->GetBlocks()) {
49         if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
50             continue;
51         }
52 
53         auto ifImm = bb->GetLastInst();
54         if (IsConditionEqual(ifImm, inst, false) || IsConditionEqual(ifImm, inst, true)) {
55             // will replace all equal or oposite conditions
56             conditions_.push_back(ifImm);
57         }
58     }
59 
60     CloneBlocksAndInstructions<InstCloneType::CLONE_ALL, false>(*unswitchData->blocks, GetGraph());
61     BuildLoopUnswitchControlFlow(unswitchData);
62     BuildLoopUnswitchDataFlow(unswitchData, inst);
63     MakeLoopCloneInfo(unswitchData);
64     GetGraph()->RunPass<DominatorsTree>();
65 
66     auto cloneLoop = GetClone(loop->GetHeader())->GetLoop();
67     ASSERT(cloneLoop != loop && cloneLoop->GetOuterLoop() == loop->GetOuterLoop());
68     COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Loop " << loop->GetId() << " is copied";
69     COMPILER_LOG(DEBUG, GRAPH_CLONER) << "Created new loop, id = " << cloneLoop->GetId();
70     return cloneLoop;
71 }
72 
BuildLoopUnswitchControlFlow(LoopClonerData * unswitchData)73 void LoopUnswitcher::BuildLoopUnswitchControlFlow(LoopClonerData *unswitchData)
74 {
75     ASSERT(unswitchData != nullptr);
76     auto outerClone = GetClone(unswitchData->outer);
77     auto preHeaderClone = GetClone(unswitchData->preHeader);
78 
79     auto commonOuter = GetGraph()->CreateEmptyBlock();
80 
81     while (!unswitchData->outer->GetSuccsBlocks().empty()) {
82         auto succ = unswitchData->outer->GetSuccsBlocks().front();
83         succ->ReplacePred(unswitchData->outer, commonOuter);
84         unswitchData->outer->RemoveSucc(succ);
85     }
86     unswitchData->outer->AddSucc(commonOuter);
87     outerClone->AddSucc(commonOuter);
88 
89     auto commonPredecessor = GetGraph()->CreateEmptyBlock();
90     while (!unswitchData->preHeader->GetPredsBlocks().empty()) {
91         auto pred = unswitchData->preHeader->GetPredsBlocks().front();
92         pred->ReplaceSucc(unswitchData->preHeader, commonPredecessor);
93         unswitchData->preHeader->RemovePred(pred);
94     }
95     commonPredecessor->AddSucc(unswitchData->preHeader);
96     commonPredecessor->AddSucc(preHeaderClone);
97 
98     for (auto &block : *unswitchData->blocks) {
99         if (block != unswitchData->preHeader) {
100             CloneEdges<CloneEdgeType::EDGE_PRED>(block);
101         }
102         if (block != unswitchData->outer) {
103             CloneEdges<CloneEdgeType::EDGE_SUCC>(block);
104         }
105     }
106     ASSERT(unswitchData->outer->GetPredBlockIndex(unswitchData->preHeader) ==
107            outerClone->GetPredBlockIndex(preHeaderClone));
108     ASSERT(unswitchData->header->GetPredBlockIndex(unswitchData->preHeader) ==
109            GetClone(unswitchData->header)->GetPredBlockIndex(preHeaderClone));
110 }
111 
BuildLoopUnswitchDataFlow(LoopClonerData * unswitchData,Inst * ifInst)112 void LoopUnswitcher::BuildLoopUnswitchDataFlow(LoopClonerData *unswitchData, Inst *ifInst)
113 {
114     ASSERT(unswitchData != nullptr);
115     ProcessMarkedInsts(unswitchData);
116 
117     auto commonOuter = unswitchData->outer->GetSuccessor(0);
118     for (auto phi : unswitchData->outer->PhiInsts()) {
119         auto phiClone = GetClone(phi);
120         auto phiJoin = commonOuter->GetGraph()->CreateInstPhi(phi->GetType(), phi->GetPc());
121         phi->ReplaceUsers(phiJoin);
122         phiJoin->AppendInput(phi);
123         phiJoin->AppendInput(phiClone);
124         commonOuter->AppendPhi(phiJoin);
125     }
126 
127     auto commonPredecessor = unswitchData->preHeader->GetPredecessor(0);
128     auto ifInstUnswitch = ifInst->Clone(commonPredecessor->GetGraph());
129     for (size_t i = 0; i < ifInst->GetInputsCount(); i++) {
130         auto input = ifInst->GetInput(i);
131         ifInstUnswitch->SetInput(i, input.GetInst());
132     }
133     commonPredecessor->AppendInst(ifInstUnswitch);
134 
135     ReplaceWithConstantCondition(ifInstUnswitch);
136 }
137 
ReplaceWithConstantCondition(Inst * ifInst)138 void LoopUnswitcher::ReplaceWithConstantCondition(Inst *ifInst)
139 {
140     auto graph = ifInst->GetBasicBlock()->GetGraph();
141     auto i1 = graph->FindOrCreateConstant(1);
142     auto i2 = graph->FindOrCreateConstant(0);
143 
144     auto ifImm = ifInst->CastToIfImm();
145     ASSERT(ifImm->GetCc() == ConditionCode::CC_NE || ifImm->GetCc() == ConditionCode::CC_EQ);
146     if ((ifImm->GetImm() == 0) != (ifImm->GetCc() == ConditionCode::CC_NE)) {
147         std::swap(i1, i2);
148     }
149 
150     for (auto inst : conditions_) {
151         if (IsConditionEqual(inst, ifInst, true)) {
152             inst->SetInput(0, i2);
153             GetClone(inst)->SetInput(0, i1);
154         } else {
155             inst->SetInput(0, i1);
156             GetClone(inst)->SetInput(0, i2);
157         }
158     }
159 }
160 
AllInputsConst(Inst * inst)161 static bool AllInputsConst(Inst *inst)
162 {
163     for (auto input : inst->GetInputs()) {
164         if (!input.GetInst()->IsConst()) {
165             return false;
166         }
167     }
168     return true;
169 }
170 
IsHoistable(Inst * inst,Loop * loop)171 static bool IsHoistable(Inst *inst, Loop *loop)
172 {
173     for (auto input : inst->GetInputs()) {
174         if (!input.GetInst()->GetBasicBlock()->IsDominate(loop->GetPreHeader())) {
175             return false;
176         }
177     }
178     return true;
179 }
180 
FindUnswitchInst(Loop * loop)181 Inst *LoopUnswitcher::FindUnswitchInst(Loop *loop)
182 {
183     for (auto bb : loop->GetBlocks()) {
184         if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
185             continue;
186         }
187         auto ifInst = bb->GetLastInst();
188         if (AllInputsConst(ifInst)) {
189             continue;
190         }
191         if (IsHoistable(ifInst, loop)) {
192             return ifInst;
193         }
194     }
195     return nullptr;
196 }
197 
IsSmallLoop(Loop * loop)198 bool LoopUnswitcher::IsSmallLoop(Loop *loop)
199 {
200     auto loopParser = CountableLoopParser(*loop);
201     auto loopInfo = loopParser.Parse();
202     if (!loopInfo.has_value()) {
203         return false;
204     }
205     auto iterations = CountableLoopParser::GetLoopIterations(*loopInfo);
206     if (!iterations.has_value()) {
207         return false;
208     }
209     return *iterations <= 1;
210 }
211 
CountLoopInstructions(const Loop * loop)212 static uint32_t CountLoopInstructions(const Loop *loop)
213 {
214     uint32_t count = 0;
215     for (auto block : loop->GetBlocks()) {
216         count += block->CountInsts();
217     }
218     return count;
219 }
220 
EstimateUnswitchInstructionsCount(BasicBlock * bb,const BasicBlock * backEdge,const Inst * unswitchInst,bool trueCond,Marker marker)221 static uint32_t EstimateUnswitchInstructionsCount(BasicBlock *bb, const BasicBlock *backEdge, const Inst *unswitchInst,
222                                                   bool trueCond, Marker marker)
223 {
224     if (bb->IsMarked(marker)) {
225         return 0;
226     }
227     bb->SetMarker(marker);
228 
229     uint32_t count = bb->CountInsts();
230     if (bb == backEdge) {
231         return count;
232     }
233 
234     if (bb->GetSuccsBlocks().size() != MAX_SUCCS_NUM) {
235         count += EstimateUnswitchInstructionsCount(bb->GetSuccsBlocks()[0], backEdge, unswitchInst, trueCond, marker);
236     } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), false)) {
237         auto succ = trueCond ? bb->GetTrueSuccessor() : bb->GetFalseSuccessor();
238         count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
239     } else if (IsConditionEqual(unswitchInst, bb->GetLastInst(), true)) {
240         auto succ = trueCond ? bb->GetFalseSuccessor() : bb->GetTrueSuccessor();
241         count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
242     } else {
243         for (auto succ : bb->GetSuccsBlocks()) {
244             count += EstimateUnswitchInstructionsCount(succ, backEdge, unswitchInst, trueCond, marker);
245         }
246     }
247     return count;
248 }
249 
EstimateInstructionsCount(const Loop * loop,const Inst * unswitchInst,int64_t * loopSize,int64_t * trueCount,int64_t * falseCount)250 void LoopUnswitcher::EstimateInstructionsCount(const Loop *loop, const Inst *unswitchInst, int64_t *loopSize,
251                                                int64_t *trueCount, int64_t *falseCount)
252 {
253     ASSERT(loop->GetBackEdges().size() == 1);
254     ASSERT(loop->GetInnerLoops().empty());
255     *loopSize = static_cast<int64_t>(CountLoopInstructions(loop));
256     auto backEdge = loop->GetBackEdges()[0];
257     auto graph = backEdge->GetGraph();
258 
259     auto trueMarker = graph->NewMarker();
260     *trueCount = static_cast<int64_t>(
261         EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, true, trueMarker));
262     graph->EraseMarker(trueMarker);
263 
264     auto falseMarker = graph->NewMarker();
265     *falseCount = static_cast<int64_t>(
266         EstimateUnswitchInstructionsCount(loop->GetHeader(), backEdge, unswitchInst, false, falseMarker));
267     graph->EraseMarker(falseMarker);
268 }
269 }  // namespace ark::compiler
270