1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27
28 namespace maple {
29
30 namespace {
31
32 constexpr uint64 kJsTypeNumber = 4;
33 constexpr uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32; // set high 32 bit as JSTYPE_NUMBER
34 constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit
35 constexpr uint32 kBitSizePerByte = 8;
36 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
37
38 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
39
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)40 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
41 {
42 if (!v1 && !v2) {
43 return std::nullopt;
44 }
45
46 // Perform all calculations in terms of the maximum available signed type.
47 // The value will be truncated for an appropriate type when constant is created in PariToExpr function
48 // replace with PTY_i128 when IntVal supports 128bit calculation
49 return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(0, PTY_i64);
50 }
51
52 // Perform all calculations in terms of the maximum available signed type.
53 // The value will be truncated for an appropriate type when constant is created in PariToExpr function
54 // replace with PTY_i128 when IntVal supports 128bit calculation
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)55 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
56 {
57 if (!v1 && !v2) {
58 return std::nullopt;
59 }
60
61 if (v1 && v2) {
62 return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
63 }
64
65 if (v1) {
66 return v1->TruncOrExtend(PTY_i64);
67 }
68
69 // !v1 && v2
70 return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
71 }
72
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)73 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
74 {
75 return AddSub(v1, v2, true);
76 }
77
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)78 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
79 {
80 return AddSub(v1, v2, false);
81 }
82
83 } // anonymous namespace
84
85 // This phase is designed to achieve compiler optimization by
86 // simplifying constant expressions. The constant expression
87 // is evaluated and replaced by the value calculated on compile
88 // time to save time on runtime.
89 //
90 // The main procedure shows as following:
91 // A. Analyze expression type
92 // B. Analysis operator type
93 // C. Replace the expression with the result of the operation
94
95 // true if the constant's bits are made of only one group of contiguous 1's
96 // starting at bit 0
ContiguousBitsOf1(uint64 x)97 static bool ContiguousBitsOf1(uint64 x)
98 {
99 if (x == 0) {
100 return false;
101 }
102 return (~x & (x + 1)) == (x + 1);
103 }
104
IsPowerOf2(uint64 num)105 inline bool IsPowerOf2(uint64 num)
106 {
107 if (num == 0) {
108 return false;
109 }
110 return (~(num - 1) & num) == num;
111 }
112
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const113 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
114 BaseNode *rhs) const
115 {
116 CHECK_NULL_FATAL(old);
117 BinaryNode *result = nullptr;
118 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
119 result = old;
120 } else {
121 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
122 }
123 return result;
124 }
125
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const126 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
127 {
128 CHECK_NULL_FATAL(old);
129 UnaryNode *result = nullptr;
130 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
131 result = old;
132 } else {
133 result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
134 }
135 return result;
136 }
137
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const138 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode *, std::optional<IntVal>> &pair) const
139 {
140 CHECK_NULL_FATAL(pair.first);
141 BaseNode *result = pair.first;
142 if (!pair.second || *pair.second == 0) {
143 return result;
144 }
145 if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
146 // -a, 5 -> 5 - a
147 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(pair.second->GetExtValue(), resultType);
148 BaseNode *r = static_cast<UnaryNode *>(pair.first)->Opnd(0);
149 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
150 } else {
151 if ((!pair.second->GetSignBit() && pair.second->GetSXTValue(GetPrimTypeBitSize(resultType)) > 0) ||
152 pair.second->GetSXTValue() == INT64_MIN) {
153 // +-a, 5 -> a + 5
154 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(pair.second->GetExtValue(), resultType);
155 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
156 } else {
157 // +-a, -5 -> a + -5
158 ConstvalNode *val =
159 mirModule->GetMIRBuilder()->CreateIntConst((-pair.second.value()).GetExtValue(), resultType);
160 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
161 }
162 }
163 return result;
164 }
165
FoldBase(BaseNode * node) const166 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
167 {
168 return std::make_pair(node, std::nullopt);
169 }
170
Simplify(StmtNode * node)171 StmtNode *ConstantFold::Simplify(StmtNode *node)
172 {
173 CHECK_NULL_FATAL(node);
174 switch (node->GetOpCode()) {
175 case OP_dassign:
176 case OP_maydassign:
177 return SimplifyDassign(static_cast<DassignNode *>(node));
178 case OP_iassign:
179 return SimplifyIassign(static_cast<IassignNode *>(node));
180 case OP_block:
181 return SimplifyBlock(static_cast<BlockNode *>(node));
182 case OP_if:
183 return SimplifyIf(static_cast<IfStmtNode *>(node));
184 case OP_dowhile:
185 case OP_while:
186 return SimplifyWhile(static_cast<WhileStmtNode *>(node));
187 case OP_switch:
188 return SimplifySwitch(static_cast<SwitchNode *>(node));
189 case OP_eval:
190 case OP_throw:
191 case OP_free:
192 case OP_decref:
193 case OP_incref:
194 case OP_decrefreset:
195 case OP_regassign:
196 CASE_OP_ASSERT_NONNULL
197 case OP_igoto:
198 return SimplifyUnary(static_cast<UnaryStmtNode *>(node));
199 case OP_brfalse:
200 case OP_brtrue:
201 return SimplifyCondGoto(static_cast<CondGotoNode *>(node));
202 case OP_return:
203 case OP_syncenter:
204 case OP_syncexit:
205 case OP_call:
206 case OP_virtualcall:
207 case OP_superclasscall:
208 case OP_interfacecall:
209 case OP_customcall:
210 case OP_polymorphiccall:
211 case OP_intrinsiccall:
212 case OP_xintrinsiccall:
213 case OP_intrinsiccallwithtype:
214 case OP_callassigned:
215 case OP_virtualcallassigned:
216 case OP_superclasscallassigned:
217 case OP_interfacecallassigned:
218 case OP_customcallassigned:
219 case OP_polymorphiccallassigned:
220 case OP_intrinsiccallassigned:
221 case OP_intrinsiccallwithtypeassigned:
222 case OP_xintrinsiccallassigned:
223 case OP_callinstant:
224 case OP_callinstantassigned:
225 case OP_virtualcallinstant:
226 case OP_virtualcallinstantassigned:
227 case OP_superclasscallinstant:
228 case OP_superclasscallinstantassigned:
229 case OP_interfacecallinstant:
230 case OP_interfacecallinstantassigned:
231 CASE_OP_ASSERT_BOUNDARY
232 return SimplifyNary(static_cast<NaryStmtNode *>(node));
233 case OP_icall:
234 case OP_icallassigned:
235 case OP_icallproto:
236 case OP_icallprotoassigned:
237 return SimplifyIcall(static_cast<IcallNode *>(node));
238 case OP_asm:
239 return SimplifyAsm(static_cast<AsmNode *>(node));
240 default:
241 return node;
242 }
243 }
244
DispatchFold(BaseNode * node)245 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
246 {
247 CHECK_NULL_FATAL(node);
248 switch (node->GetOpCode()) {
249 case OP_sizeoftype:
250 return FoldSizeoftype(static_cast<SizeoftypeNode *>(node));
251 case OP_abs:
252 case OP_bnot:
253 case OP_lnot:
254 case OP_neg:
255 case OP_recip:
256 case OP_sqrt:
257 return FoldUnary(static_cast<UnaryNode *>(node));
258 case OP_ceil:
259 case OP_floor:
260 case OP_round:
261 case OP_trunc:
262 case OP_cvt:
263 return FoldTypeCvt(static_cast<TypeCvtNode *>(node));
264 case OP_sext:
265 case OP_zext:
266 case OP_extractbits:
267 return FoldExtractbits(static_cast<ExtractbitsNode *>(node));
268 case OP_iaddrof:
269 case OP_iread:
270 return FoldIread(static_cast<IreadNode *>(node));
271 case OP_add:
272 case OP_ashr:
273 case OP_band:
274 case OP_bior:
275 case OP_bxor:
276 case OP_cand:
277 case OP_cior:
278 case OP_div:
279 case OP_land:
280 case OP_lior:
281 case OP_lshr:
282 case OP_max:
283 case OP_min:
284 case OP_mul:
285 case OP_rem:
286 case OP_shl:
287 case OP_sub:
288 return FoldBinary(static_cast<BinaryNode *>(node));
289 case OP_eq:
290 case OP_ne:
291 case OP_ge:
292 case OP_gt:
293 case OP_le:
294 case OP_lt:
295 case OP_cmp:
296 return FoldCompare(static_cast<CompareNode *>(node));
297 case OP_depositbits:
298 return FoldDepositbits(static_cast<DepositbitsNode *>(node));
299 case OP_select:
300 return FoldTernary(static_cast<TernaryNode *>(node));
301 case OP_array:
302 return FoldArray(static_cast<ArrayNode *>(node));
303 case OP_retype:
304 return FoldRetype(static_cast<RetypeNode *>(node));
305 case OP_gcmallocjarray:
306 case OP_gcpermallocjarray:
307 return FoldGcmallocjarray(static_cast<JarrayMallocNode *>(node));
308 default:
309 return FoldBase(static_cast<BaseNode *>(node));
310 }
311 }
312
Negate(BaseNode * node) const313 BaseNode *ConstantFold::Negate(BaseNode *node) const
314 {
315 CHECK_NULL_FATAL(node);
316 return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
317 }
318
Negate(UnaryNode * node) const319 BaseNode *ConstantFold::Negate(UnaryNode *node) const
320 {
321 CHECK_NULL_FATAL(node);
322 BaseNode *result = nullptr;
323 if (node->GetOpCode() == OP_neg) {
324 result = static_cast<BaseNode *>(node->Opnd(0));
325 } else {
326 BaseNode *n = static_cast<BaseNode *>(node);
327 result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
328 }
329 return result;
330 }
331
Negate(const ConstvalNode * node) const332 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
333 {
334 CHECK_NULL_FATAL(node);
335 ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
336 CHECK_NULL_FATAL(copy);
337 copy->GetConstVal()->Neg();
338 return copy;
339 }
340
NegateTree(BaseNode * node) const341 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
342 {
343 CHECK_NULL_FATAL(node);
344 if (node->IsUnaryNode()) {
345 return Negate(static_cast<UnaryNode *>(node));
346 } else if (node->GetOpCode() == OP_constval) {
347 return Negate(static_cast<ConstvalNode *>(node));
348 } else {
349 return Negate(static_cast<BaseNode *>(node));
350 }
351 }
352
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const353 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
354 const MIRIntConst &intConst0,
355 const MIRIntConst &intConst1) const
356 {
357 uint64 result = 0;
358
359 bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
360 bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
361 bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
362
363 switch (opcode) {
364 case OP_eq: {
365 result = equal;
366 break;
367 }
368 case OP_ge: {
369 result = greater || equal;
370 break;
371 }
372 case OP_gt: {
373 result = greater;
374 break;
375 }
376 case OP_le: {
377 result = less || equal;
378 break;
379 }
380 case OP_lt: {
381 result = less;
382 break;
383 }
384 case OP_ne: {
385 result = !equal;
386 break;
387 }
388 case OP_cmp: {
389 if (greater) {
390 result = kGreater;
391 } else if (equal) {
392 result = kEqual;
393 } else if (less) {
394 result = kLess;
395 }
396 break;
397 }
398 default:
399 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
400 break;
401 }
402 // determine the type
403 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
404 // form the constant
405 MIRIntConst *constValue = nullptr;
406 if (type.GetPrimType() == PTY_dyni32) {
407 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
408 constValue->SetValue(kJsTypeNumberInHigh32Bit | result);
409 } else {
410 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
411 }
412 return constValue;
413 }
414
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const415 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
416 const ConstvalNode &const0, const ConstvalNode &const1) const
417 {
418 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
419 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
420 CHECK_NULL_FATAL(intConst0);
421 CHECK_NULL_FATAL(intConst1);
422 MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
423 // form the ConstvalNode
424 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
425 resultConst->SetPrimType(resultType);
426 resultConst->SetConstVal(constValue);
427 return resultConst;
428 }
429
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * intConst0,const MIRIntConst * intConst1) const430 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *intConst0,
431 const MIRIntConst *intConst1) const
432 {
433 IntVal intVal0 = intConst0->GetValue();
434 IntVal intVal1 = intConst1->GetValue();
435 IntVal result(0, resultType);
436
437 switch (opcode) {
438 case OP_add: {
439 result = intVal0.Add(intVal1, resultType);
440 break;
441 }
442 case OP_sub: {
443 result = intVal0.Sub(intVal1, resultType);
444 break;
445 }
446 case OP_mul: {
447 result = intVal0.Mul(intVal1, resultType);
448 break;
449 }
450 case OP_div: {
451 result = intVal0.Div(intVal1, resultType);
452 break;
453 }
454 case OP_rem: {
455 result = intVal0.Rem(intVal1, resultType);
456 break;
457 }
458 case OP_ashr: {
459 result = intVal0.AShr(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
460 break;
461 }
462 case OP_lshr: {
463 result = intVal0.LShr(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
464 break;
465 }
466 case OP_shl: {
467 result = intVal0.Shl(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
468 break;
469 }
470 case OP_max: {
471 result = Max(intVal0, intVal1, resultType);
472 break;
473 }
474 case OP_min: {
475 result = Min(intVal0, intVal1, resultType);
476 break;
477 }
478 case OP_band: {
479 result = intVal0.And(intVal1, resultType);
480 break;
481 }
482 case OP_bior: {
483 result = intVal0.Or(intVal1, resultType);
484 break;
485 }
486 case OP_bxor: {
487 result = intVal0.Xor(intVal1, resultType);
488 break;
489 }
490 case OP_cand:
491 case OP_land: {
492 result = IntVal(intVal0.GetExtValue() && intVal1.GetExtValue(), resultType);
493 break;
494 }
495 case OP_cior:
496 case OP_lior: {
497 result = IntVal(intVal0.GetExtValue() || intVal1.GetExtValue(), resultType);
498 break;
499 }
500 case OP_depositbits: {
501 // handled in FoldDepositbits
502 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstBinary");
503 break;
504 }
505 default:
506 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
507 break;
508 }
509 // determine the type
510 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
511 // form the constant
512 MIRIntConst *constValue = nullptr;
513 if (type.GetPrimType() == PTY_dyni32) {
514 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
515 constValue->SetValue(kJsTypeNumberInHigh32Bit | result.GetExtValue());
516 } else {
517 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result.GetExtValue(), type);
518 }
519 return constValue;
520 }
521
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const522 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
523 const ConstvalNode &const1) const
524 {
525 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
526 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
527 CHECK_NULL_FATAL(intConst0);
528 CHECK_NULL_FATAL(intConst1);
529 MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, intConst0, intConst1);
530 // form the ConstvalNode
531 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
532 resultConst->SetPrimType(resultType);
533 resultConst->SetConstVal(constValue);
534 return resultConst;
535 }
536
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const537 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
538 const ConstvalNode &const1) const
539 {
540 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
541 const MIRDoubleConst *doubleConst0 = nullptr;
542 const MIRDoubleConst *doubleConst1 = nullptr;
543 const MIRFloatConst *floatConst0 = nullptr;
544 const MIRFloatConst *floatConst1 = nullptr;
545 bool useDouble = (const0.GetPrimType() == PTY_f64);
546 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
547 resultConst->SetPrimType(resultType);
548 if (useDouble) {
549 doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
550 doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
551 CHECK_NULL_FATAL(doubleConst0);
552 CHECK_NULL_FATAL(doubleConst1);
553 } else {
554 floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
555 floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
556 CHECK_NULL_FATAL(floatConst0);
557 CHECK_NULL_FATAL(floatConst1);
558 }
559 float constValueFloat = 0.0;
560 double constValueDouble = 0.0;
561 switch (opcode) {
562 case OP_add: {
563 if (useDouble) {
564 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
565 } else {
566 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
567 }
568 break;
569 }
570 case OP_sub: {
571 if (useDouble) {
572 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
573 } else {
574 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
575 }
576 break;
577 }
578 case OP_mul: {
579 if (useDouble) {
580 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
581 } else {
582 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
583 }
584 break;
585 }
586 case OP_div: {
587 // for floats div by 0 is well defined
588 if (useDouble) {
589 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
590 } else {
591 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
592 }
593 break;
594 }
595 case OP_max: {
596 if (useDouble) {
597 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
598 : doubleConst1->GetValue();
599 } else {
600 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
601 : floatConst1->GetValue();
602 }
603 break;
604 }
605 case OP_min: {
606 if (useDouble) {
607 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
608 : doubleConst1->GetValue();
609 } else {
610 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
611 : floatConst1->GetValue();
612 }
613 break;
614 }
615 case OP_rem:
616 case OP_ashr:
617 case OP_lshr:
618 case OP_shl:
619 case OP_band:
620 case OP_bior:
621 case OP_bxor:
622 case OP_cand:
623 case OP_land:
624 case OP_cior:
625 case OP_lior:
626 case OP_depositbits: {
627 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
628 break;
629 }
630 default:
631 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
632 break;
633 }
634 if (resultType == PTY_f64) {
635 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
636 } else {
637 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
638 }
639 return resultConst;
640 }
641
ConstValueEqual(int64 leftValue,int64 rightValue) const642 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
643 {
644 return (leftValue == rightValue);
645 }
646
ConstValueEqual(float leftValue,float rightValue) const647 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
648 {
649 return fabs(leftValue - rightValue) <= FLT_MIN;
650 }
651
ConstValueEqual(double leftValue,double rightValue) const652 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
653 {
654 return fabs(leftValue - rightValue) <= DBL_MIN;
655 }
656
657 template <typename T>
FullyEqual(T leftValue,T rightValue) const658 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
659 {
660 if (isinf(leftValue) && isinf(rightValue)) {
661 // (inf == inf), add the judgement here in case of the subtraction between float type inf
662 return true;
663 } else {
664 return ConstValueEqual(leftValue, rightValue);
665 }
666 }
667
668 template <typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const669 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
670 {
671 typename T::value_type leftValue = leftConst->GetValue();
672 typename T::value_type rightValue = rightConst->GetValue();
673 int64 result = 0;
674 ;
675 switch (op) {
676 case OP_eq: {
677 result = FullyEqual(leftValue, rightValue);
678 break;
679 }
680 case OP_ge: {
681 result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
682 break;
683 }
684 case OP_gt: {
685 result = (leftValue > rightValue);
686 break;
687 }
688 case OP_le: {
689 result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
690 break;
691 }
692 case OP_lt: {
693 result = (leftValue < rightValue);
694 break;
695 }
696 case OP_ne: {
697 result = !FullyEqual(leftValue, rightValue);
698 break;
699 }
700 case OP_cmpl:
701 case OP_cmpg: {
702 if (std::isnan(leftValue) || std::isnan(rightValue)) {
703 result = (op == OP_cmpg) ? kGreater : kLess;
704 break;
705 }
706 }
707 [[clang::fallthrough]];
708 case OP_cmp: {
709 if (leftValue > rightValue) {
710 result = kGreater;
711 } else if (FullyEqual(leftValue, rightValue)) {
712 result = kEqual;
713 } else {
714 result = kLess;
715 }
716 break;
717 }
718 default:
719 DEBUG_ASSERT(false, "Unknown opcode for Comparison");
720 break;
721 }
722 return result;
723 }
724
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const725 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
726 const MIRConst &leftConst, const MIRConst &rightConst) const
727 {
728 int64 result = 0;
729 bool useDouble = (opndType == PTY_f64);
730 if (useDouble) {
731 result =
732 ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
733 } else {
734 result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
735 }
736 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
737 MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
738 return resultConst;
739 }
740
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const741 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
742 const ConstvalNode &const0, const ConstvalNode &const1) const
743 {
744 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
745 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
746 resultConst->SetPrimType(resultType);
747 resultConst->SetConstVal(
748 FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
749 return resultConst;
750 }
751
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1)752 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
753 const MIRConst &const0, const MIRConst &const1)
754 {
755 MIRConst *returnValue = nullptr;
756 if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
757 const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
758 const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
759 ASSERT_NOT_NULL(intConst0);
760 ASSERT_NOT_NULL(intConst1);
761 returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
762 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
763 returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
764 } else {
765 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
766 }
767 return returnValue;
768 }
769
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const770 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
771 const ConstvalNode &const0, const ConstvalNode &const1) const
772 {
773 ConstvalNode *returnValue = nullptr;
774 if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
775 returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
776 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
777 returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
778 } else {
779 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
780 }
781 return returnValue;
782 }
783
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r)784 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
785 BaseNode &l, BaseNode &r)
786 {
787 CompareNode *result = nullptr;
788 Opcode op = opcode;
789 switch (opcode) {
790 case OP_gt: {
791 op = OP_lt;
792 break;
793 }
794 case OP_lt: {
795 op = OP_gt;
796 break;
797 }
798 case OP_ge: {
799 op = OP_le;
800 break;
801 }
802 case OP_le: {
803 op = OP_ge;
804 break;
805 }
806 case OP_eq: {
807 break;
808 }
809 case OP_ne: {
810 break;
811 }
812 default:
813 DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
814 break;
815 }
816
817 result =
818 mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
819 return result;
820 }
821
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const822 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
823 const ConstvalNode &const1) const
824 {
825 ConstvalNode *returnValue = nullptr;
826 if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
827 returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
828 } else if (resultType == PTY_f32 || resultType == PTY_f64) {
829 returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
830 } else {
831 DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
832 }
833 return returnValue;
834 }
835
FoldIntConstUnary(Opcode opcode,PrimType resultType,const ConstvalNode * constNode) const836 ConstvalNode *ConstantFold::FoldIntConstUnary(Opcode opcode, PrimType resultType, const ConstvalNode *constNode) const
837 {
838 CHECK_NULL_FATAL(constNode);
839 const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode->GetConstVal());
840 ASSERT_NOT_NULL(cst);
841 IntVal result = cst->GetValue().TruncOrExtend(resultType);
842 switch (opcode) {
843 case OP_abs: {
844 if (result.IsSigned() && result.GetSignBit()) {
845 result = -result;
846 }
847 break;
848 }
849 case OP_bnot: {
850 result = ~result;
851 break;
852 }
853 case OP_lnot: {
854 result = {result == 0, resultType};
855 break;
856 }
857 case OP_neg: {
858 result = -result;
859 break;
860 }
861 case OP_sext: // handled in FoldExtractbits
862 case OP_zext: // handled in FoldExtractbits
863 case OP_extractbits: // handled in FoldExtractbits
864 case OP_recip:
865 case OP_sqrt: {
866 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnary");
867 break;
868 }
869 default:
870 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnary");
871 break;
872 }
873 // determine the type
874 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
875 // form the constant
876 MIRIntConst *constValue = nullptr;
877 if (type.GetPrimType() == PTY_dyni32) {
878 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
879 constValue->SetValue(kJsTypeNumberInHigh32Bit | result.GetExtValue());
880 } else {
881 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result.GetExtValue(), type);
882 }
883 // form the ConstvalNode
884 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
885 resultConst->SetPrimType(type.GetPrimType());
886 resultConst->SetConstVal(constValue);
887 return resultConst;
888 }
889
890 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const891 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
892 {
893 CHECK_NULL_FATAL(constNode);
894 typename T::value_type constValue = 0.0;
895 T *fpCst = static_cast<T *>(constNode->GetConstVal());
896 switch (opcode) {
897 case OP_recip:
898 case OP_neg:
899 case OP_abs: {
900 if (OP_recip == opcode) {
901 constValue = typename T::value_type(1.0 / fpCst->GetValue());
902 } else if (OP_neg == opcode) {
903 constValue = typename T::value_type(-fpCst->GetValue());
904 } else if (OP_abs == opcode) {
905 constValue = typename T::value_type(fabs(fpCst->GetValue()));
906 }
907 break;
908 }
909 case OP_sqrt: {
910 constValue = typename T::value_type(sqrt(fpCst->GetValue()));
911 break;
912 }
913 case OP_bnot:
914 case OP_lnot:
915 case OP_sext:
916 case OP_zext:
917 case OP_extractbits: {
918 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
919 break;
920 }
921 default:
922 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
923 break;
924 }
925 auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
926 resultConst->SetPrimType(resultType);
927 if (resultType == PTY_f32) {
928 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValue));
929 } else {
930 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
931 }
932 return resultConst;
933 }
934
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const935 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
936 {
937 ConstvalNode *returnValue = nullptr;
938 if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
939 returnValue = FoldIntConstUnary(opcode, resultType, constNode);
940 } else if (resultType == PTY_f32) {
941 returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, constNode);
942 } else if (resultType == PTY_f64) {
943 returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, constNode);
944 } else if (PTY_f128 == resultType) {
945 } else {
946 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
947 }
948 return returnValue;
949 }
950
FoldSizeoftype(SizeoftypeNode * node) const951 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldSizeoftype(SizeoftypeNode *node) const
952 {
953 CHECK_NULL_FATAL(node);
954 BaseNode *result = node;
955 MIRType *argType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
956 if (argType->GetKind() == kTypeScalar) {
957 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
958 uint32 size = GetPrimTypeSize(argType->GetPrimType());
959 ConstvalNode *constValueNode = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
960 constValueNode->SetPrimType(node->GetPrimType());
961 constValueNode->SetConstVal(GlobalTables::GetIntConstTable().GetOrCreateIntConst(size, resultType));
962 result = constValueNode;
963 }
964 return std::make_pair(result, std::nullopt);
965 }
966
FoldRetype(RetypeNode * node)967 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
968 {
969 CHECK_NULL_FATAL(node);
970 BaseNode *result = node;
971 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
972 if (node->Opnd(0) != p.first) {
973 RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
974 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
975 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
976 result = newRetNode;
977 }
978 return std::make_pair(result, std::nullopt);
979 }
980
FoldGcmallocjarray(JarrayMallocNode * node)981 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldGcmallocjarray(JarrayMallocNode *node)
982 {
983 CHECK_NULL_FATAL(node);
984 BaseNode *result = node;
985 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
986 if (node->Opnd(0) != p.first) {
987 JarrayMallocNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
988 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldGcmallocjarray");
989 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
990 result = newRetNode;
991 }
992 return std::make_pair(result, std::nullopt);
993 }
994
FoldUnary(UnaryNode * node)995 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
996 {
997 CHECK_NULL_FATAL(node);
998 BaseNode *result = nullptr;
999 std::optional<IntVal> sum = std::nullopt;
1000 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1001 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1002 if (cst != nullptr) {
1003 result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), cst);
1004 } else {
1005 bool isInt = IsPrimitiveInteger(node->GetPrimType());
1006 // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
1007 // primType will be set to opnd type. There will be problems in some cases. For example:
1008 // before cf:
1009 // neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
1010 // after cf:
1011 // neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) # wrong!
1012 // As a workaround, we exclude u1 opnd type
1013 if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
1014 result = NegateTree(p.first);
1015 if (result->GetOpCode() == OP_neg) {
1016 PrimType origPtyp = node->GetPrimType();
1017 PrimType newPtyp = result->GetPrimType();
1018 if (newPtyp == origPtyp) {
1019 if (static_cast<UnaryNode *>(result)->Opnd(0) == node->Opnd(0)) {
1020 // NegateTree returned an UnaryNode quivalent to `n`, so keep the
1021 // original UnaryNode to preserve identity
1022 result = node;
1023 }
1024 } else {
1025 if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
1026 // do not fold explicit cvt
1027 return std::make_pair(node, std::nullopt);
1028 } else {
1029 result->SetPrimType(origPtyp);
1030 }
1031 }
1032 }
1033 if (p.second) {
1034 sum = -(*p.second);
1035 }
1036 } else {
1037 result =
1038 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
1039 }
1040 }
1041 return std::make_pair(result, sum);
1042 }
1043
FloatToIntOverflow(float fval,PrimType totype)1044 static bool FloatToIntOverflow(float fval, PrimType totype)
1045 {
1046 static const float safeFloatMaxToInt32 = 2147483520.0f; // 2^31 - 128
1047 static const float safeFloatMinToInt32 = -2147483520.0f;
1048 static const float safeFloatMaxToInt64 = 9223372036854775680.0f; // 2^63 - 128
1049 static const float safeFloatMinToInt64 = -9223372036854775680.0f;
1050 if (!std::isfinite(fval)) {
1051 return true;
1052 }
1053 if (totype == PTY_i64 || totype == PTY_u64) {
1054 if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
1055 return true;
1056 }
1057 } else {
1058 if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
1059 return true;
1060 }
1061 }
1062 return false;
1063 }
1064
DoubleToIntOverflow(double dval,PrimType totype)1065 static bool DoubleToIntOverflow(double dval, PrimType totype)
1066 {
1067 static const double safeDoubleMaxToInt32 = 2147482624.0; // 2^31 - 1024
1068 static const double safeDoubleMinToInt32 = -2147482624.0;
1069 static const double safeDoubleMaxToInt64 = 9223372036854774784.0; // 2^63 - 1024
1070 static const double safeDoubleMinToInt64 = -9223372036854774784.0;
1071 if (!std::isfinite(dval)) {
1072 return true;
1073 }
1074 if (totype == PTY_i64 || totype == PTY_u64) {
1075 if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
1076 return true;
1077 }
1078 } else {
1079 if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
1080 return true;
1081 }
1082 }
1083 return false;
1084 }
1085
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1086 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1087 {
1088 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1089 resultConst->SetPrimType(toType);
1090 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1091 if (fromType == PTY_f32) {
1092 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1093 ASSERT_NOT_NULL(constValue);
1094 float floatValue = ceil(constValue->GetValue());
1095 if (FloatToIntOverflow(floatValue, toType)) {
1096 return nullptr;
1097 }
1098 resultConst->SetConstVal(
1099 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType));
1100 } else {
1101 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1102 ASSERT_NOT_NULL(constValue);
1103 double doubleValue = ceil(constValue->GetValue());
1104 if (DoubleToIntOverflow(doubleValue, toType)) {
1105 return nullptr;
1106 }
1107 resultConst->SetConstVal(
1108 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType));
1109 }
1110 return resultConst;
1111 }
1112
1113 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const1114 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
1115 {
1116 DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
1117 size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
1118 bool isSigned = IsSignedInteger(resultType.GetPrimType());
1119 int64 max = LLONG_MAX >> shiftNum;
1120 uint64 umax = ULLONG_MAX >> shiftNum;
1121 int64 min = isSigned ? (LLONG_MIN >> shiftNum) : 0;
1122 if (isSigned && (value > max)) {
1123 return static_cast<T>(max);
1124 } else if (!isSigned && (value > umax)) {
1125 return static_cast<T>(umax);
1126 } else if (value < min) {
1127 return static_cast<T>(min);
1128 }
1129 return value;
1130 }
1131
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const1132 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
1133 {
1134 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1135 if (fromType == PTY_f32) {
1136 const auto &constValue = static_cast<const MIRFloatConst &>(cst);
1137 float floatValue = constValue.GetValue();
1138 if (isFloor) {
1139 floatValue = floor(constValue.GetValue());
1140 }
1141 if (FloatToIntOverflow(floatValue, toType)) {
1142 return nullptr;
1143 }
1144 floatValue = CalIntValueFromFloatValue(floatValue, resultType);
1145 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1146 } else {
1147 const auto &constValue = static_cast<const MIRDoubleConst &>(cst);
1148 double doubleValue = constValue.GetValue();
1149 if (isFloor) {
1150 doubleValue = floor(constValue.GetValue());
1151 }
1152 if (DoubleToIntOverflow(doubleValue, toType)) {
1153 return nullptr;
1154 }
1155 doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1156 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1157 }
1158 }
1159
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1160 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1161 {
1162 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1163 resultConst->SetPrimType(toType);
1164 resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1165 return resultConst;
1166 }
1167
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1168 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1169 {
1170 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1171 if (fromType == PTY_f32) {
1172 const auto &constValue = static_cast<const MIRFloatConst &>(cst);
1173 float floatValue = round(constValue.GetValue());
1174 if (FloatToIntOverflow(floatValue, toType)) {
1175 return nullptr;
1176 }
1177 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1178 } else if (fromType == PTY_f64) {
1179 const auto &constValue = static_cast<const MIRDoubleConst &>(cst);
1180 double doubleValue = round(constValue.GetValue());
1181 if (DoubleToIntOverflow(doubleValue, toType)) {
1182 return nullptr;
1183 }
1184 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1185 } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1186 const auto &constValue = static_cast<const MIRIntConst &>(cst);
1187 if (IsSignedInteger(fromType)) {
1188 int64 fromValue = constValue.GetExtValue();
1189 float floatValue = round(static_cast<float>(fromValue));
1190 if (static_cast<int64>(floatValue) == fromValue) {
1191 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1192 }
1193 } else {
1194 uint64 fromValue = constValue.GetExtValue();
1195 float floatValue = round(static_cast<float>(fromValue));
1196 if (static_cast<uint64>(floatValue) == fromValue) {
1197 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1198 }
1199 }
1200 } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1201 const auto &constValue = static_cast<const MIRIntConst &>(cst);
1202 if (IsSignedInteger(fromType)) {
1203 int64 fromValue = constValue.GetExtValue();
1204 double doubleValue = round(static_cast<double>(fromValue));
1205 if (static_cast<int64>(doubleValue) == fromValue) {
1206 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1207 }
1208 } else {
1209 uint64 fromValue = constValue.GetExtValue();
1210 double doubleValue = round(static_cast<double>(fromValue));
1211 if (static_cast<uint64>(doubleValue) == fromValue) {
1212 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1213 }
1214 }
1215 }
1216 return nullptr;
1217 }
1218
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1219 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1220 {
1221 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1222 resultConst->SetPrimType(toType);
1223 resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1224 return resultConst;
1225 }
1226
FoldTrunk(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1227 ConstvalNode *ConstantFold::FoldTrunk(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1228 {
1229 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1230 resultConst->SetPrimType(toType);
1231 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1232 if (fromType == PTY_f32) {
1233 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1234 CHECK_NULL_FATAL(constValue);
1235 float floatValue = trunc(constValue->GetValue());
1236 if (FloatToIntOverflow(floatValue, toType)) {
1237 return nullptr;
1238 }
1239 resultConst->SetConstVal(
1240 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType));
1241 } else {
1242 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1243 CHECK_NULL_FATAL(constValue);
1244 double doubleValue = trunc(constValue->GetValue());
1245 if (DoubleToIntOverflow(doubleValue, toType)) {
1246 return nullptr;
1247 }
1248 resultConst->SetConstVal(
1249 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType));
1250 }
1251 return resultConst;
1252 }
1253
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1254 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1255 {
1256 if (IsPrimitiveDynType(fromType) || IsPrimitiveDynType(toType)) {
1257 // do not fold
1258 return nullptr;
1259 }
1260 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1261 MIRConst *toConst = nullptr;
1262 uint32 fromSize = GetPrimTypeBitSize(fromType);
1263 uint32 toSize = GetPrimTypeBitSize(toType);
1264 // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1265 if (fromType == PTY_u1) {
1266 fromSize = 1;
1267 }
1268 if (toType == PTY_u1) {
1269 toSize = 1;
1270 }
1271 if (toSize > fromSize) {
1272 Opcode op = OP_zext;
1273 if (IsSignedInteger(fromType)) {
1274 op = OP_sext;
1275 }
1276 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1277 ASSERT_NOT_NULL(constVal);
1278 toConst = FoldSignExtendMIRConst(op, toType, fromSize, constVal->GetValue().TruncOrExtend(fromType));
1279 } else {
1280 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1281 ASSERT_NOT_NULL(constVal);
1282 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1283 toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(constVal->GetExtValue(), type);
1284 }
1285 return toConst;
1286 }
1287 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1288 MIRConst *toConst = nullptr;
1289 if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1290 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1291 const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1292 ASSERT_NOT_NULL(fromValue);
1293 float floatValue = static_cast<float>(fromValue->GetValue());
1294 MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1295 toConst = toValue;
1296 } else {
1297 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1298 const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1299 ASSERT_NOT_NULL(fromValue);
1300 double doubleValue = static_cast<double>(fromValue->GetValue());
1301 MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1302 toConst = toValue;
1303 }
1304 return toConst;
1305 }
1306 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1307 return FoldFloorMIRConst(cst, fromType, toType, false);
1308 }
1309 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1310 return FoldRoundMIRConst(cst, fromType, toType);
1311 }
1312 CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1313 return nullptr;
1314 }
1315
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1316 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1317 {
1318 MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1319 if (toConstValue == nullptr) {
1320 return nullptr;
1321 }
1322 ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1323 toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1324 toConst->SetConstVal(toConstValue);
1325 return toConst;
1326 }
1327
1328 // return a primType with bit size >= bitSize (and the nearest one),
1329 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1330 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1331 {
1332 bool isSigned = IsSignedInteger(ptyp);
1333 bool isFloat = IsPrimitiveFloat(ptyp);
1334 if (bitSize == 1) { // 1 bit
1335 return PTY_u1;
1336 }
1337 if (bitSize <= 8) { // 8 bit
1338 return isSigned ? PTY_i8 : PTY_u8;
1339 }
1340 if (bitSize <= 16) { // 16 bit
1341 return isSigned ? PTY_i16 : PTY_u16;
1342 }
1343 if (bitSize <= 32) { // 32 bit
1344 return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1345 }
1346 if (bitSize <= 64) { // 64 bit
1347 return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1348 }
1349 if (bitSize <= 128) { // 128 bit
1350 return isFloat ? PTY_f128 : (isSigned ? PTY_i128 : PTY_u128);
1351 }
1352 return ptyp;
1353 }
1354
GetIntPrimTypeMax(PrimType ptyp)1355 size_t GetIntPrimTypeMax(PrimType ptyp)
1356 {
1357 switch (ptyp) {
1358 case PTY_u1:
1359 return 1;
1360 case PTY_u8:
1361 return UINT8_MAX;
1362 case PTY_i8:
1363 return INT8_MAX;
1364 case PTY_u16:
1365 return UINT16_MAX;
1366 case PTY_i16:
1367 return INT16_MAX;
1368 case PTY_u32:
1369 return UINT32_MAX;
1370 case PTY_i32:
1371 return INT32_MAX;
1372 case PTY_u64:
1373 return UINT64_MAX;
1374 case PTY_i64:
1375 return INT64_MAX;
1376 default:
1377 CHECK_FATAL(false, "NYI");
1378 }
1379 }
1380
GetIntPrimTypeMin(PrimType ptyp)1381 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1382 {
1383 if (IsUnsignedInteger(ptyp)) {
1384 return 0;
1385 }
1386 switch (ptyp) {
1387 case PTY_i8:
1388 return INT8_MIN;
1389 case PTY_i16:
1390 return INT16_MIN;
1391 case PTY_i32:
1392 return INT32_MIN;
1393 case PTY_i64:
1394 return INT64_MIN;
1395 default:
1396 CHECK_FATAL(false, "NYI");
1397 }
1398 }
1399
1400 // return a primtype to represent value range of expr
GetExprValueRangePtyp(BaseNode * expr)1401 PrimType GetExprValueRangePtyp(BaseNode *expr)
1402 {
1403 PrimType ptyp = expr->GetPrimType();
1404 Opcode op = expr->GetOpCode();
1405 if (expr->IsLeaf()) {
1406 return ptyp;
1407 }
1408 if (kOpcodeInfo.IsTypeCvt(op)) {
1409 auto *node = static_cast<TypeCvtNode *>(expr);
1410 if (GetPrimTypeSize(node->FromType()) < GetPrimTypeSize(node->GetPrimType())) {
1411 return GetExprValueRangePtyp(expr->Opnd(0));
1412 }
1413 return ptyp;
1414 }
1415 if (op == OP_sext || op == OP_zext || op == OP_extractbits) {
1416 auto *node = static_cast<ExtractbitsNode *>(expr);
1417 uint8 size = node->GetBitsSize();
1418 return GetNearestSizePtyp(size, expr->GetPrimType());
1419 }
1420 // find max size primtype of opnds.
1421 size_t maxTypeSize = 1;
1422 size_t ptypSize = GetPrimTypeSize(ptyp);
1423 for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
1424 PrimType opndPtyp = GetExprValueRangePtyp(expr->Opnd(i));
1425 size_t opndSize = GetPrimTypeSize(opndPtyp);
1426 if (ptypSize <= opndSize) {
1427 return ptyp;
1428 }
1429 if (maxTypeSize < opndSize) {
1430 maxTypeSize = opndSize;
1431 constexpr size_t intMaxSize = 8;
1432 if (maxTypeSize == intMaxSize) {
1433 break;
1434 }
1435 }
1436 }
1437 return GetNearestSizePtyp(maxTypeSize, ptyp);
1438 }
1439
FoldTypeCvt(TypeCvtNode * node)1440 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1441 {
1442 CHECK_NULL_FATAL(node);
1443 BaseNode *result = nullptr;
1444 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1445 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1446 PrimType destPtyp = node->GetPrimType();
1447 PrimType fromPtyp = node->FromType();
1448 if (cst != nullptr) {
1449 switch (node->GetOpCode()) {
1450 case OP_ceil: {
1451 result = FoldCeil(*cst, fromPtyp, destPtyp);
1452 break;
1453 }
1454 case OP_cvt: {
1455 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1456 break;
1457 }
1458 case OP_floor: {
1459 result = FoldFloor(*cst, fromPtyp, destPtyp);
1460 break;
1461 }
1462 case OP_round: {
1463 result = FoldRound(*cst, fromPtyp, destPtyp);
1464 break;
1465 }
1466 case OP_trunc: {
1467 result = FoldTrunk(*cst, fromPtyp, destPtyp);
1468 break;
1469 }
1470 default:
1471 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1472 break;
1473 }
1474 } else if (node->GetOpCode() == OP_cvt && GetPrimTypeSize(fromPtyp) == GetPrimTypeSize(destPtyp)) {
1475 if ((IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1476 (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1477 (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp) &&
1478 GetPrimTypeSize(fromPtyp) == GetPrimTypeSize(destPtyp))) {
1479 // the cvt is redundant
1480 return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1481 }
1482 } else if (node->GetOpCode() == OP_cvt && p.second.has_value() && *p.second != 0 &&
1483 p.second->GetExtValue() < INT_MAX && p.second->GetSXTValue() > -kMaxOffset &&
1484 IsPrimitiveInteger(node->GetPrimType()) && IsSignedInteger(node->FromType()) &&
1485 GetPrimTypeSize(node->GetPrimType()) > GetPrimTypeSize(node->FromType())) {
1486 bool simplifyCvt = false;
1487 // Integer values must not be allowd to wrap, for pointer arithmetic, including array indexing
1488 // Therefore, if dest-type of cvt expr is pointer type, like : cvt ptr/ref/a64 (op (expr, constval))
1489 // we assume that result of cvt opnd will never wrap
1490 if (IsPossible64BitAddress(node->GetPrimType()) && node->GetPrimType() != PTY_u64) {
1491 simplifyCvt = true;
1492 } else {
1493 // Ensure : op(expr, constval) is not overflow
1494 // Example :
1495 // given a expr like : r = (i + constVal), and its mplIR may be:
1496 // cvt i64 i32 (add i32 (cvt i32 i16 (dread i16 i), constval i32 constVal))
1497 // expr primtype(i32) range : |I32_MIN--------------------0--------------------I32_MAX|
1498 // expr value range(i16) : |I16_MIN------0------I16_MAX|
1499 // value range difference : |-------------| |-------------|
1500 // constVal value range : [I32_MIN - I16MIN, I32_MAX - I16_MAX], (i + constVal) will never pos/neg
1501 // overflow.
1502 PrimType valuePtyp = GetExprValueRangePtyp(p.first);
1503 PrimType exprPtyp = p.first->GetPrimType();
1504 int64 constVal = p.second->GetExtValue();
1505 // cannot only check (min <= constval && constval <= max) here.
1506 // If constval = -1, it will be SIZE_T_MAX when converted to size_t
1507 if ((constVal < 0 && constVal >= GetIntPrimTypeMin(exprPtyp) - GetIntPrimTypeMin(valuePtyp)) ||
1508 (constVal > 0 &&
1509 static_cast<size_t>(constVal) <= GetIntPrimTypeMax(exprPtyp) - GetIntPrimTypeMax(valuePtyp))) {
1510 simplifyCvt = true;
1511 }
1512 }
1513 if (simplifyCvt) {
1514 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, node->GetPrimType(), node->FromType(),
1515 p.first);
1516 return std::make_pair(result, p.second->TruncOrExtend(node->GetPrimType()));
1517 }
1518 }
1519 if (result == nullptr) {
1520 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1521 if (e != node->Opnd(0)) {
1522 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1523 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1524 } else {
1525 result = node;
1526 }
1527 }
1528 return std::make_pair(result, std::nullopt);
1529 }
1530
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1531 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1532 {
1533 uint64 result = opcode == OP_sext ? val.GetSXTValue(size) : val.GetZXTValue(size);
1534 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1535 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1536 return constValue;
1537 }
1538
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1539 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1540 const ConstvalNode &cst) const
1541 {
1542 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1543 const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1544 ASSERT_NOT_NULL(intCst);
1545 IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1546 MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1547 resultConst->SetPrimType(toConst->GetType().GetPrimType());
1548 resultConst->SetConstVal(toConst);
1549 return resultConst;
1550 }
1551
1552 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(ExtractbitsNode * x,MIRFunction * f)1553 static bool ExtractbitsRedundant(ExtractbitsNode *x, MIRFunction *f)
1554 {
1555 if (GetPrimTypeSize(x->GetPrimType()) == k8ByteSize) {
1556 return false; // this is trying to be conservative
1557 }
1558 BaseNode *opnd = x->Opnd(0);
1559 MIRType *mirType = nullptr;
1560 if (opnd->GetOpCode() == OP_dread) {
1561 DreadNode *dread = static_cast<DreadNode *>(opnd);
1562 MIRSymbol *sym = f->GetLocalOrGlobalSymbol(dread->GetStIdx());
1563 mirType = sym->GetType();
1564 if (dread->GetFieldID() != 0) {
1565 MIRStructType *structType = dynamic_cast<MIRStructType *>(mirType);
1566 if (structType == nullptr) {
1567 return false;
1568 }
1569 mirType = structType->GetFieldType(dread->GetFieldID());
1570 }
1571 } else if (opnd->GetOpCode() == OP_iread) {
1572 IreadNode *iread = static_cast<IreadNode *>(opnd);
1573 MIRPtrType *ptrType =
1574 dynamic_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1575 if (ptrType == nullptr) {
1576 return false;
1577 }
1578 mirType = ptrType->GetPointedType();
1579 if (iread->GetFieldID() != 0) {
1580 MIRStructType *structType = dynamic_cast<MIRStructType *>(mirType);
1581 if (structType == nullptr) {
1582 return false;
1583 }
1584 mirType = structType->GetFieldType(iread->GetFieldID());
1585 }
1586 } else {
1587 return false;
1588 }
1589 return IsPrimitiveInteger(mirType->GetPrimType()) && mirType->GetSize() * kBitSizePerByte == x->GetBitsSize() &&
1590 IsSignedInteger(mirType->GetPrimType()) == IsSignedInteger(x->GetPrimType());
1591 }
1592
1593 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1594 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1595 {
1596 CHECK_NULL_FATAL(node);
1597 BaseNode *result = nullptr;
1598 uint8 offset = node->GetBitsOffset();
1599 uint8 size = node->GetBitsSize();
1600 Opcode opcode = node->GetOpCode();
1601 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1602 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1603 if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1604 result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1605 return std::make_pair(result, std::nullopt);
1606 }
1607 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1608 if (e != node->Opnd(0)) {
1609 result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1610 size, e);
1611 } else {
1612 result = node;
1613 }
1614 // check for consecutive and redundant extraction of same bits
1615 BaseNode *opnd = result->Opnd(0);
1616 Opcode opndOp = opnd->GetOpCode();
1617 if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1618 uint8 opndOffset = static_cast<ExtractbitsNode *>(opnd)->GetBitsOffset();
1619 uint8 opndSize = static_cast<ExtractbitsNode *>(opnd)->GetBitsSize();
1620 if (offset == opndOffset && size == opndSize) {
1621 result->SetOpnd(opnd->Opnd(0), 0); // delete the redundant extraction
1622 }
1623 }
1624 if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1625 if (ExtractbitsRedundant(static_cast<ExtractbitsNode *>(result), mirModule->CurFunction())) {
1626 return std::make_pair(result->Opnd(0), std::nullopt);
1627 }
1628 }
1629 return std::make_pair(result, std::nullopt);
1630 }
1631
FoldIread(IreadNode * node)1632 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1633 {
1634 CHECK_NULL_FATAL(node);
1635 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1636 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1637 node->SetOpnd(e, 0);
1638 BaseNode *result = node;
1639 if (e->GetOpCode() != OP_addrof) {
1640 return std::make_pair(result, std::nullopt);
1641 }
1642
1643 AddrofNode *addrofNode = static_cast<AddrofNode *>(e);
1644 MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1645 TyIdx typeId = msy->GetTyIdx();
1646 CHECK_FATAL(GlobalTables::GetTypeTable().GetTypeTable().empty() == false, "container check");
1647 MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1648 if (addrofNode->GetFieldID() != 0 && (msyType->GetKind() == kTypeStruct || msyType->GetKind() == kTypeClass)) {
1649 msyType = static_cast<MIRStructType *>(msyType)->GetFieldType(addrofNode->GetFieldID());
1650 }
1651 MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1652 // If the high level type of iaddrof/iread doesn't match
1653 // the type of addrof's rhs, this optimization cannot be done.
1654 if (ptrType->GetPointedType() != msyType) {
1655 return std::make_pair(result, std::nullopt);
1656 }
1657
1658 Opcode op = node->GetOpCode();
1659 FieldID fieldID = node->GetFieldID();
1660 if (op == OP_iaddrof) {
1661 AddrofNode *newAddrof = addrofNode->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1662 CHECK_NULL_FATAL(newAddrof);
1663 newAddrof->SetFieldID(newAddrof->GetFieldID() + fieldID);
1664 result = newAddrof;
1665 } else if (op == OP_iread) {
1666 result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1667 node->GetFieldID() + addrofNode->GetFieldID());
1668 }
1669 return std::make_pair(result, std::nullopt);
1670 }
1671
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1672 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1673 {
1674 switch (op) {
1675 case OP_add: {
1676 int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1677 if (IsUnsignedInteger(primType)) {
1678 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1679 }
1680 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1681 return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1682 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1683 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1684 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1685 }
1686 case OP_sub: {
1687 if (IsUnsignedInteger(primType)) {
1688 return cstA < cstB;
1689 }
1690 int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1691 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1692 return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1693 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1694 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1695 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1696 }
1697 default: {
1698 return false;
1699 }
1700 }
1701 }
1702
FoldBinary(BinaryNode * node)1703 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1704 {
1705 CHECK_NULL_FATAL(node);
1706 BaseNode *result = nullptr;
1707 std::optional<IntVal> sum = std::nullopt;
1708 Opcode op = node->GetOpCode();
1709 PrimType primType = node->GetPrimType();
1710 PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1711 PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1712 std::pair<BaseNode *, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1713 std::pair<BaseNode *, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1714 BaseNode *l = lp.first;
1715 BaseNode *r = rp.first;
1716 ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1717 ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1718 bool isInt = IsPrimitiveInteger(primType);
1719
1720 if (lConst != nullptr && rConst != nullptr) {
1721 MIRConst *lConstVal = lConst->GetConstVal();
1722 MIRConst *rConstVal = rConst->GetConstVal();
1723 ASSERT_NOT_NULL(lConstVal);
1724 ASSERT_NOT_NULL(rConstVal);
1725 // Don't fold div by 0, for floats div by 0 is well defined.
1726 if ((op == OP_div || op == OP_rem) && isInt &&
1727 !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1728 result = NewBinaryNode(node, op, primType, lConst, rConst);
1729 } else {
1730 // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1731 // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1732 // logic since the alternative is to return pair(result = nullptr, sum = 6).
1733 // Doing so would introduce many nullptr checks in the code. See previous
1734 // commits that implemented that logic for a comparison.
1735 result = FoldConstBinary(op, primType, *lConst, *rConst);
1736 }
1737 } else if (lConst != nullptr && isInt) {
1738 MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1739 ASSERT_NOT_NULL(mcst);
1740 PrimType cstTyp = mcst->GetType().GetPrimType();
1741 IntVal cst = mcst->GetValue();
1742 if (op == OP_add) {
1743 sum = cst + rp.second;
1744 result = r;
1745 } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1746 // We exclude u1 type for fixing the following wrong example:
1747 // before cf:
1748 // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1749 // after cf:
1750 // add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1751 sum = cst - rp.second;
1752 result = NegateTree(r);
1753 } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1754 op == OP_band || op == OP_cand || op == OP_land) &&
1755 cst == 0) {
1756 // 0 * X -> 0
1757 // 0 / X -> 0
1758 // 0 % X -> 0
1759 // 0 >> X -> 0
1760 // 0 << X -> 0
1761 // 0 & X -> 0
1762 // 0 && X -> 0
1763 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1764 } else if (op == OP_mul && cst == 1) {
1765 // 1 * X --> X
1766 sum = rp.second;
1767 result = r;
1768 } else if (op == OP_bior && cst == -1) {
1769 // (-1) | X -> -1
1770 result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp);
1771 } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1772 // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1773 sum = cst * rp.second;
1774 if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1775 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1776 }
1777 result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1778 } else if (op == OP_lior || op == OP_cior) {
1779 if (cst != 0) {
1780 // 5 || X -> 1
1781 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1782 } else {
1783 // when cst is zero
1784 // 0 || X -> (X != 0);
1785 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1786 OP_ne, primType, r->GetPrimType(), r,
1787 mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1788 }
1789 } else if ((op == OP_cand || op == OP_land) && cst != 0) {
1790 // 5 && X -> (X != 0)
1791 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1792 OP_ne, primType, r->GetPrimType(), r, mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1793 } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1794 // 0 | X -> X
1795 // 0 ^ X -> X
1796 sum = rp.second;
1797 result = r;
1798 } else {
1799 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1800 }
1801 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1802 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1803 }
1804 } else if (rConst != nullptr && isInt) {
1805 MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1806 ASSERT_NOT_NULL(mcst);
1807 PrimType cstTyp = mcst->GetType().GetPrimType();
1808 IntVal cst = mcst->GetValue();
1809 if (op == OP_add) {
1810 if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1811 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1812 } else {
1813 result = l;
1814 sum = lp.second + cst;
1815 }
1816 } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1817 {
1818 result = l;
1819 sum = lp.second - cst;
1820 }
1821 } else if ((op == OP_mul || op == OP_band || op == OP_cand || op == OP_land) && cst == 0) {
1822 // X * 0 -> 0
1823 // X & 0 -> 0
1824 // X && 0 -> 0
1825 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1826 } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1827 // case [X * 1 -> X]
1828 // case [X / 1 = X]
1829 sum = lp.second;
1830 result = l;
1831 } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1832 // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1833 sum = lp.second * cst;
1834 if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1835 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1836 }
1837 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1838 } else if (op == OP_band && cst == -1) {
1839 // X & (-1) -> X
1840 sum = lp.second;
1841 result = l;
1842 } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1843 (!lp.second.has_value() || lp.second == 0)) {
1844 bool fold2extractbits = false;
1845 if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1846 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1847 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1848 ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1849 int64 shrAmt = static_cast<MIRIntConst *>(shrOpnd->GetConstVal())->GetExtValue();
1850 uint64 ucst = cst.GetZXTValue();
1851 uint64 bsize = 0;
1852 do {
1853 bsize++;
1854 ucst >>= 1;
1855 } while (ucst != 0);
1856 if (shrAmt + bsize <= GetPrimTypeSize(primType) * kBitSizePerByte) {
1857 fold2extractbits = true;
1858 // change to use extractbits
1859 result = mirModule->GetMIRBuilder()->CreateExprExtractbits(
1860 OP_extractbits, GetUnsignedPrimType(primType), shrAmt, bsize, shrNode->Opnd(0));
1861 sum = std::nullopt;
1862 }
1863 }
1864 }
1865 if (!fold2extractbits) {
1866 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1867 sum = std::nullopt;
1868 }
1869 } else if (op == OP_bior && cst == -1) {
1870 // X | (-1) -> -1
1871 result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp);
1872 } else if ((op == OP_lior || op == OP_cior)) {
1873 if (cst == 0) {
1874 // X || 0 -> X
1875 sum = lp.second;
1876 result = l;
1877 } else if (!cst.GetSignBit()) {
1878 // X || 5 -> 1
1879 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1880 } else {
1881 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1882 }
1883 } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1884 // X >> 0 -> X
1885 // X << 0 -> X
1886 // X | 0 -> X
1887 // X ^ 0 -> X
1888 sum = lp.second;
1889 result = l;
1890 } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1891 // bxor i32 (
1892 // cvt i32 u1 (regread u1 %13),
1893 // constValue i32 1),
1894 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1895 if (l->GetOpCode() == OP_cvt) {
1896 TypeCvtNode *cvtNode = static_cast<TypeCvtNode *>(l);
1897 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1898 BaseNode *base = cvtNode->Opnd(0);
1899 BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1900 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(base);
1901 BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1902 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1903 }
1904 }
1905 } else if (op == OP_rem && cst == 1) {
1906 // X % 1 -> 0
1907 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1908 } else {
1909 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1910 }
1911 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1912 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1913 }
1914 } else if (isInt && (op == OP_add || op == OP_sub)) {
1915 if (op == OP_add) {
1916 result = NewBinaryNode(node, op, primType, l, r);
1917 sum = lp.second + rp.second;
1918 } else if (node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1919 // if fold is (x - (y - z)) -> (x - neg(z)) - y
1920 // (x - neg(z)) Could cross the int limit
1921 // return node
1922 result = node;
1923 } else {
1924 result = NewBinaryNode(node, op, primType, l, r);
1925 sum = lp.second - rp.second;
1926 }
1927 } else {
1928 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1929 }
1930 return std::make_pair(result, sum);
1931 }
1932
SimplifyDoubleCompare(CompareNode & compareNode) const1933 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1934 {
1935 // See arm manual B.cond(P2993) and FCMP(P1091)
1936 CompareNode *node = &compareNode;
1937 BaseNode *result = node;
1938 BaseNode *l = node->Opnd(0);
1939 BaseNode *r = node->Opnd(1);
1940 if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1941 if (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) {
1942 ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1943 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1944 const CompareNode *compNode = static_cast<CompareNode *>(l);
1945 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1946 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1947 compNode->Opnd(0), compNode->Opnd(1));
1948 }
1949 } else if (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval) {
1950 ConstvalNode *constNode = static_cast<ConstvalNode *>(l);
1951 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1952 const CompareNode *compNode = static_cast<CompareNode *>(r);
1953 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1954 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1955 compNode->Opnd(0), compNode->Opnd(1));
1956 }
1957 } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1958 // ne (u1 x, constValue 0) <==> x
1959 ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1960 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1961 BaseNode *opnd = l;
1962 do {
1963 if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1964 result = opnd;
1965 break;
1966 } else if (opnd->GetOpCode() == OP_cvt) {
1967 TypeCvtNode *cvtNode = static_cast<TypeCvtNode *>(opnd);
1968 opnd = cvtNode->Opnd(0);
1969 } else {
1970 opnd = nullptr;
1971 }
1972 } while (opnd != nullptr);
1973 }
1974 } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1975 ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1976 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1977 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1978 result = l;
1979 result->SetOpCode(l->GetOpCode() == OP_ne ? OP_eq : OP_ne);
1980 }
1981 }
1982 } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1983 if (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) {
1984 ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1985 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1986 const CompareNode *compNode = static_cast<CompareNode *>(l);
1987 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1988 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1989 compNode->Opnd(0), compNode->Opnd(1));
1990 }
1991 } else if (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval) {
1992 ConstvalNode *constNode = static_cast<ConstvalNode *>(l);
1993 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1994 const CompareNode *compNode = static_cast<CompareNode *>(r);
1995 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1996 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1997 compNode->Opnd(1), compNode->Opnd(0));
1998 }
1999 }
2000 }
2001 return result;
2002 }
2003
FoldCompare(CompareNode * node)2004 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
2005 {
2006 CHECK_NULL_FATAL(node);
2007 BaseNode *result = nullptr;
2008 std::pair<BaseNode *, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
2009 std::pair<BaseNode *, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
2010 ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
2011 ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
2012 Opcode opcode = node->GetOpCode();
2013 if (lConst != nullptr && rConst != nullptr && !IsPrimitiveDynType(node->GetOpndType())) {
2014 result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
2015 } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
2016 lConst->GetConstVal()->GetKind() == kConstInt) {
2017 BaseNode *l = lp.first;
2018 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2019 result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
2020 } else {
2021 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
2022 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2023 if (l != node->Opnd(0) || r != node->Opnd(1)) {
2024 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2025 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
2026 } else {
2027 result = node;
2028 }
2029 auto *compareNode = static_cast<CompareNode *>(result);
2030 CHECK_NULL_FATAL(compareNode);
2031 result = SimplifyDoubleCompare(*compareNode);
2032 }
2033 return std::make_pair(result, std::nullopt);
2034 }
2035
Fold(BaseNode * node)2036 BaseNode *ConstantFold::Fold(BaseNode *node)
2037 {
2038 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
2039 return nullptr;
2040 }
2041 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node);
2042 BaseNode *result = PairToExpr(node->GetPrimType(), p);
2043 if (result == node) {
2044 result = nullptr;
2045 }
2046 return result;
2047 }
2048
FoldDepositbits(DepositbitsNode * node)2049 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldDepositbits(DepositbitsNode *node)
2050 {
2051 CHECK_NULL_FATAL(node);
2052 BaseNode *result = nullptr;
2053 uint8 bitsOffset = node->GetBitsOffset();
2054 uint8 bitsSize = node->GetBitsSize();
2055 std::pair<BaseNode *, std::optional<IntVal>> leftPair = DispatchFold(node->Opnd(0));
2056 std::pair<BaseNode *, std::optional<IntVal>> rightPair = DispatchFold(node->Opnd(1));
2057 ConstvalNode *leftConst = safe_cast<ConstvalNode>(leftPair.first);
2058 ConstvalNode *rightConst = safe_cast<ConstvalNode>(rightPair.first);
2059 if (leftConst != nullptr && rightConst != nullptr) {
2060 MIRIntConst *intConst0 = safe_cast<MIRIntConst>(leftConst->GetConstVal());
2061 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(rightConst->GetConstVal());
2062 ASSERT_NOT_NULL(intConst0);
2063 ASSERT_NOT_NULL(intConst1);
2064 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
2065 resultConst->SetPrimType(node->GetPrimType());
2066 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
2067 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
2068 uint64 op0ExtractVal = 0;
2069 uint64 op1ExtractVal = 0;
2070 uint64 mask0 = (1LLU << (bitsSize + bitsOffset)) - 1;
2071 uint64 mask1 = (1LLU << bitsOffset) - 1;
2072 uint64 op0Mask = ~(mask0 ^ mask1);
2073 op0ExtractVal = (intConst0->GetExtValue() & op0Mask);
2074 op1ExtractVal = (intConst1->GetExtValue() << bitsOffset) & ((1ULL << (bitsSize + bitsOffset)) - 1);
2075 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
2076 static_cast<int64>(op0ExtractVal | op1ExtractVal), constValue->GetType());
2077 resultConst->SetConstVal(constValue);
2078 result = resultConst;
2079 } else {
2080 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), leftPair);
2081 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rightPair);
2082 if (l != node->Opnd(0) || r != node->Opnd(1)) {
2083 result = mirModule->CurFuncCodeMemPool()->New<DepositbitsNode>(
2084 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), bitsOffset, bitsSize, l, r);
2085 } else {
2086 result = node;
2087 }
2088 }
2089 return std::make_pair(result, std::nullopt);
2090 }
2091
FoldArray(ArrayNode * node)2092 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldArray(ArrayNode *node)
2093 {
2094 CHECK_NULL_FATAL(node);
2095 BaseNode *result = nullptr;
2096 size_t i = 0;
2097 bool isFolded = false;
2098 ArrayNode *arrNode = mirModule->CurFuncCodeMemPool()->New<ArrayNode>(*mirModule, PrimType(node->GetPrimType()),
2099 node->GetTyIdx(), node->GetBoundsCheck());
2100 for (i = 0; i < node->GetNopndSize(); i++) {
2101 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->GetNopndAt(i));
2102 BaseNode *tmpNode = PairToExpr(node->GetNopndAt(i)->GetPrimType(), p);
2103 if (tmpNode != node->GetNopndAt(i)) {
2104 isFolded = true;
2105 }
2106 arrNode->GetNopnd().push_back(tmpNode);
2107 arrNode->SetNumOpnds(arrNode->GetNumOpnds() + 1);
2108 }
2109 if (isFolded) {
2110 result = arrNode;
2111 } else {
2112 result = node;
2113 }
2114 return std::make_pair(result, std::nullopt);
2115 }
2116
FoldTernary(TernaryNode * node)2117 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldTernary(TernaryNode *node)
2118 {
2119 CHECK_NULL_FATAL(node);
2120 constexpr size_t kFirst = 0;
2121 constexpr size_t kSecond = 1;
2122 constexpr size_t kThird = 2;
2123 BaseNode *result = node;
2124 std::vector<PrimType> primTypes;
2125 std::vector<std::pair<BaseNode *, std::optional<IntVal>>> p;
2126 for (size_t i = 0; i < node->NumOpnds(); i++) {
2127 BaseNode *tempNopnd = node->Opnd(i);
2128 CHECK_NULL_FATAL(tempNopnd);
2129 primTypes.push_back(tempNopnd->GetPrimType());
2130 p.push_back(DispatchFold(tempNopnd));
2131 }
2132 if (node->GetOpCode() == OP_select) {
2133 ConstvalNode *const0 = safe_cast<ConstvalNode>(p[kFirst].first);
2134 if (const0 != nullptr) {
2135 MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0->GetConstVal());
2136 ASSERT_NOT_NULL(intConst0);
2137 // Selecting the first value if not 0, selecting the second value otherwise.
2138 if (!intConst0->IsZero()) {
2139 result = PairToExpr(primTypes[kSecond], p[kSecond]);
2140 } else {
2141 result = PairToExpr(primTypes[kThird], p[kThird]);
2142 }
2143 } else {
2144 ConstvalNode *const1 = safe_cast<ConstvalNode>(p[kSecond].first);
2145 ConstvalNode *const2 = safe_cast<ConstvalNode>(p[kThird].first);
2146 if (const1 != nullptr && const2 != nullptr) {
2147 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2148 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2149 double dconst1 = 0.0;
2150 double dconst2 = 0.0;
2151 // for fpconst
2152 if (intConst1 == nullptr || intConst2 == nullptr) {
2153 PrimType ptyp = const1->GetPrimType();
2154 if (ptyp == PTY_f64) {
2155 MIRDoubleConst *dConst1 = safe_cast<MIRDoubleConst>(const1->GetConstVal());
2156 dconst1 = dConst1->GetValue();
2157 MIRDoubleConst *dConst2 = safe_cast<MIRDoubleConst>(const2->GetConstVal());
2158 dconst2 = dConst2->GetValue();
2159 } else if (ptyp == PTY_f32) {
2160 MIRFloatConst *fConst1 = safe_cast<MIRFloatConst>(const1->GetConstVal());
2161 dconst1 = static_cast<double>(fConst1->GetFloatValue());
2162 MIRFloatConst *fConst2 = safe_cast<MIRFloatConst>(const2->GetConstVal());
2163 dconst2 = static_cast<double>(fConst2->GetFloatValue());
2164 }
2165 } else {
2166 dconst1 = static_cast<double>(intConst1->GetExtValue());
2167 dconst2 = static_cast<double>(intConst2->GetExtValue());
2168 }
2169 PrimType foldedPrimType = primTypes[kSecond];
2170 if (!IsPrimitiveInteger(foldedPrimType)) {
2171 foldedPrimType = primTypes[kThird];
2172 }
2173 if (dconst1 == 1.0 && dconst2 == 0.0) {
2174 if (IsPrimitiveInteger(foldedPrimType)) {
2175 result = PairToExpr(foldedPrimType, p[0]);
2176 } else {
2177 result = PairToExpr(primTypes[0], p[0]);
2178 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2179 result);
2180 }
2181 return std::make_pair(result, std::nullopt);
2182 }
2183 if (dconst1 == 0.0 && dconst2 == 1.0) {
2184 BaseNode *lnot = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2185 OP_eq, primTypes[0], primTypes[0], PairToExpr(primTypes[0], p[0]),
2186 mirModule->GetMIRBuilder()->CreateIntConst(0, primTypes[0]));
2187 std::pair<BaseNode *, std::optional<IntVal>> pairTemp = DispatchFold(lnot);
2188 if (IsPrimitiveInteger(foldedPrimType)) {
2189 result = PairToExpr(foldedPrimType, pairTemp);
2190 } else {
2191 result = PairToExpr(primTypes[0], pairTemp);
2192 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2193 result);
2194 }
2195 return std::make_pair(result, std::nullopt);
2196 }
2197 }
2198 }
2199 }
2200 BaseNode *e0 = PairToExpr(primTypes[kFirst], p[kFirst]);
2201 BaseNode *e1 = PairToExpr(primTypes[kSecond], p[kSecond]);
2202 BaseNode *e2 = PairToExpr(primTypes[kThird], p[kThird]); // count up to 3 for ternary node
2203 if (e0 != node->Opnd(kFirst) || e1 != node->Opnd(kSecond) || e2 != node->Opnd(kThird)) {
2204 result = mirModule->CurFuncCodeMemPool()->New<TernaryNode>(Opcode(node->GetOpCode()),
2205 PrimType(node->GetPrimType()), e0, e1, e2);
2206 }
2207 return std::make_pair(result, std::nullopt);
2208 }
2209
SimplifyDassign(DassignNode * node)2210 StmtNode *ConstantFold::SimplifyDassign(DassignNode *node)
2211 {
2212 CHECK_NULL_FATAL(node);
2213 BaseNode *returnValue = nullptr;
2214 returnValue = Fold(node->GetRHS());
2215 if (returnValue != nullptr) {
2216 node->SetRHS(returnValue);
2217 }
2218 return node;
2219 }
2220
SimplifyIassignWithAddrofBaseNode(IassignNode & node,const AddrofNode & base)2221 StmtNode *ConstantFold::SimplifyIassignWithAddrofBaseNode(IassignNode &node, const AddrofNode &base)
2222 {
2223 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2224 if (!mirTypeOfIass->IsMIRPtrType()) {
2225 return &node;
2226 }
2227 auto *iassPtType = static_cast<MIRPtrType *>(mirTypeOfIass);
2228
2229 MIRSymbol *lhsSym = mirModule->CurFunction()->GetLocalOrGlobalSymbol(base.GetStIdx());
2230 TyIdx lhsTyIdx = lhsSym->GetTyIdx();
2231 if (base.GetFieldID() != 0) {
2232 auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(lhsTyIdx);
2233 if (!mirType->IsStructType()) {
2234 return &node;
2235 }
2236 lhsTyIdx = static_cast<MIRStructType *>(mirType)->GetFieldType(base.GetFieldID())->GetTypeIndex();
2237 }
2238 if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2239 DassignNode *dassignNode = mirModule->CurFuncCodeMemPool()->New<DassignNode>();
2240 dassignNode->SetStIdx(base.GetStIdx());
2241 dassignNode->SetRHS(node.GetRHS());
2242 dassignNode->SetFieldID(base.GetFieldID() + node.GetFieldID());
2243 // reuse stmtid to maintain stmtFreqs if profileUse is on
2244 dassignNode->SetStmtID(node.GetStmtID());
2245 return dassignNode;
2246 }
2247 return &node;
2248 }
2249
SimplifyIassignWithIaddrofBaseNode(IassignNode & node,const IaddrofNode & base)2250 StmtNode *ConstantFold::SimplifyIassignWithIaddrofBaseNode(IassignNode &node, const IaddrofNode &base)
2251 {
2252 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2253 if (!mirTypeOfIass->IsMIRPtrType()) {
2254 return &node;
2255 }
2256 auto *iassPtType = static_cast<MIRPtrType *>(mirTypeOfIass);
2257
2258 if (base.GetFieldID() == 0) {
2259 // this iaddrof is redundant
2260 node.SetAddrExpr(base.Opnd(0));
2261 return &node;
2262 }
2263
2264 auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(base.GetTyIdx());
2265 if (!mirType->IsMIRPtrType()) {
2266 return &node;
2267 }
2268 auto *iaddrofPtType = static_cast<MIRPtrType *>(mirType);
2269
2270 MIRStructType *lhsStructTy =
2271 static_cast<MIRStructType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iaddrofPtType->GetPointedTyIdx()));
2272 TyIdx lhsTyIdx = lhsStructTy->GetFieldType(base.GetFieldID())->GetTypeIndex();
2273 if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2274 // eliminate the iaddrof by updating the iassign's fieldID and tyIdx
2275 node.SetFieldID(node.GetFieldID() + base.GetFieldID());
2276 node.SetTyIdx(base.GetTyIdx());
2277 node.SetOpnd(base.Opnd(0), 0);
2278 // recursive call for the new iassign
2279 return SimplifyIassign(&node);
2280 }
2281 return &node;
2282 }
2283
SimplifyIassign(IassignNode * node)2284 StmtNode *ConstantFold::SimplifyIassign(IassignNode *node)
2285 {
2286 CHECK_NULL_FATAL(node);
2287 BaseNode *returnValue = nullptr;
2288 returnValue = Fold(node->Opnd(0));
2289 if (returnValue != nullptr) {
2290 node->SetOpnd(returnValue, 0);
2291 }
2292 returnValue = Fold(node->GetRHS());
2293 if (returnValue != nullptr) {
2294 node->SetRHS(returnValue);
2295 }
2296 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
2297 if (!mirTypeOfIass->IsMIRPtrType()) {
2298 return node;
2299 }
2300
2301 auto *opnd = node->Opnd(0);
2302 ASSERT_NOT_NULL(opnd);
2303 switch (opnd->GetOpCode()) {
2304 case OP_addrof: {
2305 return SimplifyIassignWithAddrofBaseNode(*node, static_cast<AddrofNode &>(*opnd));
2306 }
2307 case OP_iaddrof: {
2308 return SimplifyIassignWithIaddrofBaseNode(*node, static_cast<IreadNode &>(*opnd));
2309 }
2310 default:
2311 break;
2312 }
2313 return node;
2314 }
2315
SimplifyCondGoto(CondGotoNode * node)2316 StmtNode *ConstantFold::SimplifyCondGoto(CondGotoNode *node)
2317 {
2318 CHECK_NULL_FATAL(node);
2319 // optimize condgoto need to update frequency, skip here
2320 if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2321 return node;
2322 }
2323 BaseNode *returnValue = nullptr;
2324 returnValue = Fold(node->Opnd(0));
2325 returnValue = (returnValue == nullptr) ? node : returnValue;
2326 if (returnValue == node && node->Opnd(0)->GetOpCode() == OP_select) {
2327 return SimplifyCondGotoSelect(node);
2328 } else {
2329 if (returnValue != node) {
2330 node->SetOpnd(returnValue, 0);
2331 }
2332 ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2333 if (cst == nullptr) {
2334 return node;
2335 }
2336 MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2337 ASSERT_NOT_NULL(intConst);
2338 if ((node->GetOpCode() == OP_brtrue && !intConst->IsZero()) ||
2339 (node->GetOpCode() == OP_brfalse && intConst->IsZero())) {
2340 uint32 freq = mirModule->CurFunction()->GetFreqFromLastStmt(node->GetStmtID());
2341 GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2342 gotoNode->SetOffset(node->GetOffset());
2343 if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2344 gotoNode->SetStmtID(node->GetStmtID()); // reuse condnode stmtid
2345 }
2346 mirModule->CurFunction()->SetLastFreqMap(gotoNode->GetStmtID(), freq);
2347 return gotoNode;
2348 } else {
2349 return nullptr;
2350 }
2351 }
2352 return node;
2353 }
2354
SimplifyCondGotoSelect(CondGotoNode * node) const2355 StmtNode *ConstantFold::SimplifyCondGotoSelect(CondGotoNode *node) const
2356 {
2357 CHECK_NULL_FATAL(node);
2358 TernaryNode *sel = static_cast<TernaryNode *>(node->Opnd(0));
2359 if (sel == nullptr || sel->GetOpCode() != OP_select) {
2360 return node;
2361 }
2362 ConstvalNode *const1 = safe_cast<ConstvalNode>(sel->Opnd(1));
2363 ConstvalNode *const2 = safe_cast<ConstvalNode>(sel->Opnd(2));
2364 if (const1 != nullptr && const2 != nullptr) {
2365 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2366 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2367 ASSERT_NOT_NULL(intConst1);
2368 ASSERT_NOT_NULL(intConst2);
2369 if (intConst1->GetValue() == 1 && intConst2->GetValue() == 0) {
2370 node->SetOpnd(sel->Opnd(0), 0);
2371 } else if (intConst1->GetValue() == 0 && intConst2->GetValue() == 1) {
2372 node->SetOpCode((node->GetOpCode() == OP_brfalse) ? OP_brtrue : OP_brfalse);
2373 node->SetOpnd(sel->Opnd(0), 0);
2374 }
2375 }
2376 return node;
2377 }
2378
SimplifySwitch(SwitchNode * node)2379 StmtNode *ConstantFold::SimplifySwitch(SwitchNode *node)
2380 {
2381 CHECK_NULL_FATAL(node);
2382 BaseNode *returnValue = nullptr;
2383 returnValue = Fold(node->GetSwitchOpnd());
2384 if (returnValue != nullptr) {
2385 node->SetSwitchOpnd(returnValue);
2386 ConstvalNode *cst = safe_cast<ConstvalNode>(node->GetSwitchOpnd());
2387 if (cst == nullptr) {
2388 return node;
2389 }
2390 MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2391 ASSERT_NOT_NULL(intConst);
2392 GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2393 bool isdefault = true;
2394 for (unsigned i = 0; i < node->GetSwitchTable().size(); i++) {
2395 if (node->GetCasePair(i).first == intConst->GetValue()) {
2396 isdefault = false;
2397 gotoNode->SetOffset((LabelIdx)node->GetCasePair(i).second);
2398 break;
2399 }
2400 }
2401 if (isdefault) {
2402 gotoNode->SetOffset(node->GetDefaultLabel());
2403 }
2404 return gotoNode;
2405 }
2406 return node;
2407 }
2408
SimplifyUnary(UnaryStmtNode * node)2409 StmtNode *ConstantFold::SimplifyUnary(UnaryStmtNode *node)
2410 {
2411 CHECK_NULL_FATAL(node);
2412 BaseNode *returnValue = nullptr;
2413 if (node->Opnd(0) == nullptr) {
2414 return node;
2415 }
2416 returnValue = Fold(node->Opnd(0));
2417 if (returnValue != nullptr) {
2418 node->SetOpnd(returnValue, 0);
2419 }
2420 return node;
2421 }
2422
SimplifyBinary(BinaryStmtNode * node)2423 StmtNode *ConstantFold::SimplifyBinary(BinaryStmtNode *node)
2424 {
2425 CHECK_NULL_FATAL(node);
2426 BaseNode *returnValue = nullptr;
2427 returnValue = Fold(node->GetBOpnd(0));
2428 if (returnValue != nullptr) {
2429 node->SetBOpnd(returnValue, 0);
2430 }
2431 returnValue = Fold(node->GetBOpnd(1));
2432 if (returnValue != nullptr) {
2433 node->SetBOpnd(returnValue, 1);
2434 }
2435 return node;
2436 }
2437
SimplifyBlock(BlockNode * node)2438 StmtNode *ConstantFold::SimplifyBlock(BlockNode *node)
2439 {
2440 CHECK_NULL_FATAL(node);
2441 if (node->GetFirst() == nullptr) {
2442 return node;
2443 }
2444 StmtNode *s = node->GetFirst();
2445 StmtNode *prevStmt = nullptr;
2446 do {
2447 StmtNode *returnValue = Simplify(s);
2448 if (returnValue != nullptr) {
2449 if (returnValue->GetOpCode() == OP_block) {
2450 BlockNode *blk = static_cast<BlockNode *>(returnValue);
2451 node->ReplaceStmtWithBlock(*s, *blk);
2452 } else {
2453 node->ReplaceStmt1WithStmt2(s, returnValue);
2454 }
2455 prevStmt = s;
2456 s = s->GetNext();
2457 } else {
2458 // delete s from block
2459 StmtNode *nextStmt = s->GetNext();
2460 if (s == node->GetFirst()) {
2461 node->SetFirst(nextStmt);
2462 if (nextStmt != nullptr) {
2463 nextStmt->SetPrev(nullptr);
2464 }
2465 } else {
2466 CHECK_NULL_FATAL(prevStmt);
2467 prevStmt->SetNext(nextStmt);
2468 if (nextStmt != nullptr) {
2469 nextStmt->SetPrev(prevStmt);
2470 }
2471 }
2472 if (s == node->GetLast()) {
2473 node->SetLast(prevStmt);
2474 }
2475 s = nextStmt;
2476 }
2477 } while (s != nullptr);
2478 return node;
2479 }
2480
SimplifyAsm(AsmNode * node)2481 StmtNode *ConstantFold::SimplifyAsm(AsmNode *node)
2482 {
2483 CHECK_NULL_FATAL(node);
2484 /* fold constval in input */
2485 for (size_t i = 0; i < node->NumOpnds(); i++) {
2486 const std::string &str = GlobalTables::GetUStrTable().GetStringFromStrIdx(node->inputConstraints[i]);
2487 if (str == "i") {
2488 std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(i));
2489 node->SetOpnd(p.first, i);
2490 continue;
2491 }
2492 }
2493 return node;
2494 }
2495
SimplifyIf(IfStmtNode * node)2496 StmtNode *ConstantFold::SimplifyIf(IfStmtNode *node)
2497 {
2498 CHECK_NULL_FATAL(node);
2499 BaseNode *returnValue = nullptr;
2500 (void)Simplify(node->GetThenPart());
2501 if (node->GetElsePart()) {
2502 (void)Simplify(node->GetElsePart());
2503 }
2504 returnValue = Fold(node->Opnd());
2505 if (returnValue != nullptr) {
2506 node->SetOpnd(returnValue, 0);
2507 ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd());
2508 // do not delete c/c++ dead if-body here
2509 if (cst == nullptr || (!mirModule->IsJavaModule())) {
2510 return node;
2511 }
2512 MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2513 ASSERT_NOT_NULL(intConst);
2514 if (intConst->GetValue() == 0) {
2515 return node->GetElsePart();
2516 } else {
2517 return node->GetThenPart();
2518 }
2519 }
2520 return node;
2521 }
2522
SimplifyWhile(WhileStmtNode * node)2523 StmtNode *ConstantFold::SimplifyWhile(WhileStmtNode *node)
2524 {
2525 CHECK_NULL_FATAL(node);
2526 BaseNode *returnValue = nullptr;
2527 if (node->Opnd(0) == nullptr) {
2528 return node;
2529 }
2530 if (node->GetBody()) {
2531 (void)Simplify(node->GetBody());
2532 }
2533 returnValue = Fold(node->Opnd(0));
2534 if (returnValue != nullptr) {
2535 node->SetOpnd(returnValue, 0);
2536 ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2537 // do not delete c/c++ dead while-body here
2538 if (cst == nullptr || (!mirModule->IsJavaModule())) {
2539 return node;
2540 }
2541 if (cst->GetConstVal()->IsZero()) {
2542 if (OP_while == node->GetOpCode()) {
2543 return nullptr;
2544 } else {
2545 return node->GetBody();
2546 }
2547 }
2548 }
2549 return node;
2550 }
2551
SimplifyNary(NaryStmtNode * node)2552 StmtNode *ConstantFold::SimplifyNary(NaryStmtNode *node)
2553 {
2554 CHECK_NULL_FATAL(node);
2555 BaseNode *returnValue = nullptr;
2556 for (size_t i = 0; i < node->NumOpnds(); i++) {
2557 returnValue = Fold(node->GetNopndAt(i));
2558 if (returnValue != nullptr) {
2559 node->SetNOpndAt(i, returnValue);
2560 }
2561 }
2562 return node;
2563 }
2564
SimplifyIcall(IcallNode * node)2565 StmtNode *ConstantFold::SimplifyIcall(IcallNode *node)
2566 {
2567 CHECK_NULL_FATAL(node);
2568 BaseNode *returnValue = nullptr;
2569 for (size_t i = 0; i < node->NumOpnds(); i++) {
2570 returnValue = Fold(node->GetNopndAt(i));
2571 if (returnValue != nullptr) {
2572 node->SetNOpndAt(i, returnValue);
2573 }
2574 }
2575 // icall node transform to call node
2576 CHECK_FATAL(node->GetNopnd().empty() == false, "container check");
2577 switch (node->GetNopndAt(0)->GetOpCode()) {
2578 case OP_addroffunc: {
2579 AddroffuncNode *addrofNode = static_cast<AddroffuncNode *>(node->GetNopndAt(0));
2580 CallNode *callNode = mirModule->CurFuncCodeMemPool()->New<CallNode>(
2581 *mirModule,
2582 (node->GetOpCode() == OP_icall || node->GetOpCode() == OP_icallproto) ? OP_call : OP_callassigned);
2583 if (node->GetOpCode() == OP_icallassigned || node->GetOpCode() == OP_icallprotoassigned) {
2584 callNode->SetReturnVec(node->GetReturnVec());
2585 }
2586 callNode->SetPUIdx(addrofNode->GetPUIdx());
2587 for (size_t i = 1; i < node->GetNopndSize(); i++) {
2588 callNode->GetNopnd().push_back(node->GetNopndAt(i));
2589 }
2590 callNode->SetNumOpnds(callNode->GetNopndSize());
2591 // reuse stmtID to skip update stmtFreqs when profileUse is on
2592 callNode->SetStmtID(node->GetStmtID());
2593 return callNode;
2594 }
2595 default:
2596 break;
2597 }
2598 return node;
2599 }
2600
ProcessFunc(MIRFunction * func)2601 void ConstantFold::ProcessFunc(MIRFunction *func)
2602 {
2603 if (func->IsEmpty()) {
2604 return;
2605 }
2606 mirModule->SetCurFunction(func);
2607 (void)Simplify(func->GetBody());
2608 }
2609
2610 } // namespace maple
2611