1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "ext_constantfold.h"
17 #include <climits>
18
19 namespace maple {
20 // This class is designed to further identify simplification
21 // patterns that have not been covered in ConstantFold.
22
ExtSimplify(StmtNode * node)23 StmtNode *ExtConstantFold::ExtSimplify(StmtNode *node)
24 {
25 CHECK_NULL_FATAL(node);
26 switch (node->GetOpCode()) {
27 case OP_block:
28 return ExtSimplifyBlock(static_cast<BlockNode *>(node));
29 case OP_if:
30 return ExtSimplifyIf(static_cast<IfStmtNode *>(node));
31 case OP_dassign:
32 return ExtSimplifyDassign(static_cast<DassignNode *>(node));
33 case OP_iassign:
34 return ExtSimplifyIassign(static_cast<IassignNode *>(node));
35 case OP_dowhile:
36 case OP_while:
37 return ExtSimplifyWhile(static_cast<WhileStmtNode *>(node));
38 default:
39 return node;
40 }
41 }
42
DispatchFold(BaseNode * node)43 BaseNode *ExtConstantFold::DispatchFold(BaseNode *node)
44 {
45 // Not trying all possiblities.
46 // For simplicity, stop looking further down the expression once OP_OP_cior/OP_cand (etc) are seen
47 CHECK_NULL_FATAL(node);
48 switch (node->GetOpCode()) {
49 case OP_cior:
50 case OP_lior:
51 return ExtFoldIor(static_cast<BinaryNode *>(node));
52 case OP_cand:
53 case OP_land:
54 return ExtFoldXand(static_cast<BinaryNode *>(node));
55 case OP_abs:
56 case OP_bnot:
57 case OP_lnot:
58 case OP_neg:
59 case OP_recip:
60 case OP_sqrt:
61 return ExtFoldUnary(static_cast<UnaryNode *>(node));
62 case OP_add:
63 case OP_ashr:
64 case OP_band:
65 case OP_bior:
66 case OP_bxor:
67 case OP_div:
68 case OP_lshr:
69 case OP_max:
70 case OP_min:
71 case OP_mul:
72 case OP_rem:
73 case OP_shl:
74 case OP_sub:
75 case OP_eq:
76 case OP_ne:
77 case OP_ge:
78 case OP_gt:
79 case OP_le:
80 case OP_lt:
81 case OP_cmp:
82 return ExtFoldBinary(static_cast<BinaryNode *>(node));
83 case OP_select:
84 return ExtFoldTernary(static_cast<TernaryNode *>(node));
85 default:
86 return node;
87 }
88 }
89
ExtFoldUnary(UnaryNode * node)90 BaseNode *ExtConstantFold::ExtFoldUnary(UnaryNode *node)
91 {
92 CHECK_NULL_FATAL(node);
93 BaseNode *result = nullptr;
94 result = DispatchFold(node->Opnd(0));
95 if (result != node->Opnd(0)) {
96 node->SetOpnd(result, 0);
97 }
98 return node;
99 }
100
ExtFoldBinary(BinaryNode * node)101 BaseNode *ExtConstantFold::ExtFoldBinary(BinaryNode *node)
102 {
103 CHECK_NULL_FATAL(node);
104 BaseNode *result = nullptr;
105 result = DispatchFold(node->Opnd(0));
106 if (result != node->Opnd(0)) {
107 node->SetOpnd(result, 0);
108 }
109 result = DispatchFold(node->Opnd(1));
110 if (result != node->Opnd(1)) {
111 node->SetOpnd(result, 1);
112 }
113 return node;
114 }
115
ExtFoldTernary(TernaryNode * node)116 BaseNode *ExtConstantFold::ExtFoldTernary(TernaryNode *node)
117 {
118 CHECK_NULL_FATAL(node);
119 BaseNode *result = nullptr;
120 result = DispatchFold(node->Opnd(kFirstOpnd));
121 if (result != node->Opnd(kFirstOpnd)) {
122 node->SetOpnd(result, kFirstOpnd);
123 }
124 result = DispatchFold(node->Opnd(kSecondOpnd));
125 if (result != node->Opnd(kSecondOpnd)) {
126 node->SetOpnd(result, kSecondOpnd);
127 }
128 result = DispatchFold(node->Opnd(kThirdOpnd));
129 if (result != node->Opnd(kThirdOpnd)) {
130 node->SetOpnd(result, kThirdOpnd);
131 }
132 return node;
133 }
134
ExtFold(BaseNode * node)135 BaseNode *ExtConstantFold::ExtFold(BaseNode *node)
136 {
137 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
138 return nullptr;
139 }
140 return DispatchFold(node);
141 }
142
ExtFoldIor(BinaryNode * node)143 BaseNode *ExtConstantFold::ExtFoldIor(BinaryNode *node)
144 {
145 CHECK_NULL_FATAL(node);
146 // The target pattern (Cior, Lior):
147 // x == c || x == c+1 || ... || x == c+k
148 // ==> le (x - c), k
149 // where c is int. i
150 // Leave other cases for future including extended simplification of partial expressions
151 std::queue<BaseNode *> operands;
152 std::vector<int64> uniqOperands;
153 operands.push(node);
154 int64 minVal = LLONG_MAX;
155 bool isWorkable = true;
156 BaseNode *lNode = nullptr;
157
158 while (!operands.empty()) {
159 BaseNode *operand = operands.front();
160 operands.pop();
161 Opcode op = operand->GetOpCode();
162 if (op == OP_cior || op == OP_lior) {
163 operands.push(static_cast<BinaryNode *>(operand)->GetBOpnd(0));
164 operands.push(static_cast<BinaryNode *>(operand)->GetBOpnd(1));
165 } else if (op == OP_eq) {
166 BinaryNode *bNode = static_cast<BinaryNode *>(operand);
167 if (lNode == nullptr) {
168 if (bNode->Opnd(0)->GetOpCode() == OP_dread || bNode->Opnd(0)->GetOpCode() == OP_iread) {
169 lNode = bNode->Opnd(0);
170 } else {
171 // Consider other cases in future
172 isWorkable = false;
173 break;
174 }
175 }
176
177 if ((lNode->IsSameContent(bNode->Opnd(0))) && (bNode->Opnd(1)->GetOpCode() == OP_constval) &&
178 (IsPrimitiveInteger(bNode->Opnd(1)->GetPrimType()))) {
179 MIRConst *rConstVal = safe_cast<ConstvalNode>(bNode->Opnd(1))->GetConstVal();
180 int64 rVal = static_cast<MIRIntConst *>(rConstVal)->GetExtValue();
181 minVal = std::min(minVal, rVal);
182 uniqOperands.push_back(rVal);
183 } else {
184 isWorkable = false;
185 break;
186 }
187 } else {
188 isWorkable = false;
189 break;
190 }
191 }
192
193 if (isWorkable) {
194 std::sort(uniqOperands.begin(), uniqOperands.end());
195 uniqOperands.erase(std::unique(uniqOperands.begin(), uniqOperands.end()), uniqOperands.end());
196 if ((uniqOperands.size() >= 2) && // operand count is not less than 2
197 (uniqOperands[uniqOperands.size() - 1] == uniqOperands[0] + static_cast<int64>(uniqOperands.size()) - 1)) {
198 PrimType nPrimType = node->GetPrimType();
199 BaseNode *diffVal;
200 ConstvalNode *lowVal = mirModule->GetMIRBuilder()->CreateIntConst(minVal, nPrimType);
201 diffVal = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, nPrimType, lNode, lowVal);
202 PrimType cmpPrimType = (nPrimType == PTY_i64 || nPrimType == PTY_u64) ? PTY_u64 : PTY_u32;
203 MIRType *cmpMirType = (nPrimType == PTY_i64 || nPrimType == PTY_u64)
204 ? GlobalTables::GetTypeTable().GetUInt64()
205 : GlobalTables::GetTypeTable().GetUInt32();
206 ConstvalNode *deltaVal =
207 mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(uniqOperands.size()) - 1, cmpPrimType);
208 CompareNode *result;
209 result = mirModule->GetMIRBuilder()->CreateExprCompare(OP_le, *cmpMirType, *cmpMirType, diffVal, deltaVal);
210 return result;
211 } else {
212 return node;
213 }
214 } else {
215 return node;
216 }
217 }
218
ExtFoldXand(BinaryNode * node)219 BaseNode *ExtConstantFold::ExtFoldXand(BinaryNode *node)
220 {
221 // The target pattern (Cand, Land):
222 // (x & m1) == c1 && (x & m2) == c2 && ... && (x & Mk) == ck
223 // where mi and ci shall be all int constants
224 // ==> (x & M) == C
225
226 CHECK_NULL_FATAL(node);
227 CHECK_FATAL(node->GetOpCode() == OP_cand || node->GetOpCode() == OP_land,
228 "Operator is neither OP_cand nor OP_land");
229
230 BaseNode *lnode = DispatchFold(node->Opnd(0));
231 if (lnode != node->Opnd(0)) {
232 node->SetOpnd(lnode, 0);
233 }
234
235 BaseNode *rnode = DispatchFold(node->Opnd(1));
236 if (rnode != node->Opnd(1)) {
237 node->SetOpnd(rnode, 1);
238 }
239
240 // Check if it is of the form of (x & m) == c cand (x & m') == c'
241 if ((lnode->GetOpCode() == OP_eq) && (rnode->GetOpCode() == OP_eq) && (lnode->Opnd(0)->GetOpCode() == OP_band) &&
242 (lnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) &&
243 (IsPrimitiveInteger(lnode->Opnd(0)->Opnd(1)->GetPrimType())) && (lnode->Opnd(1)->GetOpCode() == OP_constval) &&
244 (IsPrimitiveInteger(lnode->Opnd(1)->GetPrimType())) && (rnode->Opnd(0)->GetOpCode() == OP_band) &&
245 (rnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) &&
246 (IsPrimitiveInteger(rnode->Opnd(0)->Opnd(1)->GetPrimType())) && (rnode->Opnd(1)->GetOpCode() == OP_constval) &&
247 (IsPrimitiveInteger(rnode->Opnd(1)->GetPrimType())) &&
248 (lnode->Opnd(0)->Opnd(0)->IsSameContent(rnode->Opnd(0)->Opnd(0)))) {
249 MIRConst *lmConstVal = safe_cast<ConstvalNode>(lnode->Opnd(0)->Opnd(1))->GetConstVal();
250 uint64 lmVal = static_cast<MIRIntConst *>(lmConstVal)->GetExtValue();
251 MIRConst *rmConstVal = safe_cast<ConstvalNode>(rnode->Opnd(0)->Opnd(1))->GetConstVal();
252 uint64 rmVal = static_cast<MIRIntConst *>(rmConstVal)->GetExtValue();
253 MIRConst *lcConstVal = safe_cast<ConstvalNode>(lnode->Opnd(1))->GetConstVal();
254 uint64 lcVal = static_cast<MIRIntConst *>(lcConstVal)->GetExtValue();
255 MIRConst *rcConstVal = safe_cast<ConstvalNode>(rnode->Opnd(1))->GetConstVal();
256 uint64 rcVal = static_cast<MIRIntConst *>(rcConstVal)->GetExtValue();
257
258 bool isWorkable = true;
259 for (uint64 i = 0; i < k64BitSize; i++) {
260 if ((lmVal & (1UL << i)) == (rmVal & (1UL << i)) && (lcVal & (1UL << i)) != (rcVal & (1UL << i))) {
261 isWorkable = false;
262 break;
263 }
264 }
265
266 if (isWorkable) {
267 uint64 mVal = lmVal | rmVal;
268 uint64 cVal = lcVal | rcVal;
269 PrimType mPrimType = lnode->Opnd(0)->Opnd(1)->GetPrimType();
270 ConstvalNode *mIntConst = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(mVal), mPrimType);
271 PrimType cPrimType = lnode->Opnd(1)->GetPrimType();
272 ConstvalNode *cIntConst = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<int64>(cVal), cPrimType);
273 BinaryNode *eqNode = static_cast<BinaryNode *>(lnode);
274 BinaryNode *bandNode = static_cast<BinaryNode *>(eqNode->Opnd(0));
275 bandNode->SetOpnd(mIntConst, 1);
276 eqNode->SetOpnd(cIntConst, 1);
277 return eqNode;
278 }
279 }
280 return node;
281 }
282
ExtSimplifyBlock(BlockNode * node)283 StmtNode *ExtConstantFold::ExtSimplifyBlock(BlockNode *node)
284 {
285 CHECK_NULL_FATAL(node);
286 if (node->GetFirst() == nullptr) {
287 return node;
288 }
289 StmtNode *s = node->GetFirst();
290 do {
291 (void)ExtSimplify(s);
292 s = s->GetNext();
293 ;
294 } while (s != nullptr);
295 return node;
296 }
297
ExtSimplifyIf(IfStmtNode * node)298 StmtNode *ExtConstantFold::ExtSimplifyIf(IfStmtNode *node)
299 {
300 CHECK_NULL_FATAL(node);
301 (void)ExtSimplify(node->GetThenPart());
302 if (node->GetElsePart()) {
303 (void)ExtSimplify(node->GetElsePart());
304 }
305 BaseNode *origTest = node->Opnd();
306 BaseNode *returnValue = ExtFold(node->Opnd());
307 if (returnValue != origTest) {
308 node->SetOpnd(returnValue, 0);
309 }
310 return node;
311 }
312
ExtSimplifyDassign(DassignNode * node)313 StmtNode *ExtConstantFold::ExtSimplifyDassign(DassignNode *node)
314 {
315 CHECK_NULL_FATAL(node);
316 BaseNode *returnValue;
317 returnValue = ExtFold(node->GetRHS());
318 if (returnValue != node->GetRHS()) {
319 node->SetRHS(returnValue);
320 }
321 return node;
322 }
323
ExtSimplifyIassign(IassignNode * node)324 StmtNode *ExtConstantFold::ExtSimplifyIassign(IassignNode *node)
325 {
326 CHECK_NULL_FATAL(node);
327 BaseNode *returnValue;
328 returnValue = ExtFold(node->GetRHS());
329 if (returnValue != node->GetRHS()) {
330 node->SetRHS(returnValue);
331 }
332 return node;
333 }
334
ExtSimplifyWhile(WhileStmtNode * node)335 StmtNode *ExtConstantFold::ExtSimplifyWhile(WhileStmtNode *node)
336 {
337 CHECK_NULL_FATAL(node);
338 if (node->Opnd(0) == nullptr) {
339 return node;
340 }
341 BaseNode *returnValue = ExtFold(node->Opnd(0));
342 if (returnValue != node->Opnd(0)) {
343 node->SetOpnd(returnValue, 0);
344 }
345 return node;
346 }
347 } // namespace maple
348