• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H
17 #define COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H
18 
19 #include "compiler_logger.h"
20 #include "compiler_options.h"
21 #include "optimizer/ir/analysis.h"
22 #include "optimizer/ir/graph.h"
23 #include "optimizer/ir/graph_visitor.h"
24 
25 namespace ark::compiler {
26 // NOLINTNEXTLINE(fuchsia-multiple-inheritance)
27 class Lowering : public Optimization, public GraphVisitor {
28     using Optimization::Optimization;
29 
30 public:
Lowering(Graph * graph)31     explicit Lowering(Graph *graph) : Optimization(graph) {}
32 
33     bool RunImpl() override;
34 
IsEnable()35     bool IsEnable() const override
36     {
37         return g_options.IsCompilerLowering();
38     }
39 
GetPassName()40     const char *GetPassName() const override
41     {
42         return "Lowering";
43     }
44 
45     void InvalidateAnalyses() override;
46 
47     static constexpr uint8_t HALF_SIZE = 32;
48     static constexpr uint8_t WORD_SIZE = 64;
49 
50 private:
51     /**
52      * Utility template classes aimed to simplify pattern matching over IR-graph.
53      * Patterns are trees declared as a type using Any, UnaryOp, BinaryOp and Operand composition.
54      * Then IR-subtree could be tested for matching by calling static Capture method.
55      * To capture operands from matched subtree Operand<Index> should be used, where
56      * Index is an operand's index within OperandsCapture.
57      *
58      * For example, suppose that we want to test if IR-subtree rooted by some Inst matches add or sub instruction
59      * pattern:
60      *
61      * Inst* inst = ...;
62      * using Predicate = Any<BinaryOp<Opcode::Add, Operand<0>, Operand<1>,
63      *                       BinaryOp<Opcode::Sub, Operand<0>, Operand<1>>;
64      * OperandsCapture<2> capture{};
65      * bool is_add_or_sub = Predicate::Capture(inst, capture);
66      *
67      * If inst is a binary instruction with opcode Add or Sub then Capture will return `true`,
68      * capture.Get(0) will return left inst's input and capture.Get(1) will return right inst's input.
69      */
70 
71     // Flags altering matching behavior.
72     enum Flags {
73         NONE = 0,
74         COMMUTATIVE = 1,  // binary operation is commutative
75         C = COMMUTATIVE,
76         SINGLE_USER = 2,  // operation must have only single user
77         S = SINGLE_USER
78     };
79 
IsSet(uint64_t flags,Flags flag)80     static bool IsSet(uint64_t flags, Flags flag)
81     {
82         return (flags & flag) != 0;
83     }
84 
IsNotSet(uint64_t flags,Flags flag)85     static bool IsNotSet(uint64_t flags, Flags flag)
86     {
87         return (flags & flag) == 0;
88     }
89 
90     template <size_t MAX_OPERANDS>
91     class OperandsCapture {
92     public:
Get(size_t i)93         Inst *Get(size_t i)
94         {
95             ASSERT(i < MAX_OPERANDS);
96             return operands_[i];
97         }
98 
Set(size_t i,Inst * inst)99         void Set(size_t i, Inst *inst)
100         {
101             ASSERT(i < MAX_OPERANDS);
102 #pragma GCC diagnostic push
103 #pragma GCC diagnostic ignored "-Warray-bounds"
104             operands_[i] = inst;
105 #pragma GCC diagnostic pop
106         }
107 
108         // Returns true if all non-constant operands have the same common type (obtained using GetCommonType) as all
109         // other operands.
HaveCommonType()110         bool HaveCommonType() const
111         {
112             auto nonConstType = DataType::LAST;
113             for (size_t i = 0; i < MAX_OPERANDS; i++) {
114                 if (operands_[i]->GetOpcode() != Opcode::Constant) {
115                     nonConstType = GetCommonType(operands_[i]->GetType());
116                     break;
117                 }
118             }
119             // all operands are constants
120             if (nonConstType == DataType::LAST) {
121                 nonConstType = operands_[0]->GetType();
122             }
123             for (size_t i = 0; i < MAX_OPERANDS; i++) {
124                 if (operands_[i]->GetOpcode() == Opcode::Constant) {
125                     if (GetCommonType(operands_[i]->GetType()) != GetCommonType(nonConstType)) {
126                         return false;
127                     }
128                 } else if (nonConstType != GetCommonType(operands_[i]->GetType())) {
129                     return false;
130                 }
131             }
132             return true;
133         }
134 
135     private:
136         std::array<Inst *, MAX_OPERANDS> operands_;
137     };
138 
139     template <size_t MAX_INSTS>
140     class InstructionsCapture {
141     public:
Add(Inst * inst)142         void Add(Inst *inst)
143         {
144             ASSERT(currentIdx_ < MAX_INSTS);
145             insts_[currentIdx_++] = inst;
146         }
147 
GetCurrentIndex()148         size_t GetCurrentIndex() const
149         {
150             return currentIdx_;
151         }
152 
SetCurrentIndex(size_t idx)153         void SetCurrentIndex(size_t idx)
154         {
155             ASSERT(idx < MAX_INSTS);
156             currentIdx_ = idx;
157         }
158 
159         // Returns true if all non-constant operands have exactly the same type and all
160         // constant arguments have the same common type (obtained using GetCommonType) as all other operands.
HaveSameType()161         bool HaveSameType() const
162         {
163             ASSERT(currentIdx_ == MAX_INSTS);
164             auto nonConstType = DataType::LAST;
165             for (size_t i = 0; i < MAX_INSTS; i++) {
166                 if (insts_[i]->GetOpcode() != Opcode::Constant) {
167                     nonConstType = insts_[i]->GetType();
168                     break;
169                 }
170             }
171             // all operands are constants
172             if (nonConstType == DataType::LAST) {
173                 nonConstType = insts_[0]->GetType();
174             }
175             for (size_t i = 0; i < MAX_INSTS; i++) {
176                 if (insts_[i]->GetOpcode() == Opcode::Constant) {
177                     if (GetCommonType(insts_[i]->GetType()) != GetCommonType(nonConstType)) {
178                         return false;
179                     }
180                 } else if (nonConstType != insts_[i]->GetType()) {
181                     return false;
182                 }
183             }
184             return true;
185         }
186 
ResetIndex()187         InstructionsCapture &ResetIndex()
188         {
189             currentIdx_ = 0;
190             return *this;
191         }
192 
193     private:
194         std::array<Inst *, MAX_INSTS> insts_ {};
195         size_t currentIdx_ = 0;
196     };
197 
198     template <Opcode OPCODE, typename L, typename R, uint64_t FLAGS = Flags::NONE>
199     struct BinaryOp {
200         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
CaptureBinaryOp201         static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts)
202         {
203             constexpr auto INPUTS_NUM = 2;
204             // NOLINTNEXTLINE(readability-magic-numbers)
205             if (inst->GetOpcode() != OPCODE || inst->GetInputsCount() != INPUTS_NUM ||
206                 (IsSet(FLAGS, Flags::SINGLE_USER) && !inst->HasSingleUser())) {
207                 return false;
208             }
209             if (L::Capture(inst->GetInput(0).GetInst(), args, insts) &&
210                 R::Capture(inst->GetInput(1).GetInst(), args, insts)) {
211                 insts.Add(inst);
212                 return true;
213             }
214             if (IsSet(FLAGS, Flags::COMMUTATIVE) && L::Capture(inst->GetInput(1).GetInst(), args, insts) &&
215                 R::Capture(inst->GetInput(0).GetInst(), args, insts)) {
216                 insts.Add(inst);
217                 return true;
218             }
219             return false;
220         }
221     };
222 
223     template <Opcode OPCODE, typename T, uint64_t FLAGS = Flags::NONE>
224     struct UnaryOp {
225         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
CaptureUnaryOp226         static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts)
227         {
228             // NOLINTNEXTLINE(readability-magic-numbers)
229             bool matched = inst->GetOpcode() == OPCODE && inst->GetInputsCount() == 1 &&
230                            (IsNotSet(FLAGS, Flags::SINGLE_USER) || inst->HasSingleUser()) &&
231                            T::Capture(inst->GetInput(0).GetInst(), args, insts);
232             if (matched) {
233                 insts.Add(inst);
234             }
235             return matched;
236         }
237     };
238 
239     template <size_t IDX>
240     struct Operand {
241         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
CaptureOperand242         static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args,
243                             [[maybe_unused]] InstructionsCapture<MAX_INSTS> &insts)
244         {
245             static_assert(IDX < MAX_OPERANDS, "Operand's index should not exceed OperandsCapture size");
246 
247             args.Set(IDX, inst);
248             return true;
249         }
250     };
251 
252     template <typename T, typename... Args>
253     struct AnyOf {
254         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
CaptureAnyOf255         static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts)
256         {
257             size_t instIdx = insts.GetCurrentIndex();
258             if (T::Capture(inst, args, insts)) {
259                 return true;
260             }
261             insts.SetCurrentIndex(instIdx);
262             return AnyOf<Args...>::Capture(inst, args, insts);
263         }
264     };
265 
266     template <typename T>
267     struct AnyOf<T> {
268         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
269         static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts)
270         {
271             return T::Capture(inst, args, insts);
272         }
273     };
274 
275     template <bool ENABLED, typename T>
276     struct MatchIf : public T {
277     };
278 
279     template <typename T>
280     struct MatchIf<false, T> {
281         template <size_t MAX_OPERANDS, size_t MAX_INSTS>
282         // NOLINTNEXTLINE(readability-named-parameter)
283         static bool Capture(Inst *, OperandsCapture<MAX_OPERANDS> &, InstructionsCapture<MAX_INSTS> &)
284         {
285             return false;
286         }
287     };
288 
289     using SRC0 = Operand<0>;
290     using SRC1 = Operand<1>;
291     using SRC2 = Operand<2U>;
292 
293     template <typename L, typename R, uint64_t F = Flags::NONE>
294     using ADD = BinaryOp<Opcode::Add, L, R, F | Flags::COMMUTATIVE>;
295     template <typename L, typename R, uint64_t F = Flags::NONE>
296     using SUB = BinaryOp<Opcode::Sub, L, R, F>;
297     template <typename L, typename R, uint64_t F = Flags::NONE>
298     using MUL = BinaryOp<Opcode::Mul, L, R, F | Flags::COMMUTATIVE>;
299     template <typename T, uint64_t F = Flags::NONE>
300     using NEG = UnaryOp<Opcode::Neg, T, F>;
301     template <typename T, uint64_t F = Flags::SINGLE_USER>
302     using NOT = UnaryOp<Opcode::Not, T, F>;
303     template <typename T, uint64_t F = Flags::SINGLE_USER>
304     using SHRI = UnaryOp<Opcode::ShrI, T, F>;
305     template <typename T, uint64_t F = Flags::SINGLE_USER>
306     using ASHRI = UnaryOp<Opcode::AShrI, T, F>;
307     template <typename T, uint64_t F = Flags::SINGLE_USER>
308     using SHLI = UnaryOp<Opcode::ShlI, T, F>;
309 
310     const ArenaVector<BasicBlock *> &GetBlocksToVisit() const override
311     {
312         return GetGraph()->GetBlocksRPO();
313     }
314 
315     static void VisitAdd([[maybe_unused]] GraphVisitor *v, Inst *inst);
316     static void VisitSub([[maybe_unused]] GraphVisitor *v, Inst *inst);
317     static void VisitCast([[maybe_unused]] GraphVisitor *v, Inst *inst);
318     static void VisitCastValueToAnyType([[maybe_unused]] GraphVisitor *v, Inst *inst);
319 
320     template <Opcode OPC>
321     static void VisitBitwiseBinaryOperation([[maybe_unused]] GraphVisitor *v, Inst *inst);
322     static void VisitOr(GraphVisitor *v, Inst *inst);
323     static void VisitAnd(GraphVisitor *v, Inst *inst);
324     static void VisitXor(GraphVisitor *v, Inst *inst);
325 
326     static void VisitAndNot([[maybe_unused]] GraphVisitor *v, Inst *inst);
327     static void VisitXorNot([[maybe_unused]] GraphVisitor *v, Inst *inst);
328     static void VisitOrNot([[maybe_unused]] GraphVisitor *v, Inst *inst);
329     static void VisitSaveState([[maybe_unused]] GraphVisitor *v, Inst *inst);
330     static void VisitSafePoint([[maybe_unused]] GraphVisitor *v, Inst *inst);
331     static void VisitSaveStateOsr([[maybe_unused]] GraphVisitor *v, Inst *inst);
332     static void VisitSaveStateDeoptimize([[maybe_unused]] GraphVisitor *v, Inst *inst);
333     static void VisitBoundsCheck([[maybe_unused]] GraphVisitor *v, Inst *inst);
334     static void VisitLoadArray([[maybe_unused]] GraphVisitor *v, Inst *inst);
335     static void VisitLoadCompressedStringChar([[maybe_unused]] GraphVisitor *v, Inst *inst);
336     static void VisitStoreArray([[maybe_unused]] GraphVisitor *v, Inst *inst);
337     static void VisitLoad([[maybe_unused]] GraphVisitor *v, Inst *inst);
338     static void VisitStore([[maybe_unused]] GraphVisitor *v, Inst *inst);
339     static void VisitReturn([[maybe_unused]] GraphVisitor *v, Inst *inst);
340     static void VisitShr([[maybe_unused]] GraphVisitor *v, Inst *inst);
341     static void VisitAShr([[maybe_unused]] GraphVisitor *v, Inst *inst);
342     static void VisitShl([[maybe_unused]] GraphVisitor *v, Inst *inst);
343     static void VisitIfImm([[maybe_unused]] GraphVisitor *v, Inst *inst);
344     static void VisitMul([[maybe_unused]] GraphVisitor *v, Inst *inst);
345     static void VisitDiv([[maybe_unused]] GraphVisitor *v, Inst *inst);
346     static void VisitMod([[maybe_unused]] GraphVisitor *v, Inst *inst);
347     static void VisitNeg([[maybe_unused]] GraphVisitor *v, Inst *inst);
348     static void VisitDeoptimizeIf(GraphVisitor *v, Inst *inst);
349     static void VisitLoadFromConstantPool(GraphVisitor *v, Inst *inst);
350     static void VisitCompare(GraphVisitor *v, Inst *inst);
351 
352 #include "optimizer/ir/visitor.inc"
353 
354     static void InsertInstruction(Inst *inst, Inst *newInst)
355     {
356         inst->InsertBefore(newInst);
357         inst->ReplaceUsers(newInst);
358         newInst->GetBasicBlock()->GetGraph()->GetEventWriter().EventLowering(GetOpcodeString(inst->GetOpcode()),
359                                                                              inst->GetId(), inst->GetPc());
360         COMPILER_LOG(DEBUG, LOWERING) << "Lowering is applied for " << GetOpcodeString(inst->GetOpcode());
361     }
362 
363     template <size_t MAX_OPERANDS>
364     static void SetInputsAndInsertInstruction(OperandsCapture<MAX_OPERANDS> &operands, Inst *inst, Inst *newInst);
365 
366     static constexpr Opcode GetInstructionWithShiftedOperand(Opcode opcode);
367     static constexpr Opcode GetInstructionWithInvertedOperand(Opcode opcode);
368     static ShiftType GetShiftTypeByOpcode(Opcode opcode);
369     static Inst *GetCheckInstAndGetConstInput(Inst *inst);
370     static ShiftOpcode ConvertOpcode(Opcode newOpcode);
371 
372     static void LowerMemInstScale(Inst *inst);
373     static void LowerShift(Inst *inst);
374     static bool ConstantFitsCompareImm(Inst *cst, uint32_t size, ConditionCode cc);
375     static Inst *LowerAddSub(Inst *inst);
376     template <Opcode OPCODE>
377     static void LowerMulDivMod(Inst *inst);
378     static Inst *LowerMultiplyAddSub(Inst *inst);
379     static Inst *LowerNegateMultiply(Inst *inst);
380     static void LowerLogicWithInvertedOperand(Inst *inst);
381     static bool LowerCastValueToAnyTypeWithConst(Inst *inst);
382     template <typename T, size_t MAX_OPERANDS>
383     static Inst *LowerOperationWithShiftedOperand(Inst *inst, OperandsCapture<MAX_OPERANDS> &operands, Inst *shiftInst,
384                                                   Opcode newOpcode);
385     template <Opcode OPCODE, bool IS_COMMUTATIVE = true>
386     static Inst *LowerBinaryOperationWithShiftedOperand(Inst *inst);
387     template <Opcode OPCODE>
388     static void LowerUnaryOperationWithShiftedOperand(Inst *inst);
389     static Inst *LowerLogic(Inst *inst);
390     template <typename LowLevelType>
391     static void LowerConstArrayIndex(Inst *inst, Opcode lowLevelOpcode);
392     static void LowerStateInst(SaveStateInst *saveState);
393     static void LowerReturnInst(FixedInputsInst1 *ret);
394     // We'd like to swap only to make second operand immediate
395     static bool BetterToSwapCompareInputs(Inst *cmp);
396     // Optimize order of input arguments for decreasing using accumulator (Bytecodeoptimizer only).
397     static void OptimizeIfInput(compiler::Inst *ifInst);
398     static void JoinFcmpInst(IfImmInst *inst, CmpInst *input);
399     void LowerIf(IfImmInst *inst);
400     static void InPlaceLowerIfImm(IfImmInst *inst, Inst *input, Inst *cst, ConditionCode cc, DataType::Type inputType);
401     static void LowerIfImmToIf(IfImmInst *inst, Inst *input, ConditionCode cc, DataType::Type inputType);
402     static void LowerToDeoptimizeCompare(Inst *inst);
403     static bool TryReplaceDivPowerOfTwo(GraphVisitor *v, Inst *inst);
404     static bool TryReplaceDivModNonPowerOfTwo(GraphVisitor *v, Inst *inst);
405     static bool TryReplaceModPowerOfTwo(GraphVisitor *v, Inst *inst);
406     static void ReplaceSignedModPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t absValue);
407     static void ReplaceUnsignedModPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t absValue);
408     static void ReplaceSignedDivPowerOfTwo(GraphVisitor *v, Inst *inst, int64_t sValue);
409     static void ReplaceUnsignedDivPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t uValue);
410 
411 private:
412     SaveStateBridgesBuilder ssb_;
413 };
414 }  // namespace ark::compiler
415 
416 #endif  // COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H
417