• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "ext_constantfold.h"
17 #include <climits>
18 
19 namespace maple {
20 // This class is designed to further identify simplification
21 // patterns that have not been covered in ConstantFold.
22 
ExtSimplify(StmtNode * node)23 StmtNode *ExtConstantFold::ExtSimplify(StmtNode *node)
24 {
25     CHECK_NULL_FATAL(node);
26     switch (node->GetOpCode()) {
27         case OP_block:
28             return ExtSimplifyBlock(static_cast<BlockNode *>(node));
29         case OP_if:
30             return ExtSimplifyIf(static_cast<IfStmtNode *>(node));
31         case OP_dassign:
32             return ExtSimplifyDassign(static_cast<DassignNode *>(node));
33         case OP_iassign:
34             return ExtSimplifyIassign(static_cast<IassignNode *>(node));
35         case OP_dowhile:
36         case OP_while:
37             return ExtSimplifyWhile(static_cast<WhileStmtNode *>(node));
38         default:
39             return node;
40     }
41 }
42 
DispatchFold(BaseNode * node)43 BaseNode *ExtConstantFold::DispatchFold(BaseNode *node)
44 {
45     // Not trying all possiblities.
46     // For simplicity, stop looking further down the expression once OP_OP_cior/OP_cand (etc) are seen
47     CHECK_NULL_FATAL(node);
48     switch (node->GetOpCode()) {
49         case OP_cior:
50         case OP_lior:
51             return ExtFoldIor(static_cast<BinaryNode *>(node));
52         case OP_cand:
53         case OP_land:
54             return ExtFoldXand(static_cast<BinaryNode *>(node));
55         case OP_abs:
56         case OP_bnot:
57         case OP_lnot:
58         case OP_neg:
59         case OP_recip:
60         case OP_sqrt:
61             return ExtFoldUnary(static_cast<UnaryNode *>(node));
62         case OP_add:
63         case OP_ashr:
64         case OP_band:
65         case OP_bior:
66         case OP_bxor:
67         case OP_div:
68         case OP_lshr:
69         case OP_max:
70         case OP_min:
71         case OP_mul:
72         case OP_rem:
73         case OP_shl:
74         case OP_sub:
75         case OP_eq:
76         case OP_ne:
77         case OP_ge:
78         case OP_gt:
79         case OP_le:
80         case OP_lt:
81         case OP_cmp:
82             return ExtFoldBinary(static_cast<BinaryNode *>(node));
83         case OP_select:
84             return ExtFoldTernary(static_cast<TernaryNode *>(node));
85         default:
86             return node;
87     }
88 }
89 
ExtFoldUnary(UnaryNode * node)90 BaseNode *ExtConstantFold::ExtFoldUnary(UnaryNode *node)
91 {
92     CHECK_NULL_FATAL(node);
93     BaseNode *result = nullptr;
94     result = DispatchFold(node->Opnd(0));
95     if (result != node->Opnd(0)) {
96         node->SetOpnd(result, 0);
97     }
98     return node;
99 }
100 
ExtFoldBinary(BinaryNode * node)101 BaseNode *ExtConstantFold::ExtFoldBinary(BinaryNode *node)
102 {
103     CHECK_NULL_FATAL(node);
104     BaseNode *result = nullptr;
105     result = DispatchFold(node->Opnd(0));
106     if (result != node->Opnd(0)) {
107         node->SetOpnd(result, 0);
108     }
109     result = DispatchFold(node->Opnd(1));
110     if (result != node->Opnd(1)) {
111         node->SetOpnd(result, 1);
112     }
113     return node;
114 }
115 
ExtFoldTernary(TernaryNode * node)116 BaseNode *ExtConstantFold::ExtFoldTernary(TernaryNode *node)
117 {
118     CHECK_NULL_FATAL(node);
119     BaseNode *result = nullptr;
120     result = DispatchFold(node->Opnd(kFirstOpnd));
121     if (result != node->Opnd(kFirstOpnd)) {
122         node->SetOpnd(result, kFirstOpnd);
123     }
124     result = DispatchFold(node->Opnd(kSecondOpnd));
125     if (result != node->Opnd(kSecondOpnd)) {
126         node->SetOpnd(result, kSecondOpnd);
127     }
128     result = DispatchFold(node->Opnd(kThirdOpnd));
129     if (result != node->Opnd(kThirdOpnd)) {
130         node->SetOpnd(result, kThirdOpnd);
131     }
132     return node;
133 }
134 
ExtFold(BaseNode * node)135 BaseNode *ExtConstantFold::ExtFold(BaseNode *node)
136 {
137     if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
138         return nullptr;
139     }
140     return DispatchFold(node);
141 }
142 
ExtFoldIor(BinaryNode * node)143 BaseNode *ExtConstantFold::ExtFoldIor(BinaryNode *node)
144 {
145     CHECK_NULL_FATAL(node);
146     // The target pattern (Cior, Lior):
147     // x == c || x == c+1 || ... || x == c+k
148     // ==> le (x - c), k
149     // where c is int. i
150     // Leave other cases for future including extended simplification of partial expressions
151     std::queue<BaseNode *> operands;
152     std::vector<int64> uniqOperands;
153     operands.push(node);
154     int64 minVal = LLONG_MAX;
155     bool isWorkable = true;
156     BaseNode *lNode = nullptr;
157 
158     while (!operands.empty()) {
159         BaseNode *operand = operands.front();
160         operands.pop();
161         Opcode op = operand->GetOpCode();
162         if (op == OP_cior || op == OP_lior) {
163             operands.push(static_cast<BinaryNode *>(operand)->GetBOpnd(0));
164             operands.push(static_cast<BinaryNode *>(operand)->GetBOpnd(1));
165         } else if (op == OP_eq) {
166             BinaryNode *bNode = static_cast<BinaryNode *>(operand);
167             if (lNode == nullptr) {
168                 if (bNode->Opnd(0)->GetOpCode() == OP_dread || bNode->Opnd(0)->GetOpCode() == OP_iread) {
169                     lNode = bNode->Opnd(0);
170                 } else {
171                     // Consider other cases in future
172                     isWorkable = false;
173                     break;
174                 }
175             }
176 
177             if ((lNode->IsSameContent(bNode->Opnd(0))) && (bNode->Opnd(1)->GetOpCode() == OP_constval) &&
178                 (IsPrimitiveInteger(bNode->Opnd(1)->GetPrimType()))) {
179                 MIRConst *rConstVal = safe_cast<ConstvalNode>(bNode->Opnd(1))->GetConstVal();
180                 int64 rVal = static_cast<MIRIntConst *>(rConstVal)->GetExtValue();
181                 minVal = std::min(minVal, rVal);
182                 uniqOperands.push_back(rVal);
183             } else {
184                 isWorkable = false;
185                 break;
186             }
187         } else {
188             isWorkable = false;
189             break;
190         }
191     }
192 
193     if (isWorkable) {
194         std::sort(uniqOperands.begin(), uniqOperands.end());
195         uniqOperands.erase(std::unique(uniqOperands.begin(), uniqOperands.end()), uniqOperands.end());
196         if ((uniqOperands.size() >= 2) && // operand count is not less than 2
197             (uniqOperands[uniqOperands.size() - 1] == uniqOperands[0] + static_cast<int64>(uniqOperands.size()) - 1)) {
198             PrimType nPrimType = node->GetPrimType();
199             BaseNode *diffVal;
200             ConstvalNode *lowVal = mirModule->GetMIRBuilder()->CreateIntConst(minVal, nPrimType);
201             diffVal = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, nPrimType, lNode, lowVal);
202             PrimType cmpPrimType = (nPrimType == PTY_i64 || nPrimType == PTY_u64) ? PTY_u64 : PTY_u32;
203             MIRType *cmpMirType = (nPrimType == PTY_i64 || nPrimType == PTY_u64)
204                                       ? GlobalTables::GetTypeTable().GetUInt64()
205                                       : GlobalTables::GetTypeTable().GetUInt32();
206             ConstvalNode *deltaVal =
207                 mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(uniqOperands.size()) - 1, cmpPrimType);
208             CompareNode *result;
209             result = mirModule->GetMIRBuilder()->CreateExprCompare(OP_le, *cmpMirType, *cmpMirType, diffVal, deltaVal);
210             return result;
211         } else {
212             return node;
213         }
214     } else {
215         return node;
216     }
217 }
218 
ExtFoldXand(BinaryNode * node)219 BaseNode *ExtConstantFold::ExtFoldXand(BinaryNode *node)
220 {
221     // The target pattern (Cand, Land):
222     // (x & m1) == c1 && (x & m2) == c2 && ... && (x & Mk) == ck
223     // where mi and ci shall be all int constants
224     // ==> (x & M) == C
225 
226     CHECK_NULL_FATAL(node);
227     CHECK_FATAL(node->GetOpCode() == OP_cand || node->GetOpCode() == OP_land,
228                 "Operator is neither OP_cand nor OP_land");
229 
230     BaseNode *lnode = DispatchFold(node->Opnd(0));
231     if (lnode != node->Opnd(0)) {
232         node->SetOpnd(lnode, 0);
233     }
234 
235     BaseNode *rnode = DispatchFold(node->Opnd(1));
236     if (rnode != node->Opnd(1)) {
237         node->SetOpnd(rnode, 1);
238     }
239 
240     // Check if it is of the form of (x & m) == c cand (x & m') == c'
241     if ((lnode->GetOpCode() == OP_eq) && (rnode->GetOpCode() == OP_eq) && (lnode->Opnd(0)->GetOpCode() == OP_band) &&
242         (lnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) &&
243         (IsPrimitiveInteger(lnode->Opnd(0)->Opnd(1)->GetPrimType())) && (lnode->Opnd(1)->GetOpCode() == OP_constval) &&
244         (IsPrimitiveInteger(lnode->Opnd(1)->GetPrimType())) && (rnode->Opnd(0)->GetOpCode() == OP_band) &&
245         (rnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) &&
246         (IsPrimitiveInteger(rnode->Opnd(0)->Opnd(1)->GetPrimType())) && (rnode->Opnd(1)->GetOpCode() == OP_constval) &&
247         (IsPrimitiveInteger(rnode->Opnd(1)->GetPrimType())) &&
248         (lnode->Opnd(0)->Opnd(0)->IsSameContent(rnode->Opnd(0)->Opnd(0)))) {
249         MIRConst *lmConstVal = safe_cast<ConstvalNode>(lnode->Opnd(0)->Opnd(1))->GetConstVal();
250         uint64 lmVal = static_cast<MIRIntConst *>(lmConstVal)->GetExtValue();
251         MIRConst *rmConstVal = safe_cast<ConstvalNode>(rnode->Opnd(0)->Opnd(1))->GetConstVal();
252         uint64 rmVal = static_cast<MIRIntConst *>(rmConstVal)->GetExtValue();
253         MIRConst *lcConstVal = safe_cast<ConstvalNode>(lnode->Opnd(1))->GetConstVal();
254         uint64 lcVal = static_cast<MIRIntConst *>(lcConstVal)->GetExtValue();
255         MIRConst *rcConstVal = safe_cast<ConstvalNode>(rnode->Opnd(1))->GetConstVal();
256         uint64 rcVal = static_cast<MIRIntConst *>(rcConstVal)->GetExtValue();
257 
258         bool isWorkable = true;
259         for (uint64 i = 0; i < k64BitSize; i++) {
260             if ((lmVal & (1UL << i)) == (rmVal & (1UL << i)) && (lcVal & (1UL << i)) != (rcVal & (1UL << i))) {
261                 isWorkable = false;
262                 break;
263             }
264         }
265 
266         if (isWorkable) {
267             uint64 mVal = lmVal | rmVal;
268             uint64 cVal = lcVal | rcVal;
269             PrimType mPrimType = lnode->Opnd(0)->Opnd(1)->GetPrimType();
270             ConstvalNode *mIntConst = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(mVal), mPrimType);
271             PrimType cPrimType = lnode->Opnd(1)->GetPrimType();
272             ConstvalNode *cIntConst = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(cVal), cPrimType);
273             BinaryNode *eqNode = static_cast<BinaryNode *>(lnode);
274             BinaryNode *bandNode = static_cast<BinaryNode *>(eqNode->Opnd(0));
275             bandNode->SetOpnd(mIntConst, 1);
276             eqNode->SetOpnd(cIntConst, 1);
277             return eqNode;
278         }
279     }
280     return node;
281 }
282 
ExtSimplifyBlock(BlockNode * node)283 StmtNode *ExtConstantFold::ExtSimplifyBlock(BlockNode *node)
284 {
285     CHECK_NULL_FATAL(node);
286     if (node->GetFirst() == nullptr) {
287         return node;
288     }
289     StmtNode *s = node->GetFirst();
290     do {
291         (void)ExtSimplify(s);
292         s = s->GetNext();
293         ;
294     } while (s != nullptr);
295     return node;
296 }
297 
ExtSimplifyIf(IfStmtNode * node)298 StmtNode *ExtConstantFold::ExtSimplifyIf(IfStmtNode *node)
299 {
300     CHECK_NULL_FATAL(node);
301     (void)ExtSimplify(node->GetThenPart());
302     if (node->GetElsePart()) {
303         (void)ExtSimplify(node->GetElsePart());
304     }
305     BaseNode *origTest = node->Opnd();
306     BaseNode *returnValue = ExtFold(node->Opnd());
307     if (returnValue != origTest) {
308         node->SetOpnd(returnValue, 0);
309     }
310     return node;
311 }
312 
ExtSimplifyDassign(DassignNode * node)313 StmtNode *ExtConstantFold::ExtSimplifyDassign(DassignNode *node)
314 {
315     CHECK_NULL_FATAL(node);
316     BaseNode *returnValue;
317     returnValue = ExtFold(node->GetRHS());
318     if (returnValue != node->GetRHS()) {
319         node->SetRHS(returnValue);
320     }
321     return node;
322 }
323 
ExtSimplifyIassign(IassignNode * node)324 StmtNode *ExtConstantFold::ExtSimplifyIassign(IassignNode *node)
325 {
326     CHECK_NULL_FATAL(node);
327     BaseNode *returnValue;
328     returnValue = ExtFold(node->GetRHS());
329     if (returnValue != node->GetRHS()) {
330         node->SetRHS(returnValue);
331     }
332     return node;
333 }
334 
ExtSimplifyWhile(WhileStmtNode * node)335 StmtNode *ExtConstantFold::ExtSimplifyWhile(WhileStmtNode *node)
336 {
337     CHECK_NULL_FATAL(node);
338     if (node->Opnd(0) == nullptr) {
339         return node;
340     }
341     BaseNode *returnValue = ExtFold(node->Opnd(0));
342     if (returnValue != node->Opnd(0)) {
343         node->SetOpnd(returnValue, 0);
344     }
345     return node;
346 }
347 }  // namespace maple
348