1 /*
2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "optimizer/analysis/bounds_analysis.h"
17 #include "optimizer/analysis/dominators_tree.h"
18 #include "optimizer/optimizations/balance_expressions.h"
19 #include "compiler_logger.h"
20
21 namespace ark::compiler {
RunImpl()22 bool BalanceExpressions::RunImpl()
23 {
24 processedInstMrk_ = GetGraph()->NewMarker();
25 for (auto bb : GetGraph()->GetBlocksRPO()) {
26 ProcessBB(bb);
27 }
28 GetGraph()->EraseMarker(processedInstMrk_);
29 return isApplied_;
30 }
31
InvalidateAnalyses()32 void BalanceExpressions::InvalidateAnalyses()
33 {
34 GetGraph()->InvalidateAnalysis<BoundsAnalysis>();
35 GetGraph()->InvalidateAnalysis<DominatorsTree>();
36 }
37
38 /**
39 * Iterate over instructions in reverse order, find every expression-chain by detecting
40 * the final operator of each chain, analyze the chain and optimize it if necessary.
41 */
ProcessBB(BasicBlock * bb)42 void BalanceExpressions::ProcessBB(BasicBlock *bb)
43 {
44 ASSERT(bb != nullptr);
45 SetBB(bb);
46
47 auto it = bb->InstsReverse().begin();
48 for (; it != it.end(); ++it) {
49 ASSERT(*it != nullptr);
50 if ((*it)->SetMarker(processedInstMrk_)) {
51 // The instruction is already processed;
52 continue;
53 }
54 if (SuitableInst(*it)) {
55 // The final operator of the chain is found, start analyzing:
56 auto instToContinueCycle = ProccesExpressionChain(*it);
57 it.SetCurrent(instToContinueCycle);
58 }
59 }
60 }
61
SuitableInst(Inst * inst)62 bool BalanceExpressions::SuitableInst(Inst *inst)
63 {
64 // Floating point operations are not associative:
65 if (inst->IsCommutative() && !IsFloatType(inst->GetType())) {
66 SetOpcode(inst->GetOpcode());
67 return true;
68 }
69 return false;
70 }
71
ProccesExpressionChain(Inst * lastOperator)72 Inst *BalanceExpressions::ProccesExpressionChain(Inst *lastOperator)
73 {
74 ASSERT(lastOperator != nullptr);
75 AnalyzeInputsRec(lastOperator);
76
77 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nConsidering expression:";
78 COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
79
80 auto instToContinue = NeedsOptimization() ? OptimizeExpression(lastOperator->GetNext()) : lastOperator;
81
82 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Expression considered.";
83
84 Reset();
85 return instToContinue;
86 }
87
88 /**
89 * Optimizes expression.
90 *
91 * By the end of the algorithm, `operators_.front()` points to the first instruction in expression and
92 * `operators_.back()` points to the last (`operators_.front()` dominates `operators_.back()`).
93 */
OptimizeExpression(Inst * instAfterExpr)94 Inst *BalanceExpressions::OptimizeExpression(Inst *instAfterExpr)
95 {
96 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Optimizing expression:";
97 AllocateSourcesRec<true>(0, sources_.size() - 1);
98
99 size_t size = operators_.size();
100 operators_.front()->SetNext(operators_[1]);
101 constexpr auto IMM_2 = 2;
102 operators_.back()->SetPrev(operators_[size - IMM_2]);
103 for (size_t i = 1; i < size - 1; i++) {
104 operators_[i]->SetNext(operators_[i + 1]);
105 operators_[i]->SetPrev(operators_[i - 1]);
106 }
107 if (instAfterExpr == nullptr) {
108 GetBB()->AppendRangeInst(operators_.front(), operators_.back());
109 } else {
110 GetBB()->InsertRangeBefore(operators_.front(), operators_.back(), instAfterExpr);
111 }
112
113 SetIsApplied(true);
114 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "\nOptimized expression:";
115 COMPILER_LOG(DEBUG, BALANCE_EXPR) << *this;
116
117 // Need to return pointer to the next instruction in order to correctly continue the cycle:
118 Inst *instToContinueCycle = operators_.front();
119 return instToContinueCycle;
120 }
121
122 /**
123 * Generates expression for sources_ in range from @param first_idx to @param last_idx by splitting them on
124 * two parts, calling itself for each part and binding them to an instruction from operators_.
125 *
126 * By the end of the algorithm, `operators_` are sorted in execution order
127 * (`operators_.front()` is the first, `operators_.back()` is the last).
128 */
129 template <bool IS_FIRST_CALL>
AllocateSourcesRec(size_t firstIdx,size_t lastIdx)130 Inst *BalanceExpressions::AllocateSourcesRec(size_t firstIdx, size_t lastIdx)
131 {
132 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Allocating operators for sources_[" << firstIdx << " to " << lastIdx << "]";
133 size_t splitIdx = firstIdx + GetBitFloor(lastIdx - firstIdx + 1) - 1;
134
135 Inst *lhs = GetOperand(firstIdx, splitIdx);
136 Inst *rhs = LIKELY((splitIdx + 1) != lastIdx) ? GetOperand(splitIdx + 1, lastIdx) : sources_[splitIdx + 1];
137 // `(split_idx + 1) == last_idx` means an odd number of `sources_` and we are considering
138 // the last (unpaired) source. This situation may occur only with `rhs`.
139 ASSERT(firstIdx != splitIdx);
140
141 // Operator allocation:
142 ASSERT(operatorsAllocIdx_ < operators_.size());
143 Inst *allocatedOperator = operators_[operatorsAllocIdx_];
144 operatorsAllocIdx_++;
145
146 // Operator initialization:
147 // NOLINTNEXTLINE(readability-braces-around-statements)
148 if constexpr (IS_FIRST_CALL) {
149 // The first call allocates and generates at the same time the last operator of expression and
150 // its users should be saved.
151 allocatedOperator->RemoveInputs();
152 allocatedOperator->GetBasicBlock()->EraseInst(allocatedOperator);
153 } else { // NOLINT(readability-misleading-indentation)
154 allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
155 }
156 allocatedOperator->SetBasicBlock(GetBB());
157 allocatedOperator->SetInput(0, lhs);
158 allocatedOperator->SetInput(1, rhs);
159
160 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "sources_[" << firstIdx << " to " << lastIdx << "] allocated";
161 return allocatedOperator;
162 }
163
GetOperand(size_t firstIdx,size_t lastIdx)164 Inst *BalanceExpressions::GetOperand(size_t firstIdx, size_t lastIdx)
165 {
166 ASSERT(lastIdx > firstIdx);
167 return (lastIdx - firstIdx == 1) ? GenerateElementalOperator(sources_[firstIdx], sources_[lastIdx])
168 : AllocateSourcesRec<false>(firstIdx, lastIdx);
169 }
170
171 /**
172 * Create an operator with direct sources
173 * (i.e. `lhs` and `rhs` are from sources_)
174 */
GenerateElementalOperator(Inst * lhs,Inst * rhs)175 Inst *BalanceExpressions::GenerateElementalOperator(Inst *lhs, Inst *rhs)
176 {
177 ASSERT(lhs && rhs);
178 ASSERT(operatorsAllocIdx_ < operators_.size());
179 Inst *allocatedOperator = operators_[operatorsAllocIdx_];
180 operatorsAllocIdx_++;
181 allocatedOperator->GetBasicBlock()->RemoveInst(allocatedOperator);
182
183 allocatedOperator->SetBasicBlock(GetBB());
184
185 // There is no need to clean users of lhs and rhs because it is cleaned during RemoveInst()
186 // (as soon as every of operator_insts_ is removed before further usage)
187
188 allocatedOperator->SetInput(0, lhs);
189 allocatedOperator->SetInput(1, rhs);
190 COMPILER_LOG(DEBUG, BALANCE_EXPR) << "Created an elemental operator:\n" << *allocatedOperator;
191 return allocatedOperator;
192 }
193
194 /**
195 * Recursively checks inputs.
196 * Fills arrays of source-insts and opreation-insts.
197 */
AnalyzeInputsRec(Inst * inst)198 void BalanceExpressions::AnalyzeInputsRec(Inst *inst)
199 {
200 exprCurDepth_++;
201 ASSERT(inst != nullptr);
202
203 auto lhsInput = inst->GetInput(0).GetInst();
204 auto rhsInput = inst->GetInput(1).GetInst();
205
206 TryExtendChainRec(lhsInput);
207 TryExtendChainRec(rhsInput);
208 operators_.push_back(inst);
209
210 if (exprMaxDepth_ < exprCurDepth_) {
211 exprMaxDepth_ = exprCurDepth_;
212 }
213 exprCurDepth_--;
214 }
215
216 /**
217 * Recursively checks if the instruction should be added in the current expression chain.
218 * If not, the considered instruction is the expression's term (source) and we save it for a later step.
219 */
TryExtendChainRec(Inst * inst)220 void BalanceExpressions::TryExtendChainRec(Inst *inst)
221 {
222 ASSERT(inst);
223 if (inst->GetOpcode() == GetOpcode()) {
224 if (inst->HasSingleUser()) {
225 inst->SetMarker(processedInstMrk_);
226
227 AnalyzeInputsRec(inst);
228
229 return;
230 }
231 }
232 sources_.push_back(inst);
233 }
234
235 /**
236 * Finds optimal depth and compares to the current.
237 * Both of the numbers are represented as pow(x, 2).
238 */
NeedsOptimization()239 bool BalanceExpressions::NeedsOptimization()
240 {
241 if (sources_.size() <= 3U) {
242 return false;
243 }
244 // Avoid large shift exponent for size_t
245 if (exprMaxDepth_ >= std::numeric_limits<size_t>::digits) {
246 return false;
247 }
248 size_t current = 1UL << (exprMaxDepth_);
249 size_t optimal = GetBitCeil(sources_.size());
250 return current > optimal;
251 }
252
Reset()253 void BalanceExpressions::Reset()
254 {
255 sources_.clear();
256 operators_.clear();
257 exprCurDepth_ = 0;
258 exprMaxDepth_ = 0;
259 operatorsAllocIdx_ = 0;
260 SetOpcode(Opcode::INVALID);
261 }
262
Dump(std::ostream * out) const263 void BalanceExpressions::Dump(std::ostream *out) const
264 {
265 (*out) << "Sources:\n";
266 for (auto i : sources_) {
267 (*out) << *i << '\n';
268 }
269
270 (*out) << "Operators:\n";
271 for (auto i : operators_) {
272 (*out) << *i << '\n';
273 }
274 }
275
276 template size_t BalanceExpressions::GetBitFloor<size_t>(size_t val);
277 template size_t BalanceExpressions::GetBitCeil<size_t>(size_t val);
278 } // namespace ark::compiler
279