1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef MAPLE_ME_INCLUDE_CAST_OPT_H 17 #define MAPLE_ME_INCLUDE_CAST_OPT_H 18 #include "mir_nodes.h" 19 #include "me_ir.h" 20 21 namespace maple { 22 // The order matters 23 enum CastKind { 24 CAST_intTrunc = 0, 25 CAST_zext = 1, 26 CAST_sext = 2, 27 CAST_int2fp = 3, 28 CAST_fp2int = 4, 29 CAST_fpTrunc = 5, 30 CAST_fpExt = 6, 31 CAST_retype = 7, 32 CAST_unknown = 8 33 }; 34 35 template <typename T, 36 std::enable_if_t<std::is_base_of<MeExpr, T>::value || std::is_base_of<BaseNode, T>::value, bool> = true> 37 class CastInfo { 38 public: CastInfo(T * expr)39 explicit CastInfo(T *expr) : expr(expr) {} 40 virtual ~CastInfo() = default; GetOp()41 virtual Opcode GetOp() 42 { 43 CHECK_FATAL(false, "NYI"); 44 } GetPrimType()45 PrimType GetPrimType() const 46 { 47 return expr->GetPrimType(); 48 } GetBitsSize()49 virtual size_t GetBitsSize() 50 { 51 CHECK_FATAL(false, "NYI"); 52 } GetOpnd(size_t index)53 virtual T *GetOpnd(size_t index __attribute__((unused))) 54 { 55 CHECK_FATAL(false, "NYI"); 56 } GetOpndType()57 virtual PrimType GetOpndType() 58 { 59 CHECK_FATAL(false, "NYI"); 60 } 61 IsInvalid()62 bool IsInvalid() const 63 { 64 return kind == CAST_unknown; 65 } 66 CastKind kind = CAST_unknown; // CastInfo is invalid if kind is CAST_unknown 67 PrimType srcType = PTY_begin; 68 PrimType dstType = PTY_end; 69 T *expr = nullptr; // expr's type must be MeExpr* or BaseNode* 70 }; 71 72 class BaseNodeCastInfo : public CastInfo<BaseNode> { 73 public: BaseNodeCastInfo(BaseNode * expr)74 explicit BaseNodeCastInfo(BaseNode *expr) : CastInfo(expr) {} 75 ~BaseNodeCastInfo() = default; 76 GetOp()77 Opcode GetOp() override 78 { 79 return expr->GetOpCode(); 80 } 81 GetBitsSize()82 size_t GetBitsSize() override 83 { 84 switch (GetOp()) { 85 case OP_zext: 86 case OP_sext: 87 return static_cast<const ExtractbitsNode *>(expr)->GetBitsSize(); 88 default: 89 CHECK_FATAL(false, "NYI"); 90 break; 91 } 92 } 93 GetOpnd(size_t index)94 BaseNode *GetOpnd(size_t index) override 95 { 96 return expr->Opnd(index); 97 } 98 GetOpndType()99 PrimType GetOpndType() override 100 { 101 switch (GetOp()) { 102 case OP_retype: { 103 return GetOpnd(0)->GetPrimType(); 104 } 105 case OP_cvt: 106 return static_cast<const TypeCvtNode *>(expr)->FromType(); 107 case OP_regread: { 108 const auto *regread = static_cast<const RegreadNode *>(expr); 109 PregIdx regIdx = regread->GetRegIdx(); 110 MIRPreg *preg = theMIRModule->CurFunction()->GetPregItem(regIdx); 111 return preg->GetPrimType(); 112 } 113 case OP_iread: { 114 const auto *iread = static_cast<const IreadNode *>(expr); 115 return iread->GetType()->GetPrimType(); 116 } 117 case OP_dread: { 118 const auto *dread = static_cast<const DreadNode *>(expr); 119 StIdx stIdx = dread->GetStIdx(); 120 MIRSymbol *symbol = theMIRModule->CurFunction()->GetLocalOrGlobalSymbol(stIdx); 121 return symbol->GetType()->GetPrimType(); 122 } 123 default: 124 CHECK_FATAL(false, "NYI"); 125 break; 126 } 127 } 128 }; 129 130 class CastOpt { 131 public: 132 static int IsEliminableCastPair(CastKind firstCastKind, CastKind secondCastKind, PrimType dstType, 133 PrimType midType2, PrimType midType1, PrimType &srcType); 134 template <typename T> 135 static void DoComputeCastInfo(CastInfo<T> &castInfo, bool isMeExpr); 136 static bool IsExplicitCastOp(Opcode op); 137 static bool IsImplicitCastOp(Opcode op); 138 static bool IsCompareOp(Opcode op); 139 }; 140 141 class MapleCastOpt : public CastOpt { 142 public: 143 static void ComputeCastInfo(BaseNodeCastInfo &castInfo); 144 static BaseNode *CreateMapleExprByCastKind(MIRBuilder &mirBuilder, CastKind castKind, PrimType srcType, 145 PrimType dstType, BaseNode *opnd, TyIdx dstTyIdx = TyIdx(0)); 146 static BaseNode *SimplifyCast(MIRBuilder &mirBuilder, BaseNode *expr); 147 static BaseNode *SimplifyCastPair(MIRBuilder &mirBuidler, const BaseNodeCastInfo &firstCastInfo, 148 const BaseNodeCastInfo &secondCastInfo); 149 static BaseNode *SimplifyCastSingle(MIRBuilder &mirBuilder, const BaseNodeCastInfo &castInfo); 150 static BaseNode *TransformCvtU1ToNe(MIRBuilder &mirBuilder, const TypeCvtNode *cvtExpr); 151 }; 152 } // namespace maple 153 #endif 154