• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef MAPLE_ME_INCLUDE_CAST_OPT_H
17 #define MAPLE_ME_INCLUDE_CAST_OPT_H
18 #include "mir_nodes.h"
19 #include "me_ir.h"
20 
21 namespace maple {
22 // The order matters
23 enum CastKind {
24     CAST_intTrunc = 0,
25     CAST_zext = 1,
26     CAST_sext = 2,
27     CAST_int2fp = 3,
28     CAST_fp2int = 4,
29     CAST_fpTrunc = 5,
30     CAST_fpExt = 6,
31     CAST_retype = 7,
32     CAST_unknown = 8
33 };
34 
35 template <typename T,
36           std::enable_if_t<std::is_base_of<MeExpr, T>::value || std::is_base_of<BaseNode, T>::value, bool> = true>
37 class CastInfo {
38 public:
CastInfo(T * expr)39     explicit CastInfo(T *expr) : expr(expr) {}
40     virtual ~CastInfo() = default;
GetOp()41     virtual Opcode GetOp()
42     {
43         CHECK_FATAL(false, "NYI");
44     }
GetPrimType()45     PrimType GetPrimType() const
46     {
47         return expr->GetPrimType();
48     }
GetBitsSize()49     virtual size_t GetBitsSize()
50     {
51         CHECK_FATAL(false, "NYI");
52     }
GetOpnd(size_t index)53     virtual T *GetOpnd(size_t index __attribute__((unused)))
54     {
55         CHECK_FATAL(false, "NYI");
56     }
GetOpndType()57     virtual PrimType GetOpndType()
58     {
59         CHECK_FATAL(false, "NYI");
60     }
61 
IsInvalid()62     bool IsInvalid() const
63     {
64         return kind == CAST_unknown;
65     }
66     CastKind kind = CAST_unknown;  // CastInfo is invalid if kind is CAST_unknown
67     PrimType srcType = PTY_begin;
68     PrimType dstType = PTY_end;
69     T *expr = nullptr;  // expr's type must be MeExpr* or BaseNode*
70 };
71 
72 class BaseNodeCastInfo : public CastInfo<BaseNode> {
73 public:
BaseNodeCastInfo(BaseNode * expr)74     explicit BaseNodeCastInfo(BaseNode *expr) : CastInfo(expr) {}
75     ~BaseNodeCastInfo() = default;
76 
GetOp()77     Opcode GetOp() override
78     {
79         return expr->GetOpCode();
80     }
81 
GetBitsSize()82     size_t GetBitsSize() override
83     {
84         switch (GetOp()) {
85             case OP_zext:
86             case OP_sext:
87                 return static_cast<const ExtractbitsNode *>(expr)->GetBitsSize();
88             default:
89                 CHECK_FATAL(false, "NYI");
90                 break;
91         }
92     }
93 
GetOpnd(size_t index)94     BaseNode *GetOpnd(size_t index) override
95     {
96         return expr->Opnd(index);
97     }
98 
GetOpndType()99     PrimType GetOpndType() override
100     {
101         switch (GetOp()) {
102             case OP_retype: {
103                 return GetOpnd(0)->GetPrimType();
104             }
105             case OP_cvt:
106                 return static_cast<const TypeCvtNode *>(expr)->FromType();
107             case OP_regread: {
108                 const auto *regread = static_cast<const RegreadNode *>(expr);
109                 PregIdx regIdx = regread->GetRegIdx();
110                 MIRPreg *preg = theMIRModule->CurFunction()->GetPregItem(regIdx);
111                 return preg->GetPrimType();
112             }
113             case OP_iread: {
114                 const auto *iread = static_cast<const IreadNode *>(expr);
115                 return iread->GetType()->GetPrimType();
116             }
117             case OP_dread: {
118                 const auto *dread = static_cast<const DreadNode *>(expr);
119                 StIdx stIdx = dread->GetStIdx();
120                 MIRSymbol *symbol = theMIRModule->CurFunction()->GetLocalOrGlobalSymbol(stIdx);
121                 return symbol->GetType()->GetPrimType();
122             }
123             default:
124                 CHECK_FATAL(false, "NYI");
125                 break;
126         }
127     }
128 };
129 
130 class CastOpt {
131 public:
132     static int IsEliminableCastPair(CastKind firstCastKind, CastKind secondCastKind, PrimType dstType,
133                                     PrimType midType2, PrimType midType1, PrimType &srcType);
134     template <typename T>
135     static void DoComputeCastInfo(CastInfo<T> &castInfo, bool isMeExpr);
136     static bool IsExplicitCastOp(Opcode op);
137     static bool IsImplicitCastOp(Opcode op);
138     static bool IsCompareOp(Opcode op);
139 };
140 
141 class MapleCastOpt : public CastOpt {
142 public:
143     static void ComputeCastInfo(BaseNodeCastInfo &castInfo);
144     static BaseNode *CreateMapleExprByCastKind(MIRBuilder &mirBuilder, CastKind castKind, PrimType srcType,
145                                                PrimType dstType, BaseNode *opnd, TyIdx dstTyIdx = TyIdx(0));
146     static BaseNode *SimplifyCast(MIRBuilder &mirBuilder, BaseNode *expr);
147     static BaseNode *SimplifyCastPair(MIRBuilder &mirBuidler, const BaseNodeCastInfo &firstCastInfo,
148                                       const BaseNodeCastInfo &secondCastInfo);
149     static BaseNode *SimplifyCastSingle(MIRBuilder &mirBuilder, const BaseNodeCastInfo &castInfo);
150     static BaseNode *TransformCvtU1ToNe(MIRBuilder &mirBuilder, const TypeCvtNode *cvtExpr);
151 };
152 }  // namespace maple
153 #endif
154