1 /*
2 * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/bounds_analysis.h"
17 #include "optimizer/analysis/dominators_tree.h"
18 #include "optimizer/optimizations/balance_expressions.h"
19 #include "compiler_logger.h"
20
21 namespace ark::compiler {
RunImpl()22 bool BalanceExpressions::RunImpl()
23 {
24 processedInstMrk_ = GetGraph()->NewMarker();
25 for (auto bb : GetGraph()->GetBlocksRPO()) {
26 ProcessBB(bb);
27 }
28 GetGraph()->EraseMarker(processedInstMrk_);
29 return isApplied_;
30 }
31
InvalidateAnalyses()32 void BalanceExpressions::InvalidateAnalyses()
33 {
34 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
35 GetGraph()->InvalidateAnalysis<DominatorsTree>();
36 }
37
38 /**
39 * Iterate over instructions in reverse order, find every expression-chain by detecting
40 * the final operator of each chain, analyze the chain and optimize it if necessary.
41 */
ProcessBB(BasicBlock * bb)42 void BalanceExpressions::ProcessBB(BasicBlock *bb)
43 {
44 ASSERT(bb != nullptr);
45 SetBB(bb);
46
47 auto it = bb->InstsReverse().begin();
48 for (; it != it.end(); ++it) {
49 ASSERT(*it != nullptr);
50 if ((*it)->SetMarker(processedInstMrk_)) {
51 // The instruction is already processed;
52 continue;
53 }
54 if (SuitableInst(*it)) {
55 // The final operator of the chain is found, start analyzing:
56 auto instToContinueCycle = ProccesExpressionChain(*it);
57 it.SetCurrent(instToContinueCycle);
58 }
59 }
60 }
61
SuitableInst(Inst * inst)62 bool BalanceExpressions::SuitableInst(Inst *inst)
63 {
64 // Floating point operations are not associative:
65 if (inst->IsCommutative() && !IsFloatType(inst->GetType())) {
66 SetOpcode(inst->GetOpcode());
67 return true;
68 }
69 return false;
70 }
71
ProccesExpressionChain(Inst * lastOperator)72 Inst *BalanceExpressions::ProccesExpressionChain(Inst *lastOperator)
73 {
74 ASSERT(lastOperator != nullptr);
75 AnalyzeInputsRec(lastOperator);
76
77 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nConsidering expression:";
78 COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
79
80 auto instToContinue = NeedsOptimization() ? OptimizeExpression(lastOperator->GetNext()) : lastOperator;
81
82 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Expression considered.";
83
84 Reset();
85 return instToContinue;
86 }
87
88 /**
89 * Optimizes expression.
90 *
91 * By the end of the algorithm, `operators_.front()` points to the first instruction in expression and
92 * `operators_.back()` points to the last (`operators_.front()` dominates `operators_.back()`).
93 */
OptimizeExpression(Inst * instAfterExpr)94 Inst *BalanceExpressions::OptimizeExpression(Inst *instAfterExpr)
95 {
96 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Optimizing expression:";
97 AllocateSourcesRec<true>(0, sources_.size() - 1);
98
99 size_t size = operators_.size();
100 operators_.front()->SetNext(operators_[1]);
101 constexpr auto IMM_2 = 2;
102 operators_.back()->SetPrev(operators_[size - IMM_2]);
103 for (size_t i = 1; i < size - 1; i++) {
104 operators_[i]->SetNext(operators_[i + 1]);
105 operators_[i]->SetPrev(operators_[i - 1]);
106 }
107 if (instAfterExpr == nullptr) {
108 GetBB()->AppendRangeInst(operators_.front(), operators_.back());
109 } else {
110 GetBB()->InsertRangeBefore(operators_.front(), operators_.back(), instAfterExpr);
111 }
112
113 SetIsApplied(true);
114 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nOptimized expression:";
115 COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
116
117 // Need to return pointer to the next instruction in order to correctly continue the cycle:
118 Inst *instToContinueCycle = operators_.front();
119 return instToContinueCycle;
120 }
121
122 /**
123 * Generates expression for sources_ in range from @param first_idx to @param last_idx by splitting them on
124 * two parts, calling itself for each part and binding them to an instruction from operators_.
125 *
126 * By the end of the algorithm, `operators_` are sorted in execution order
127 * (`operators_.front()` is the first, `operators_.back()` is the last).
128 */
129 template <bool IS_FIRST_CALL>
AllocateSourcesRec(size_t firstIdx,size_t lastIdx)130 Inst *BalanceExpressions::AllocateSourcesRec(size_t firstIdx, size_t lastIdx)
131 {
132 ASSERT(lastIdx > firstIdx);
133 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Allocating operators for sources_[" << firstIdx << " to " << lastIdx << "]";
134 size_t memSize = firstIdx + GetBitFloor(lastIdx - firstIdx + 1);
135 size_t splitIdx = memSize > 0 ? memSize - 1 : 0;
136
137 Inst *lhs = GetOperand(firstIdx, splitIdx);
138 Inst *rhs = LIKELY((splitIdx + 1) != lastIdx) ? GetOperand(splitIdx + 1, lastIdx) : sources_[splitIdx + 1];
139 // `(split_idx + 1) == last_idx` means an odd number of `sources_` and we are considering
140 // the last (unpaired) source. This situation may occur only with `rhs`.
141 ASSERT(firstIdx != splitIdx);
142
143 // Operator allocation:
144 ASSERT(operatorsAllocIdx_ < operators_.size());
145 Inst *allocatedOperator = operators_[operatorsAllocIdx_];
146 operatorsAllocIdx_++;
147
148 // Operator initialization:
149 // NOLINTNEXTLINE(readability-braces-around-statements)
150 if constexpr (IS_FIRST_CALL) {
151 // The first call allocates and generates at the same time the last operator of expression and
152 // its users should be saved.
153 allocatedOperator->RemoveInputs();
154 allocatedOperator->GetBasicBlock()->EraseInst(allocatedOperator);
155 } else { // NOLINT(readability-misleading-indentation)
156 allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
157 }
158 allocatedOperator->SetBasicBlock(GetBB());
159 allocatedOperator->SetInput(0, lhs);
160 allocatedOperator->SetInput(1, rhs);
161
162 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "sources_[" << firstIdx << " to " << lastIdx << "] allocated";
163 return allocatedOperator;
164 }
165
GetOperand(size_t firstIdx,size_t lastIdx)166 Inst *BalanceExpressions::GetOperand(size_t firstIdx, size_t lastIdx)
167 {
168 ASSERT(lastIdx > firstIdx);
169 return (lastIdx - firstIdx == 1) ? GenerateElementalOperator(sources_[firstIdx], sources_[lastIdx])
170 : AllocateSourcesRec<false>(firstIdx, lastIdx);
171 }
172
173 /**
174 * Create an operator with direct sources
175 * (i.e. `lhs` and `rhs` are from sources_)
176 */
GenerateElementalOperator(Inst * lhs,Inst * rhs)177 Inst *BalanceExpressions::GenerateElementalOperator(Inst *lhs, Inst *rhs)
178 {
179 ASSERT(lhs && rhs);
180 ASSERT(operatorsAllocIdx_ < operators_.size());
181 Inst *allocatedOperator = operators_[operatorsAllocIdx_];
182 operatorsAllocIdx_++;
183 allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
184
185 allocatedOperator->SetBasicBlock(GetBB());
186
187 // There is no need to clean users of lhs and rhs because it is cleaned during RemoveInst()
188 // (as soon as every of operator_insts_ is removed before further usage)
189
190 allocatedOperator->SetInput(0, lhs);
191 allocatedOperator->SetInput(1, rhs);
192 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Created an elemental operator:\n" << *allocatedOperator;
193 return allocatedOperator;
194 }
195
196 /**
197 * Recursively checks inputs.
198 * Fills arrays of source-insts and opreation-insts.
199 */
AnalyzeInputsRec(Inst * inst)200 void BalanceExpressions::AnalyzeInputsRec(Inst *inst)
201 {
202 exprCurDepth_++;
203 ASSERT(inst != nullptr);
204
205 auto lhsInput = inst->GetInput(0).GetInst();
206 auto rhsInput = inst->GetInput(1).GetInst();
207
208 TryExtendChainRec(lhsInput);
209 TryExtendChainRec(rhsInput);
210 operators_.push_back(inst);
211
212 if (exprMaxDepth_ < exprCurDepth_) {
213 exprMaxDepth_ = exprCurDepth_;
214 }
215 exprCurDepth_--;
216 }
217
218 /**
219 * Recursively checks if the instruction should be added in the current expression chain.
220 * If not, the considered instruction is the expression's term (source) and we save it for a later step.
221 */
TryExtendChainRec(Inst * inst)222 void BalanceExpressions::TryExtendChainRec(Inst *inst)
223 {
224 ASSERT(inst);
225 if (inst->GetOpcode() == GetOpcode()) {
226 if (inst->HasSingleUser()) {
227 inst->SetMarker(processedInstMrk_);
228
229 AnalyzeInputsRec(inst);
230
231 return;
232 }
233 }
234 sources_.push_back(inst);
235 }
236
237 /**
238 * Finds optimal depth and compares to the current.
239 * Both of the numbers are represented as pow(x, 2).
240 */
NeedsOptimization()241 bool BalanceExpressions::NeedsOptimization()
242 {
243 if (sources_.size() <= 3U) {
244 return false;
245 }
246 // Avoid large shift exponent for size_t
247 if (exprMaxDepth_ >= std::numeric_limits<size_t>::digits) {
248 return false;
249 }
250 size_t current = 1UL << (exprMaxDepth_);
251 size_t optimal = GetBitCeil(sources_.size());
252 return current > optimal;
253 }
254
Reset()255 void BalanceExpressions::Reset()
256 {
257 sources_.clear();
258 operators_.clear();
259 exprCurDepth_ = 0;
260 exprMaxDepth_ = 0;
261 operatorsAllocIdx_ = 0;
262 SetOpcode(Opcode::INVALID);
263 }
264
Dump(std::ostream * out) const265 void BalanceExpressions::Dump(std::ostream *out) const
266 {
267 (*out) << "Sources:\n";
268 for (auto i : sources_) {
269 (*out) << *i << '\n';
270 }
271
272 (*out) << "Operators:\n";
273 for (auto i : operators_) {
274 (*out) << *i << '\n';
275 }
276 }
277
278 template size_t BalanceExpressions::GetBitFloor<size_t>(size_t val);
279 template size_t BalanceExpressions::GetBitCeil<size_t>(size_t val);
280 } // namespace ark::compiler
281