• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "optimizer/analysis/bounds_analysis.h"
17 #include "optimizer/analysis/dominators_tree.h"
18 #include "optimizer/optimizations/balance_expressions.h"
19 #include "compiler_logger.h"
20 
21 namespace ark::compiler {
RunImpl()22 bool BalanceExpressions::RunImpl()
23 {
24     processedInstMrk_ = GetGraph()->NewMarker();
25     for (auto bb : GetGraph()->GetBlocksRPO()) {
26         ProcessBB(bb);
27     }
28     GetGraph()->EraseMarker(processedInstMrk_);
29     return isApplied_;
30 }
31 
InvalidateAnalyses()32 void BalanceExpressions::InvalidateAnalyses()
33 {
34     GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
35     GetGraph()->InvalidateAnalysis<DominatorsTree>();
36 }
37 
38 /**
39  *  Iterate over instructions in reverse order, find every expression-chain by detecting
40  *  the final operator of each chain, analyze the chain and optimize it if necessary.
41  */
ProcessBB(BasicBlock * bb)42 void BalanceExpressions::ProcessBB(BasicBlock *bb)
43 {
44     ASSERT(bb != nullptr);
45     SetBB(bb);
46 
47     auto it = bb->InstsReverse().begin();
48     for (; it != it.end(); ++it) {
49         ASSERT(*it != nullptr);
50         if ((*it)->SetMarker(processedInstMrk_)) {
51             // The instruction is already processed;
52             continue;
53         }
54         if (SuitableInst(*it)) {
55             // The final operator of the chain is found, start analyzing:
56             auto instToContinueCycle = ProccesExpressionChain(*it);
57             it.SetCurrent(instToContinueCycle);
58         }
59     }
60 }
61 
SuitableInst(Inst * inst)62 bool BalanceExpressions::SuitableInst(Inst *inst)
63 {
64     // Floating point operations are not associative:
65     if (inst->IsCommutative() && !IsFloatType(inst->GetType())) {
66         SetOpcode(inst->GetOpcode());
67         return true;
68     }
69     return false;
70 }
71 
ProccesExpressionChain(Inst * lastOperator)72 Inst *BalanceExpressions::ProccesExpressionChain(Inst *lastOperator)
73 {
74     ASSERT(lastOperator != nullptr);
75     AnalyzeInputsRec(lastOperator);
76 
77     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nConsidering expression:";
78     COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
79 
80     auto instToContinue = NeedsOptimization() ? OptimizeExpression(lastOperator->GetNext()) : lastOperator;
81 
82     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Expression considered.";
83 
84     Reset();
85     return instToContinue;
86 }
87 
88 /**
89  * Optimizes expression.
90  *
91  * By the end of the algorithm, `operators_.front()` points to the first instruction in expression and
92  * `operators_.back()` points to the last (`operators_.front()` dominates `operators_.back()`).
93  */
OptimizeExpression(Inst * instAfterExpr)94 Inst *BalanceExpressions::OptimizeExpression(Inst *instAfterExpr)
95 {
96     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Optimizing expression:";
97     AllocateSourcesRec<true>(0, sources_.size() - 1);
98 
99     size_t size = operators_.size();
100     operators_.front()->SetNext(operators_[1]);
101     constexpr auto IMM_2 = 2;
102     operators_.back()->SetPrev(operators_[size - IMM_2]);
103     for (size_t i = 1; i < size - 1; i++) {
104         operators_[i]->SetNext(operators_[i + 1]);
105         operators_[i]->SetPrev(operators_[i - 1]);
106     }
107     if (instAfterExpr == nullptr) {
108         GetBB()->AppendRangeInst(operators_.front(), operators_.back());
109     } else {
110         GetBB()->InsertRangeBefore(operators_.front(), operators_.back(), instAfterExpr);
111     }
112 
113     SetIsApplied(true);
114     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nOptimized expression:";
115     COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
116 
117     // Need to return pointer to the next instruction in order to correctly continue the cycle:
118     Inst *instToContinueCycle = operators_.front();
119     return instToContinueCycle;
120 }
121 
122 /**
123  * Generates expression for sources_ in range from @param first_idx to @param last_idx by splitting them on
124  * two parts, calling itself for each part and binding them to an instruction from operators_.
125  *
126  * By the end of the algorithm, `operators_` are sorted in execution order
127  * (`operators_.front()` is the first, `operators_.back()` is the last).
128  */
129 template <bool IS_FIRST_CALL>
AllocateSourcesRec(size_t firstIdx,size_t lastIdx)130 Inst *BalanceExpressions::AllocateSourcesRec(size_t firstIdx, size_t lastIdx)
131 {
132     ASSERT(lastIdx > firstIdx);
133     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Allocating operators for sources_[" << firstIdx << " to " << lastIdx << "]";
134     size_t memSize = firstIdx + GetBitFloor(lastIdx - firstIdx + 1);
135     size_t splitIdx = memSize > 0 ? memSize - 1 : 0;
136 
137     Inst *lhs = GetOperand(firstIdx, splitIdx);
138     Inst *rhs = LIKELY((splitIdx + 1) != lastIdx) ? GetOperand(splitIdx + 1, lastIdx) : sources_[splitIdx + 1];
139     // `(split_idx + 1) == last_idx` means an odd number of `sources_` and we are considering
140     // the last (unpaired) source. This situation may occur only with `rhs`.
141     ASSERT(firstIdx != splitIdx);
142 
143     // Operator allocation:
144     ASSERT(operatorsAllocIdx_ < operators_.size());
145     Inst *allocatedOperator = operators_[operatorsAllocIdx_];
146     operatorsAllocIdx_++;
147 
148     // Operator initialization:
149     // NOLINTNEXTLINE(readability-braces-around-statements)
150     if constexpr (IS_FIRST_CALL) {
151         // The first call allocates and generates at the same time the last operator of expression and
152         // its users should be saved.
153         allocatedOperator->RemoveInputs();
154         allocatedOperator->GetBasicBlock()->EraseInst(allocatedOperator);
155     } else {  // NOLINT(readability-misleading-indentation)
156         allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
157     }
158     allocatedOperator->SetBasicBlock(GetBB());
159     allocatedOperator->SetInput(0, lhs);
160     allocatedOperator->SetInput(1, rhs);
161 
162     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "sources_[" << firstIdx << " to " << lastIdx << "] allocated";
163     return allocatedOperator;
164 }
165 
GetOperand(size_t firstIdx,size_t lastIdx)166 Inst *BalanceExpressions::GetOperand(size_t firstIdx, size_t lastIdx)
167 {
168     ASSERT(lastIdx > firstIdx);
169     return (lastIdx - firstIdx == 1) ? GenerateElementalOperator(sources_[firstIdx], sources_[lastIdx])
170                                      : AllocateSourcesRec<false>(firstIdx, lastIdx);
171 }
172 
173 /**
174  * Create an operator with direct sources
175  * (i.e. `lhs` and `rhs` are from sources_)
176  */
GenerateElementalOperator(Inst * lhs,Inst * rhs)177 Inst *BalanceExpressions::GenerateElementalOperator(Inst *lhs, Inst *rhs)
178 {
179     ASSERT(lhs && rhs);
180     ASSERT(operatorsAllocIdx_ < operators_.size());
181     Inst *allocatedOperator = operators_[operatorsAllocIdx_];
182     operatorsAllocIdx_++;
183     allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
184 
185     allocatedOperator->SetBasicBlock(GetBB());
186 
187     // There is no need to clean users of lhs and rhs because it is cleaned during RemoveInst()
188     // (as soon as every of operator_insts_ is removed before further usage)
189 
190     allocatedOperator->SetInput(0, lhs);
191     allocatedOperator->SetInput(1, rhs);
192     COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Created an elemental operator:\n" << *allocatedOperator;
193     return allocatedOperator;
194 }
195 
196 /**
197  * Recursively checks inputs.
198  * Fills arrays of source-insts and opreation-insts.
199  */
AnalyzeInputsRec(Inst * inst)200 void BalanceExpressions::AnalyzeInputsRec(Inst *inst)
201 {
202     exprCurDepth_++;
203     ASSERT(inst != nullptr);
204 
205     auto lhsInput = inst->GetInput(0).GetInst();
206     auto rhsInput = inst->GetInput(1).GetInst();
207 
208     TryExtendChainRec(lhsInput);
209     TryExtendChainRec(rhsInput);
210     operators_.push_back(inst);
211 
212     if (exprMaxDepth_ < exprCurDepth_) {
213         exprMaxDepth_ = exprCurDepth_;
214     }
215     exprCurDepth_--;
216 }
217 
218 /**
219  * Recursively checks if the instruction should be added in the current expression chain.
220  * If not, the considered instruction is the expression's term (source) and we save it for a later step.
221  */
TryExtendChainRec(Inst * inst)222 void BalanceExpressions::TryExtendChainRec(Inst *inst)
223 {
224     ASSERT(inst);
225     if (inst->GetOpcode() == GetOpcode()) {
226         if (inst->HasSingleUser()) {
227             inst->SetMarker(processedInstMrk_);
228 
229             AnalyzeInputsRec(inst);
230 
231             return;
232         }
233     }
234     sources_.push_back(inst);
235 }
236 
237 /**
238  * Finds optimal depth and compares to the current.
239  * Both of the numbers are represented as pow(x, 2).
240  */
NeedsOptimization()241 bool BalanceExpressions::NeedsOptimization()
242 {
243     if (sources_.size() <= 3U) {
244         return false;
245     }
246     // Avoid large shift exponent for size_t
247     if (exprMaxDepth_ >= std::numeric_limits<size_t>::digits) {
248         return false;
249     }
250     size_t current = 1UL << (exprMaxDepth_);
251     size_t optimal = GetBitCeil(sources_.size());
252     return current > optimal;
253 }
254 
Reset()255 void BalanceExpressions::Reset()
256 {
257     sources_.clear();
258     operators_.clear();
259     exprCurDepth_ = 0;
260     exprMaxDepth_ = 0;
261     operatorsAllocIdx_ = 0;
262     SetOpcode(Opcode::INVALID);
263 }
264 
Dump(std::ostream * out) const265 void BalanceExpressions::Dump(std::ostream *out) const
266 {
267     (*out) << "Sources:\n";
268     for (auto i : sources_) {
269         (*out) << *i << '\n';
270     }
271 
272     (*out) << "Operators:\n";
273     for (auto i : operators_) {
274         (*out) << *i << '\n';
275     }
276 }
277 
278 template size_t BalanceExpressions::GetBitFloor<size_t>(size_t val);
279 template size_t BalanceExpressions::GetBitCeil<size_t>(size_t val);
280 }  // namespace ark::compiler
281