1 /* 2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H 17 #define COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H 18 19 #include "compiler_logger.h" 20 #include "compiler_options.h" 21 #include "optimizer/ir/analysis.h" 22 #include "optimizer/ir/graph.h" 23 #include "optimizer/ir/graph_visitor.h" 24 25 namespace ark::compiler { 26 // NOLINTNEXTLINE(fuchsia-multiple-inheritance) 27 class Lowering : public Optimization, public GraphVisitor { 28 using Optimization::Optimization; 29 30 public: Lowering(Graph * graph)31 explicit Lowering(Graph *graph) : Optimization(graph) {} 32 33 bool RunImpl() override; 34 IsEnable()35 bool IsEnable() const override 36 { 37 return g_options.IsCompilerLowering(); 38 } 39 GetPassName()40 const char *GetPassName() const override 41 { 42 return "Lowering"; 43 } 44 45 void InvalidateAnalyses() override; 46 47 static constexpr uint8_t HALF_SIZE = 32; 48 static constexpr uint8_t WORD_SIZE = 64; 49 50 private: 51 /** 52 * Utility template classes aimed to simplify pattern matching over IR-graph. 53 * Patterns are trees declared as a type using Any, UnaryOp, BinaryOp and Operand composition. 54 * Then IR-subtree could be tested for matching by calling static Capture method. 55 * To capture operands from matched subtree Operand<Index> should be used, where 56 * Index is an operand's index within OperandsCapture. 57 * 58 * For example, suppose that we want to test if IR-subtree rooted by some Inst matches add or sub instruction 59 * pattern: 60 * 61 * Inst* inst = ...; 62 * using Predicate = Any<BinaryOp<Opcode::Add, Operand<0>, Operand<1>, 63 * BinaryOp<Opcode::Sub, Operand<0>, Operand<1>>; 64 * OperandsCapture<2> capture{}; 65 * bool is_add_or_sub = Predicate::Capture(inst, capture); 66 * 67 * If inst is a binary instruction with opcode Add or Sub then Capture will return `true`, 68 * capture.Get(0) will return left inst's input and capture.Get(1) will return right inst's input. 69 */ 70 71 // Flags altering matching behavior. 72 enum Flags { 73 NONE = 0, 74 COMMUTATIVE = 1, // binary operation is commutative 75 C = COMMUTATIVE, 76 SINGLE_USER = 2, // operation must have only single user 77 S = SINGLE_USER 78 }; 79 IsSet(uint64_t flags,Flags flag)80 static bool IsSet(uint64_t flags, Flags flag) 81 { 82 return (flags & flag) != 0; 83 } 84 IsNotSet(uint64_t flags,Flags flag)85 static bool IsNotSet(uint64_t flags, Flags flag) 86 { 87 return (flags & flag) == 0; 88 } 89 90 template <size_t MAX_OPERANDS> 91 class OperandsCapture { 92 public: Get(size_t i)93 Inst *Get(size_t i) 94 { 95 ASSERT(i < MAX_OPERANDS); 96 return operands_[i]; 97 } 98 Set(size_t i,Inst * inst)99 void Set(size_t i, Inst *inst) 100 { 101 ASSERT(i < MAX_OPERANDS); 102 #pragma GCC diagnostic push 103 #pragma GCC diagnostic ignored "-Warray-bounds" 104 operands_[i] = inst; 105 #pragma GCC diagnostic pop 106 } 107 108 // Returns true if all non-constant operands have the same common type (obtained using GetCommonType) as all 109 // other operands. HaveCommonType()110 bool HaveCommonType() const 111 { 112 auto nonConstType = DataType::LAST; 113 for (size_t i = 0; i < MAX_OPERANDS; i++) { 114 if (operands_[i]->GetOpcode() != Opcode::Constant) { 115 nonConstType = GetCommonType(operands_[i]->GetType()); 116 break; 117 } 118 } 119 // all operands are constants 120 if (nonConstType == DataType::LAST) { 121 nonConstType = operands_[0]->GetType(); 122 } 123 for (size_t i = 0; i < MAX_OPERANDS; i++) { 124 if (operands_[i]->GetOpcode() == Opcode::Constant) { 125 if (GetCommonType(operands_[i]->GetType()) != GetCommonType(nonConstType)) { 126 return false; 127 } 128 } else if (nonConstType != GetCommonType(operands_[i]->GetType())) { 129 return false; 130 } 131 } 132 return true; 133 } 134 135 private: 136 std::array<Inst *, MAX_OPERANDS> operands_; 137 }; 138 139 template <size_t MAX_INSTS> 140 class InstructionsCapture { 141 public: Add(Inst * inst)142 void Add(Inst *inst) 143 { 144 ASSERT(currentIdx_ < MAX_INSTS); 145 insts_[currentIdx_++] = inst; 146 } 147 GetCurrentIndex()148 size_t GetCurrentIndex() const 149 { 150 return currentIdx_; 151 } 152 SetCurrentIndex(size_t idx)153 void SetCurrentIndex(size_t idx) 154 { 155 ASSERT(idx < MAX_INSTS); 156 currentIdx_ = idx; 157 } 158 159 // Returns true if all non-constant operands have exactly the same type and all 160 // constant arguments have the same common type (obtained using GetCommonType) as all other operands. HaveSameType()161 bool HaveSameType() const 162 { 163 ASSERT(currentIdx_ == MAX_INSTS); 164 auto nonConstType = DataType::LAST; 165 for (size_t i = 0; i < MAX_INSTS; i++) { 166 if (insts_[i]->GetOpcode() != Opcode::Constant) { 167 nonConstType = insts_[i]->GetType(); 168 break; 169 } 170 } 171 // all operands are constants 172 if (nonConstType == DataType::LAST) { 173 nonConstType = insts_[0]->GetType(); 174 } 175 for (size_t i = 0; i < MAX_INSTS; i++) { 176 if (insts_[i]->GetOpcode() == Opcode::Constant) { 177 if (GetCommonType(insts_[i]->GetType()) != GetCommonType(nonConstType)) { 178 return false; 179 } 180 } else if (nonConstType != insts_[i]->GetType()) { 181 return false; 182 } 183 } 184 return true; 185 } 186 ResetIndex()187 InstructionsCapture &ResetIndex() 188 { 189 currentIdx_ = 0; 190 return *this; 191 } 192 193 private: 194 std::array<Inst *, MAX_INSTS> insts_ {}; 195 size_t currentIdx_ = 0; 196 }; 197 198 template <Opcode OPCODE, typename L, typename R, uint64_t FLAGS = Flags::NONE> 199 struct BinaryOp { 200 template <size_t MAX_OPERANDS, size_t MAX_INSTS> CaptureBinaryOp201 static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts) 202 { 203 constexpr auto INPUTS_NUM = 2; 204 // NOLINTNEXTLINE(readability-magic-numbers) 205 if (inst->GetOpcode() != OPCODE || inst->GetInputsCount() != INPUTS_NUM || 206 (IsSet(FLAGS, Flags::SINGLE_USER) && !inst->HasSingleUser())) { 207 return false; 208 } 209 if (L::Capture(inst->GetInput(0).GetInst(), args, insts) && 210 R::Capture(inst->GetInput(1).GetInst(), args, insts)) { 211 insts.Add(inst); 212 return true; 213 } 214 if (IsSet(FLAGS, Flags::COMMUTATIVE) && L::Capture(inst->GetInput(1).GetInst(), args, insts) && 215 R::Capture(inst->GetInput(0).GetInst(), args, insts)) { 216 insts.Add(inst); 217 return true; 218 } 219 return false; 220 } 221 }; 222 223 template <Opcode OPCODE, typename T, uint64_t FLAGS = Flags::NONE> 224 struct UnaryOp { 225 template <size_t MAX_OPERANDS, size_t MAX_INSTS> CaptureUnaryOp226 static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts) 227 { 228 // NOLINTNEXTLINE(readability-magic-numbers) 229 bool matched = inst->GetOpcode() == OPCODE && inst->GetInputsCount() == 1 && 230 (IsNotSet(FLAGS, Flags::SINGLE_USER) || inst->HasSingleUser()) && 231 T::Capture(inst->GetInput(0).GetInst(), args, insts); 232 if (matched) { 233 insts.Add(inst); 234 } 235 return matched; 236 } 237 }; 238 239 template <size_t IDX> 240 struct Operand { 241 template <size_t MAX_OPERANDS, size_t MAX_INSTS> CaptureOperand242 static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, 243 [[maybe_unused]] InstructionsCapture<MAX_INSTS> &insts) 244 { 245 static_assert(IDX < MAX_OPERANDS, "Operand's index should not exceed OperandsCapture size"); 246 247 args.Set(IDX, inst); 248 return true; 249 } 250 }; 251 252 template <typename T, typename... Args> 253 struct AnyOf { 254 template <size_t MAX_OPERANDS, size_t MAX_INSTS> CaptureAnyOf255 static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts) 256 { 257 size_t instIdx = insts.GetCurrentIndex(); 258 if (T::Capture(inst, args, insts)) { 259 return true; 260 } 261 insts.SetCurrentIndex(instIdx); 262 return AnyOf<Args...>::Capture(inst, args, insts); 263 } 264 }; 265 266 template <typename T> 267 struct AnyOf<T> { 268 template <size_t MAX_OPERANDS, size_t MAX_INSTS> 269 static bool Capture(Inst *inst, OperandsCapture<MAX_OPERANDS> &args, InstructionsCapture<MAX_INSTS> &insts) 270 { 271 return T::Capture(inst, args, insts); 272 } 273 }; 274 275 template <bool ENABLED, typename T> 276 struct MatchIf : public T { 277 }; 278 279 template <typename T> 280 struct MatchIf<false, T> { 281 template <size_t MAX_OPERANDS, size_t MAX_INSTS> 282 // NOLINTNEXTLINE(readability-named-parameter) 283 static bool Capture(Inst *, OperandsCapture<MAX_OPERANDS> &, InstructionsCapture<MAX_INSTS> &) 284 { 285 return false; 286 } 287 }; 288 289 using SRC0 = Operand<0>; 290 using SRC1 = Operand<1>; 291 using SRC2 = Operand<2U>; 292 293 template <typename L, typename R, uint64_t F = Flags::NONE> 294 using ADD = BinaryOp<Opcode::Add, L, R, F | Flags::COMMUTATIVE>; 295 template <typename L, typename R, uint64_t F = Flags::NONE> 296 using SUB = BinaryOp<Opcode::Sub, L, R, F>; 297 template <typename L, typename R, uint64_t F = Flags::NONE> 298 using MUL = BinaryOp<Opcode::Mul, L, R, F | Flags::COMMUTATIVE>; 299 template <typename T, uint64_t F = Flags::NONE> 300 using NEG = UnaryOp<Opcode::Neg, T, F>; 301 template <typename T, uint64_t F = Flags::SINGLE_USER> 302 using NOT = UnaryOp<Opcode::Not, T, F>; 303 template <typename T, uint64_t F = Flags::SINGLE_USER> 304 using SHRI = UnaryOp<Opcode::ShrI, T, F>; 305 template <typename T, uint64_t F = Flags::SINGLE_USER> 306 using ASHRI = UnaryOp<Opcode::AShrI, T, F>; 307 template <typename T, uint64_t F = Flags::SINGLE_USER> 308 using SHLI = UnaryOp<Opcode::ShlI, T, F>; 309 310 const ArenaVector<BasicBlock *> &GetBlocksToVisit() const override 311 { 312 return GetGraph()->GetBlocksRPO(); 313 } 314 315 static void VisitAdd([[maybe_unused]] GraphVisitor *v, Inst *inst); 316 static void VisitSub([[maybe_unused]] GraphVisitor *v, Inst *inst); 317 static void VisitCast([[maybe_unused]] GraphVisitor *v, Inst *inst); 318 static void VisitCastValueToAnyType([[maybe_unused]] GraphVisitor *v, Inst *inst); 319 320 template <Opcode OPC> 321 static void VisitBitwiseBinaryOperation([[maybe_unused]] GraphVisitor *v, Inst *inst); 322 static void VisitOr(GraphVisitor *v, Inst *inst); 323 static void VisitAnd(GraphVisitor *v, Inst *inst); 324 static void VisitXor(GraphVisitor *v, Inst *inst); 325 326 static void VisitAndNot([[maybe_unused]] GraphVisitor *v, Inst *inst); 327 static void VisitXorNot([[maybe_unused]] GraphVisitor *v, Inst *inst); 328 static void VisitOrNot([[maybe_unused]] GraphVisitor *v, Inst *inst); 329 static void VisitSaveState([[maybe_unused]] GraphVisitor *v, Inst *inst); 330 static void VisitSafePoint([[maybe_unused]] GraphVisitor *v, Inst *inst); 331 static void VisitSaveStateOsr([[maybe_unused]] GraphVisitor *v, Inst *inst); 332 static void VisitSaveStateDeoptimize([[maybe_unused]] GraphVisitor *v, Inst *inst); 333 static void VisitBoundsCheck([[maybe_unused]] GraphVisitor *v, Inst *inst); 334 static void VisitLoadArray([[maybe_unused]] GraphVisitor *v, Inst *inst); 335 static void VisitLoadCompressedStringChar([[maybe_unused]] GraphVisitor *v, Inst *inst); 336 static void VisitStoreArray([[maybe_unused]] GraphVisitor *v, Inst *inst); 337 static void VisitLoad([[maybe_unused]] GraphVisitor *v, Inst *inst); 338 static void VisitStore([[maybe_unused]] GraphVisitor *v, Inst *inst); 339 static void VisitReturn([[maybe_unused]] GraphVisitor *v, Inst *inst); 340 static void VisitShr([[maybe_unused]] GraphVisitor *v, Inst *inst); 341 static void VisitAShr([[maybe_unused]] GraphVisitor *v, Inst *inst); 342 static void VisitShl([[maybe_unused]] GraphVisitor *v, Inst *inst); 343 static void VisitIfImm([[maybe_unused]] GraphVisitor *v, Inst *inst); 344 static void VisitMul([[maybe_unused]] GraphVisitor *v, Inst *inst); 345 static void VisitDiv([[maybe_unused]] GraphVisitor *v, Inst *inst); 346 static void VisitMod([[maybe_unused]] GraphVisitor *v, Inst *inst); 347 static void VisitNeg([[maybe_unused]] GraphVisitor *v, Inst *inst); 348 static void VisitDeoptimizeIf(GraphVisitor *v, Inst *inst); 349 static void VisitLoadFromConstantPool(GraphVisitor *v, Inst *inst); 350 static void VisitCompare(GraphVisitor *v, Inst *inst); 351 352 #include "optimizer/ir/visitor.inc" 353 354 static void InsertInstruction(Inst *inst, Inst *newInst) 355 { 356 inst->InsertBefore(newInst); 357 inst->ReplaceUsers(newInst); 358 newInst->GetBasicBlock()->GetGraph()->GetEventWriter().EventLowering(GetOpcodeString(inst->GetOpcode()), 359 inst->GetId(), inst->GetPc()); 360 COMPILER_LOG(DEBUG, LOWERING) << "Lowering is applied for " << GetOpcodeString(inst->GetOpcode()); 361 } 362 363 template <size_t MAX_OPERANDS> 364 static void SetInputsAndInsertInstruction(OperandsCapture<MAX_OPERANDS> &operands, Inst *inst, Inst *newInst); 365 366 static constexpr Opcode GetInstructionWithShiftedOperand(Opcode opcode); 367 static constexpr Opcode GetInstructionWithInvertedOperand(Opcode opcode); 368 static ShiftType GetShiftTypeByOpcode(Opcode opcode); 369 static Inst *GetCheckInstAndGetConstInput(Inst *inst); 370 static ShiftOpcode ConvertOpcode(Opcode newOpcode); 371 372 static void LowerMemInstScale(Inst *inst); 373 static void LowerShift(Inst *inst); 374 static bool ConstantFitsCompareImm(Inst *cst, uint32_t size, ConditionCode cc); 375 static Inst *LowerAddSub(Inst *inst); 376 template <Opcode OPCODE> 377 static void LowerMulDivMod(Inst *inst); 378 static Inst *LowerMultiplyAddSub(Inst *inst); 379 static Inst *LowerNegateMultiply(Inst *inst); 380 static void LowerLogicWithInvertedOperand(Inst *inst); 381 static bool LowerCastValueToAnyTypeWithConst(Inst *inst); 382 template <typename T, size_t MAX_OPERANDS> 383 static Inst *LowerOperationWithShiftedOperand(Inst *inst, OperandsCapture<MAX_OPERANDS> &operands, Inst *shiftInst, 384 Opcode newOpcode); 385 template <Opcode OPCODE, bool IS_COMMUTATIVE = true> 386 static Inst *LowerBinaryOperationWithShiftedOperand(Inst *inst); 387 template <Opcode OPCODE> 388 static void LowerUnaryOperationWithShiftedOperand(Inst *inst); 389 static Inst *LowerLogic(Inst *inst); 390 template <typename LowLevelType> 391 static void LowerConstArrayIndex(Inst *inst, Opcode lowLevelOpcode); 392 static void LowerStateInst(SaveStateInst *saveState); 393 static void LowerReturnInst(FixedInputsInst1 *ret); 394 // We'd like to swap only to make second operand immediate 395 static bool BetterToSwapCompareInputs(Inst *cmp); 396 // Optimize order of input arguments for decreasing using accumulator (Bytecodeoptimizer only). 397 static void OptimizeIfInput(compiler::Inst *ifInst); 398 static void JoinFcmpInst(IfImmInst *inst, CmpInst *input); 399 void LowerIf(IfImmInst *inst); 400 static void InPlaceLowerIfImm(IfImmInst *inst, Inst *input, Inst *cst, ConditionCode cc, DataType::Type inputType); 401 static void LowerIfImmToIf(IfImmInst *inst, Inst *input, ConditionCode cc, DataType::Type inputType); 402 static void LowerToDeoptimizeCompare(Inst *inst); 403 static bool TryReplaceDivPowerOfTwo(GraphVisitor *v, Inst *inst); 404 static bool TryReplaceDivModNonPowerOfTwo(GraphVisitor *v, Inst *inst); 405 static bool TryReplaceModPowerOfTwo(GraphVisitor *v, Inst *inst); 406 static void ReplaceSignedModPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t absValue); 407 static void ReplaceUnsignedModPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t absValue); 408 static void ReplaceSignedDivPowerOfTwo(GraphVisitor *v, Inst *inst, int64_t sValue); 409 static void ReplaceUnsignedDivPowerOfTwo(GraphVisitor *v, Inst *inst, uint64_t uValue); 410 411 private: 412 SaveStateBridgesBuilder ssb_; 413 }; 414 } // namespace ark::compiler 415 416 #endif // COMPILER_OPTIMIZER_OPTIMIZATIONS_LOWERING_H 417