• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28 
29 namespace maple {
30 
31 namespace {
32 
33 constexpr uint64 kJsTypeNumber = 4;
34 constexpr uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32;  // set high 32 bit as JSTYPE_NUMBER
35 constexpr uint32 kByteSizeOfBit64 = 8;                            // byte number for 64 bit
36 constexpr uint32 kBitSizePerByte = 8;
37 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
38 
39 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
40 
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)41 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
42 {
43     if (!v1 && !v2) {
44         return std::nullopt;
45     }
46 
47     // Perform all calculations in terms of the maximum available signed type.
48     // The value will be truncated for an appropriate type when constant is created in PairToExpr function
49     return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
50 }
51 
52 // Perform all calculations in terms of the maximum available signed type.
53 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)54 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
55 {
56     if (!v1 && !v2) {
57         return std::nullopt;
58     }
59 
60     if (v1 && v2) {
61         return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
62     }
63 
64     if (v1) {
65         return v1->TruncOrExtend(PTY_i64);
66     }
67 
68     // !v1 && v2
69     return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
70 }
71 
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)72 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
73 {
74     return AddSub(v1, v2, true);
75 }
76 
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)77 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
78 {
79     return AddSub(v1, v2, false);
80 }
81 
82 }  // anonymous namespace
83 
84 // This phase is designed to achieve compiler optimization by
85 // simplifying constant expressions. The constant expression
86 // is evaluated and replaced by the value calculated on compile
87 // time to save time on runtime.
88 //
89 // The main procedure shows as following:
90 // A. Analyze expression type
91 // B. Analysis operator type
92 // C. Replace the expression with the result of the operation
93 
94 // true if the constant's bits are made of only one group of contiguous 1's
95 // starting at bit 0
ContiguousBitsOf1(uint64 x)96 static bool ContiguousBitsOf1(uint64 x)
97 {
98     if (x == 0) {
99         return false;
100     }
101     return (~x & (x + 1)) == (x + 1);
102 }
103 
IsPowerOf2(uint64 num)104 inline bool IsPowerOf2(uint64 num)
105 {
106     if (num == 0) {
107         return false;
108     }
109     return (~(num - 1) & num) == num;
110 }
111 
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const112 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
113                                         BaseNode *rhs) const
114 {
115     CHECK_NULL_FATAL(old);
116     BinaryNode *result = nullptr;
117     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
118         result = old;
119     } else {
120         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
121     }
122     return result;
123 }
124 
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const125 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
126 {
127     CHECK_NULL_FATAL(old);
128     UnaryNode *result = nullptr;
129     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
130         result = old;
131     } else {
132         result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
133     }
134     return result;
135 }
136 
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const137 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
138 {
139     CHECK_NULL_FATAL(pair.first);
140     BaseNode *result = pair.first;
141     if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
142         return result;
143     }
144     if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
145         // -a, 5 -> 5 - a
146         ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
147             static_cast<uint64>(pair.second->GetExtValue()), resultType);
148         BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
149         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
150     } else {
151         if ((!pair.second->GetSignBit() &&
152             pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
153             pair.second->TruncOrExtend(resultType).IsMinValue() ||
154             pair.second->GetSXTValue() == INT64_MIN) {
155             // +-a, 5 -> a + 5
156             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
157                 static_cast<uint64>(pair.second->GetExtValue()), resultType);
158             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
159         } else {
160             // +-a, -5 -> a + -5
161             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
162                 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
163             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
164         }
165     }
166     return result;
167 }
168 
FoldBase(BaseNode * node) const169 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
170 {
171     return std::make_pair(node, std::nullopt);
172 }
173 
Simplify(StmtNode * node)174 StmtNode *ConstantFold::Simplify(StmtNode *node)
175 {
176     CHECK_NULL_FATAL(node);
177     switch (node->GetOpCode()) {
178         case OP_dassign:
179         case OP_maydassign:
180             return SimplifyDassign(static_cast<DassignNode*>(node));
181         case OP_iassign:
182             return SimplifyIassign(static_cast<IassignNode*>(node));
183         case OP_block:
184             return SimplifyBlock(static_cast<BlockNode*>(node));
185         case OP_if:
186             return SimplifyIf(static_cast<IfStmtNode*>(node));
187         case OP_dowhile:
188         case OP_while:
189             return SimplifyWhile(static_cast<WhileStmtNode*>(node));
190         case OP_switch:
191             return SimplifySwitch(static_cast<SwitchNode*>(node));
192         case OP_eval:
193         case OP_throw:
194         case OP_free:
195         case OP_decref:
196         case OP_incref:
197         case OP_decrefreset:
198         case OP_regassign:
199             CASE_OP_ASSERT_NONNULL
200         case OP_igoto:
201             return SimplifyUnary(static_cast<UnaryStmtNode*>(node));
202         case OP_brfalse:
203         case OP_brtrue:
204             return SimplifyCondGoto(static_cast<CondGotoNode*>(node));
205         case OP_return:
206         case OP_syncenter:
207         case OP_syncexit:
208         case OP_call:
209         case OP_virtualcall:
210         case OP_superclasscall:
211         case OP_interfacecall:
212         case OP_customcall:
213         case OP_polymorphiccall:
214         case OP_intrinsiccall:
215         case OP_xintrinsiccall:
216         case OP_intrinsiccallwithtype:
217         case OP_callassigned:
218         case OP_virtualcallassigned:
219         case OP_superclasscallassigned:
220         case OP_interfacecallassigned:
221         case OP_customcallassigned:
222         case OP_polymorphiccallassigned:
223         case OP_intrinsiccallassigned:
224         case OP_intrinsiccallwithtypeassigned:
225         case OP_xintrinsiccallassigned:
226         case OP_callinstant:
227         case OP_callinstantassigned:
228         case OP_virtualcallinstant:
229         case OP_virtualcallinstantassigned:
230         case OP_superclasscallinstant:
231         case OP_superclasscallinstantassigned:
232         case OP_interfacecallinstant:
233         case OP_interfacecallinstantassigned:
234             CASE_OP_ASSERT_BOUNDARY
235             return SimplifyNary(static_cast<NaryStmtNode*>(node));
236         case OP_icall:
237         case OP_icallassigned:
238         case OP_icallproto:
239         case OP_icallprotoassigned:
240             return SimplifyIcall(static_cast<IcallNode*>(node));
241         case OP_asm:
242             return SimplifyAsm(static_cast<AsmNode*>(node));
243         default:
244             return node;
245     }
246 }
247 
DispatchFold(BaseNode * node)248 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
249 {
250     CHECK_NULL_FATAL(node);
251     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
252         return {node, std::nullopt};
253     }
254     switch (node->GetOpCode()) {
255         case OP_sizeoftype:
256             return FoldSizeoftype(static_cast<SizeoftypeNode*>(node));
257         case OP_abs:
258         case OP_bnot:
259         case OP_lnot:
260         case OP_neg:
261         case OP_recip:
262         case OP_sqrt:
263             return FoldUnary(static_cast<UnaryNode*>(node));
264         case OP_ceil:
265         case OP_floor:
266         case OP_round:
267         case OP_trunc:
268         case OP_cvt:
269             return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
270         case OP_sext:
271         case OP_zext:
272         case OP_extractbits:
273             return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
274         case OP_iaddrof:
275         case OP_iread:
276             return FoldIread(static_cast<IreadNode*>(node));
277         case OP_add:
278         case OP_ashr:
279         case OP_band:
280         case OP_bior:
281         case OP_bxor:
282         case OP_cand:
283         case OP_cior:
284         case OP_div:
285         case OP_land:
286         case OP_lior:
287         case OP_lshr:
288         case OP_max:
289         case OP_min:
290         case OP_mul:
291         case OP_rem:
292         case OP_shl:
293         case OP_sub:
294             return FoldBinary(static_cast<BinaryNode*>(node));
295         case OP_eq:
296         case OP_ne:
297         case OP_ge:
298         case OP_gt:
299         case OP_le:
300         case OP_lt:
301         case OP_cmp:
302             return FoldCompare(static_cast<CompareNode*>(node));
303         case OP_depositbits:
304             return FoldDepositbits(static_cast<DepositbitsNode*>(node));
305         case OP_select:
306             return FoldTernary(static_cast<TernaryNode*>(node));
307         case OP_array:
308             return FoldArray(static_cast<ArrayNode*>(node));
309         case OP_retype:
310             return FoldRetype(static_cast<RetypeNode*>(node));
311         case OP_gcmallocjarray:
312         case OP_gcpermallocjarray:
313             return FoldGcmallocjarray(static_cast<JarrayMallocNode*>(node));
314         default:
315             return FoldBase(static_cast<BaseNode*>(node));
316     }
317 }
318 
Negate(BaseNode * node) const319 BaseNode *ConstantFold::Negate(BaseNode *node) const
320 {
321     CHECK_NULL_FATAL(node);
322     return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
323 }
324 
Negate(UnaryNode * node) const325 BaseNode *ConstantFold::Negate(UnaryNode *node) const
326 {
327     CHECK_NULL_FATAL(node);
328     BaseNode *result = nullptr;
329     if (node->GetOpCode() == OP_neg) {
330         result = static_cast<BaseNode*>(node->Opnd(0));
331     } else {
332         BaseNode *n = static_cast<BaseNode*>(node);
333         result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
334     }
335     return result;
336 }
337 
Negate(const ConstvalNode * node) const338 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
339 {
340     CHECK_NULL_FATAL(node);
341     ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
342     CHECK_NULL_FATAL(copy);
343     copy->GetConstVal()->Neg();
344     return copy;
345 }
346 
NegateTree(BaseNode * node) const347 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
348 {
349     CHECK_NULL_FATAL(node);
350     if (node->IsUnaryNode()) {
351         return Negate(static_cast<UnaryNode*>(node));
352     } else if (node->GetOpCode() == OP_constval) {
353         return Negate(static_cast<ConstvalNode*>(node));
354     } else {
355         return Negate(static_cast<BaseNode*>(node));
356     }
357 }
358 
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const359 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
360                                                           const MIRIntConst &intConst0,
361                                                           const MIRIntConst &intConst1) const
362 {
363     uint64 result = 0;
364 
365     bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
366     bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
367     bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
368 
369     switch (opcode) {
370         case OP_eq: {
371             result = equal;
372             break;
373         }
374         case OP_ge: {
375             result = (greater || equal) ? 1 : 0;
376             break;
377         }
378         case OP_gt: {
379             result = greater;
380             break;
381         }
382         case OP_le: {
383             result = (less || equal) ? 1 : 0;
384             break;
385         }
386         case OP_lt: {
387             result = less;
388             break;
389         }
390         case OP_ne: {
391             result = !equal;
392             break;
393         }
394         case OP_cmp: {
395             if (greater) {
396                 result = kGreater;
397             } else if (equal) {
398                 result = kEqual;
399             } else if (less) {
400                 result = static_cast<uint64>(kLess);
401             }
402             break;
403         }
404         default:
405             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
406             break;
407     }
408     // determine the type
409     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
410     // form the constant
411     MIRIntConst *constValue = nullptr;
412     if (type.GetPrimType() == PTY_dyni32) {
413         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
414         constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | result));
415     } else {
416         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
417     }
418     return constValue;
419 }
420 
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const421 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
422                                                    const ConstvalNode &const0, const ConstvalNode &const1) const
423 {
424     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
425     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
426     CHECK_NULL_FATAL(intConst0);
427     CHECK_NULL_FATAL(intConst1);
428     MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
429     // form the ConstvalNode
430     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
431     resultConst->SetPrimType(resultType);
432     resultConst->SetConstVal(constValue);
433     return resultConst;
434 }
435 
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst & intConst0,const MIRIntConst & intConst1)436 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
437                                                    const MIRIntConst &intConst1)
438 {
439     IntVal intVal0 = intConst0.GetValue();
440     IntVal intVal1 = intConst1.GetValue();
441     IntVal result(static_cast<uint64>(0), resultType);
442 
443     switch (opcode) {
444         case OP_add: {
445             result = intVal0.Add(intVal1, resultType);
446             break;
447         }
448         case OP_sub: {
449             result = intVal0.Sub(intVal1, resultType);
450             break;
451         }
452         case OP_mul: {
453             result = intVal0.Mul(intVal1, resultType);
454             break;
455         }
456         case OP_div: {
457             result = intVal0.Div(intVal1, resultType);
458             break;
459         }
460         case OP_rem: {
461             result = intVal0.Rem(intVal1, resultType);
462             break;
463         }
464         case OP_ashr: {
465             result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
466             break;
467         }
468         case OP_lshr: {
469             result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
470             break;
471         }
472         case OP_shl: {
473             result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
474             break;
475         }
476         case OP_max: {
477             result = Max(intVal0, intVal1, resultType);
478             break;
479         }
480         case OP_min: {
481             result = Min(intVal0, intVal1, resultType);
482             break;
483         }
484         case OP_band: {
485             result = intVal0.And(intVal1, resultType);
486             break;
487         }
488         case OP_bior: {
489             result = intVal0.Or(intVal1, resultType);
490             break;
491         }
492         case OP_bxor: {
493             result = intVal0.Xor(intVal1, resultType);
494             break;
495         }
496         case OP_cand:
497         case OP_land: {
498             result = IntVal(intVal0.GetExtValue() && intVal1.GetExtValue(), resultType);
499             break;
500         }
501         case OP_cior:
502         case OP_lior: {
503             result = IntVal(intVal0.GetExtValue() || intVal1.GetExtValue(), resultType);
504             break;
505         }
506         case OP_depositbits: {
507             // handled in FoldDepositbits
508             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstBinary");
509             break;
510         }
511         default:
512             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
513             break;
514     }
515     // determine the type
516     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
517     // form the constant
518     MIRIntConst *constValue = nullptr;
519     if (type.GetPrimType() == PTY_dyni32) {
520         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
521         constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | static_cast<uint64>(result.GetExtValue())));
522     } else {
523         constValue =
524             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
525     }
526     return constValue;
527 }
528 
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const529 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
530                                                const ConstvalNode &const1) const
531 {
532     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
533     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
534     CHECK_NULL_FATAL(intConst0);
535     CHECK_NULL_FATAL(intConst1);
536     MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
537     // form the ConstvalNode
538     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
539     resultConst->SetPrimType(resultType);
540     resultConst->SetConstVal(constValue);
541     return resultConst;
542 }
543 
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const544 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
545                                               const ConstvalNode &const1) const
546 {
547     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
548     const MIRDoubleConst *doubleConst0 = nullptr;
549     const MIRDoubleConst *doubleConst1 = nullptr;
550     const MIRFloatConst *floatConst0 = nullptr;
551     const MIRFloatConst *floatConst1 = nullptr;
552     bool useDouble = (const0.GetPrimType() == PTY_f64);
553     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
554     resultConst->SetPrimType(resultType);
555     if (useDouble) {
556         doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
557         doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
558         CHECK_NULL_FATAL(doubleConst0);
559         CHECK_NULL_FATAL(doubleConst1);
560     } else {
561         floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
562         floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
563         CHECK_NULL_FATAL(floatConst0);
564         CHECK_NULL_FATAL(floatConst1);
565     }
566     float constValueFloat = 0.0;
567     double constValueDouble = 0.0;
568     switch (opcode) {
569         case OP_add: {
570             if (useDouble) {
571                 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
572             } else {
573                 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
574             }
575             break;
576         }
577         case OP_sub: {
578             if (useDouble) {
579                 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
580             } else {
581                 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
582             }
583             break;
584         }
585         case OP_mul: {
586             if (useDouble) {
587                 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
588             } else {
589                 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
590             }
591             break;
592         }
593         case OP_div: {
594             // for floats div by 0 is well defined
595             if (useDouble) {
596                 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
597             } else {
598                 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
599             }
600             break;
601         }
602         case OP_max: {
603             if (useDouble) {
604                 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
605                                                                                         : doubleConst1->GetValue();
606             } else {
607                 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
608                                                                                     : floatConst1->GetValue();
609             }
610             break;
611         }
612         case OP_min: {
613             if (useDouble) {
614                 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
615                                                                                         : doubleConst1->GetValue();
616             } else {
617                 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
618                                                                                     : floatConst1->GetValue();
619             }
620             break;
621         }
622         case OP_rem:
623         case OP_ashr:
624         case OP_lshr:
625         case OP_shl:
626         case OP_band:
627         case OP_bior:
628         case OP_bxor:
629         case OP_cand:
630         case OP_land:
631         case OP_cior:
632         case OP_lior:
633         case OP_depositbits: {
634             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
635             break;
636         }
637         default:
638             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
639             break;
640     }
641     if (resultType == PTY_f64) {
642         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
643     } else {
644         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
645     }
646     return resultConst;
647 }
648 
ConstValueEqual(int64 leftValue,int64 rightValue) const649 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
650 {
651     return (leftValue == rightValue);
652 }
653 
ConstValueEqual(float leftValue,float rightValue) const654 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
655 {
656     auto result = fabs(leftValue - rightValue);
657     return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
658 }
659 
ConstValueEqual(double leftValue,double rightValue) const660 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
661 {
662     auto result = fabs(leftValue - rightValue);
663     return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
664 }
665 
666 template<typename T>
FullyEqual(T leftValue,T rightValue) const667 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
668 {
669     if (std::isinf(leftValue) && std::isinf(rightValue)) {
670         // (inf == inf), add the judgement here in case of the subtraction between float type inf
671         return true;
672     } else {
673         return ConstValueEqual(leftValue, rightValue);
674     }
675 }
676 
677 template<typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const678 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
679 {
680     DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
681     typename T::value_type leftValue = leftConst->GetValue();
682     DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
683     typename T::value_type rightValue = rightConst->GetValue();
684     int64 result = 0;
685     switch (op) {
686         case OP_eq: {
687             result = FullyEqual(leftValue, rightValue);
688             break;
689         }
690         case OP_ge: {
691             result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
692             break;
693         }
694         case OP_gt: {
695             result = (leftValue > rightValue);
696             break;
697         }
698         case OP_le: {
699             result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
700             break;
701         }
702         case OP_lt: {
703             result = (leftValue < rightValue);
704             break;
705         }
706         case OP_ne: {
707             result = !FullyEqual(leftValue, rightValue);
708             break;
709         }
710         case OP_cmpl:
711         case OP_cmpg: {
712             if (std::isnan(leftValue) || std::isnan(rightValue)) {
713                 result = (op == OP_cmpg) ? kGreater : kLess;
714                 break;
715             }
716         }
717         [[clang::fallthrough]];
718         case OP_cmp: {
719             if (leftValue > rightValue) {
720                 result = kGreater;
721             } else if (FullyEqual(leftValue, rightValue)) {
722                 result = kEqual;
723             } else {
724                 result = kLess;
725             }
726             break;
727         }
728         default:
729             DEBUG_ASSERT(false, "Unknown opcode for Comparison");
730             break;
731     }
732     return result;
733 }
734 
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const735 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
736                                                          const MIRConst &leftConst, const MIRConst &rightConst) const
737 {
738     int64 result = 0;
739     bool useDouble = (opndType == PTY_f64);
740     if (useDouble) {
741         result =
742             ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
743     } else {
744         result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
745     }
746     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
747     MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
748     return resultConst;
749 }
750 
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const751 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
752                                                   const ConstvalNode &const0, const ConstvalNode &const1) const
753 {
754     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
755     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
756     resultConst->SetPrimType(resultType);
757     resultConst->SetConstVal(
758         FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
759     return resultConst;
760 }
761 
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1) const762 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
763                                                     const MIRConst &const0, const MIRConst &const1) const
764 {
765     MIRConst *returnValue = nullptr;
766     if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
767         const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
768         const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
769         ASSERT_NOT_NULL(intConst0);
770         ASSERT_NOT_NULL(intConst1);
771         returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
772     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
773         returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
774     } else {
775         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
776     }
777     return returnValue;
778 }
779 
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const780 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
781                                                 const ConstvalNode &const0, const ConstvalNode &const1) const
782 {
783     ConstvalNode *returnValue = nullptr;
784     if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
785         returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
786     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
787         returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
788     } else {
789         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
790     }
791     return returnValue;
792 }
793 
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r) const794 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
795                                                       BaseNode &l, BaseNode &r) const
796 {
797     CompareNode *result = nullptr;
798     Opcode op = opcode;
799     switch (opcode) {
800         case OP_gt: {
801             op = OP_lt;
802             break;
803         }
804         case OP_lt: {
805             op = OP_gt;
806             break;
807         }
808         case OP_ge: {
809             op = OP_le;
810             break;
811         }
812         case OP_le: {
813             op = OP_ge;
814             break;
815         }
816         case OP_eq: {
817             break;
818         }
819         case OP_ne: {
820             break;
821         }
822         default:
823             DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
824             break;
825     }
826 
827     result =
828         mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
829     return result;
830 }
831 
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const832 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
833                                             const ConstvalNode &const1) const
834 {
835     ConstvalNode *returnValue = nullptr;
836     if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
837         returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
838     } else if (resultType == PTY_f32 || resultType == PTY_f64) {
839         returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
840     } else {
841         DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
842     }
843     return returnValue;
844 }
845 
FoldIntConstUnaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * constNode)846 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
847 {
848     CHECK_NULL_FATAL(constNode);
849     IntVal result = constNode->GetValue().TruncOrExtend(resultType);
850     switch (opcode) {
851         case OP_abs: {
852             if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
853                 result = -result;
854             }
855             break;
856         }
857         case OP_bnot: {
858             result = ~result;
859             break;
860         }
861         case OP_lnot: {
862             uint64 resultInt = result == 0 ? 1 : 0;
863             result = {resultInt, resultType};
864             break;
865         }
866         case OP_neg: {
867             result = -result;
868             break;
869         }
870         case OP_sext:         // handled in FoldExtractbits
871         case OP_zext:         // handled in FoldExtractbits
872         case OP_extractbits:  // handled in FoldExtractbits
873         case OP_recip:
874         case OP_sqrt: {
875             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
876             break;
877         }
878         default:
879             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
880             break;
881     }
882     // determine the type
883     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
884     // form the constant
885     MIRIntConst *constValue = nullptr;
886     if (type.GetPrimType() == PTY_dyni32) {
887         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
888         constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | static_cast<uint64>(result.GetExtValue())));
889     } else {
890         constValue =
891             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
892     }
893     return constValue;
894 }
895 
896 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const897 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
898 {
899     CHECK_NULL_FATAL(constNode);
900     double constValue = 0;
901     T *fpCst = static_cast<T*>(constNode->GetConstVal());
902     switch (opcode) {
903         case OP_recip: {
904             constValue = typename T::value_type(1.0L / fpCst->GetValue());
905             break;
906         }
907         case OP_neg: {
908             constValue = typename T::value_type(-fpCst->GetValue());
909             break;
910         }
911         case OP_abs: {
912             constValue = typename T::value_type(fabs(fpCst->GetValue()));
913             break;
914         }
915         case OP_sqrt: {
916             constValue = typename T::value_type(sqrt(fpCst->GetValue()));
917             break;
918         }
919         case OP_bnot:
920         case OP_lnot:
921         case OP_sext:
922         case OP_zext:
923         case OP_extractbits: {
924             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
925             break;
926         }
927         default:
928             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
929             break;
930     }
931     auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
932     resultConst->SetPrimType(resultType);
933     if (resultType == PTY_f32) {
934         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
935     } else if (resultType == PTY_f64) {
936         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
937     } else {
938         CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
939     }
940     return resultConst;
941 }
942 
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode & constNode) const943 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
944 {
945     ConstvalNode *returnValue = nullptr;
946     if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
947         const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
948         auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
949         returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
950         returnValue->SetPrimType(resultType);
951         returnValue->SetConstVal(constValue);
952     } else if (resultType == PTY_f32) {
953         returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
954     } else if (resultType == PTY_f64) {
955         returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
956     } else if (resultType == PTY_f128) {
957         DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
958     } else {
959         DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
960     }
961     return returnValue;
962 }
963 
FoldSizeoftype(SizeoftypeNode * node) const964 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldSizeoftype(SizeoftypeNode *node) const
965 {
966     CHECK_NULL_FATAL(node);
967     BaseNode *result = node;
968     MIRType *argType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
969     if (argType->GetKind() == kTypeScalar) {
970         MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
971         uint32 size = GetPrimTypeSize(argType->GetPrimType());
972         ConstvalNode *constValueNode = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
973         constValueNode->SetPrimType(node->GetPrimType());
974         constValueNode->SetConstVal(GlobalTables::GetIntConstTable().GetOrCreateIntConst(size, resultType));
975         result = constValueNode;
976     }
977     return std::make_pair(result, std::nullopt);
978 }
979 
FoldRetype(RetypeNode * node)980 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
981 {
982     CHECK_NULL_FATAL(node);
983     BaseNode *result = node;
984     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
985     if (node->Opnd(0) != p.first) {
986         RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
987         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
988         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
989         result = newRetNode;
990     }
991     return std::make_pair(result, std::nullopt);
992 }
993 
FoldGcmallocjarray(JarrayMallocNode * node)994 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldGcmallocjarray(JarrayMallocNode *node)
995 {
996     CHECK_NULL_FATAL(node);
997     BaseNode *result = node;
998     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
999     if (node->Opnd(0) != p.first) {
1000         JarrayMallocNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1001         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldGcmallocjarray");
1002         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
1003         result = newRetNode;
1004     }
1005     return std::make_pair(result, std::nullopt);
1006 }
1007 
FoldUnary(UnaryNode * node)1008 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
1009 {
1010     CHECK_NULL_FATAL(node);
1011     BaseNode *result = nullptr;
1012     std::optional<IntVal> sum = std::nullopt;
1013     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1014     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1015     if (cst != nullptr) {
1016         result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
1017     } else {
1018         bool isInt = IsPrimitiveInteger(node->GetPrimType());
1019         // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
1020         // primType will be set to opnd type. There will be problems in some cases. For example:
1021         // before cf:
1022         //   neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
1023         // after cf:
1024         //   neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))  # wrong!
1025         // As a workaround, we exclude u1 opnd type
1026         if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
1027             result = NegateTree(p.first);
1028             if (result->GetOpCode() == OP_neg) {
1029                 PrimType origPtyp = node->GetPrimType();
1030                 PrimType newPtyp = result->GetPrimType();
1031                 if (newPtyp == origPtyp) {
1032                 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
1033                     // NegateTree returned an UnaryNode quivalent to `n`, so keep the
1034                     // original UnaryNode to preserve identity
1035                     result = node;
1036                 }
1037                 } else {
1038                     if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
1039                         // do not fold explicit cvt
1040                         result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
1041                             PairToExpr(node->Opnd(0)->GetPrimType(), p));
1042                         return std::make_pair(result, std::nullopt);
1043                     } else {
1044                         result->SetPrimType(origPtyp);
1045                     }
1046                 }
1047             }
1048             if (p.second) {
1049                 sum = -(*p.second);
1050             }
1051         } else {
1052             result =
1053                 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
1054         }
1055     }
1056     return std::make_pair(result, sum);
1057 }
1058 
FloatToIntOverflow(float fval,PrimType totype)1059 static bool FloatToIntOverflow(float fval, PrimType totype)
1060 {
1061     static const float safeFloatMaxToInt32 = 2147483520.0f;  // 2^31 - 128
1062     static const float safeFloatMinToInt32 = -2147483520.0f;
1063     static const float safeFloatMaxToInt64 = 9223372036854775680.0f;  // 2^63 - 128
1064     static const float safeFloatMinToInt64 = -9223372036854775680.0f;
1065     if (!std::isfinite(fval)) {
1066         return true;
1067     }
1068     if (totype == PTY_i64 || totype == PTY_u64) {
1069         if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
1070             return true;
1071         }
1072     } else {
1073         if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
1074             return true;
1075         }
1076     }
1077     return false;
1078 }
1079 
DoubleToIntOverflow(double dval,PrimType totype)1080 static bool DoubleToIntOverflow(double dval, PrimType totype)
1081 {
1082     static const double safeDoubleMaxToInt32 = 2147482624.0;  // 2^31 - 1024
1083     static const double safeDoubleMinToInt32 = -2147482624.0;
1084     static const double safeDoubleMaxToInt64 = 9223372036854774784.0;  // 2^63 - 1024
1085     static const double safeDoubleMinToInt64 = -9223372036854774784.0;
1086     if (!std::isfinite(dval)) {
1087         return true;
1088     }
1089     if (totype == PTY_i64 || totype == PTY_u64) {
1090         if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
1091             return true;
1092         }
1093     } else {
1094         if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
1095             return true;
1096         }
1097     }
1098     return false;
1099 }
1100 
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1101 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1102 {
1103     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1104     resultConst->SetPrimType(toType);
1105     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1106     if (fromType == PTY_f32) {
1107         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1108         ASSERT_NOT_NULL(constValue);
1109         float floatValue = ceil(constValue->GetValue());
1110         if (IsPrimitiveFloat(toType)) {
1111             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1112         } else if (FloatToIntOverflow(floatValue, toType)) {
1113             return nullptr;
1114         } else {
1115             resultConst->SetConstVal(
1116                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1117         }
1118     } else {
1119         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1120         ASSERT_NOT_NULL(constValue);
1121         double doubleValue = ceil(constValue->GetValue());
1122         if (IsPrimitiveFloat(toType)) {
1123             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1124         } else if (DoubleToIntOverflow(doubleValue, toType)) {
1125             return nullptr;
1126         } else {
1127             resultConst->SetConstVal(
1128                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1129         }
1130     }
1131     return resultConst;
1132 }
1133 
1134 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const1135 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
1136 {
1137     DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
1138     size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
1139     bool isSigned = IsSignedInteger(resultType.GetPrimType());
1140     int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
1141     uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
1142     int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
1143     if (isSigned && (value > max)) {
1144         return static_cast<T>(max);
1145     } else if (!isSigned && (value > umax)) {
1146         return static_cast<T>(umax);
1147     } else if (value < min) {
1148         return static_cast<T>(min);
1149     }
1150     return value;
1151 }
1152 
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const1153 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
1154 {
1155     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1156     if (fromType == PTY_f32) {
1157         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1158         float floatValue = constValue.GetValue();
1159         if (isFloor) {
1160             floatValue = floor(constValue.GetValue());
1161         }
1162         if (IsPrimitiveFloat(toType)) {
1163             return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1164         }
1165         if (FloatToIntOverflow(floatValue, toType)) {
1166             return nullptr;
1167         }
1168         floatValue = CalIntValueFromFloatValue(floatValue, resultType);
1169         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
1170     } else {
1171         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1172         double doubleValue = constValue.GetValue();
1173         if (isFloor) {
1174             doubleValue = floor(constValue.GetValue());
1175         }
1176         if (IsPrimitiveFloat(toType)) {
1177             return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1178         }
1179         if (DoubleToIntOverflow(doubleValue, toType)) {
1180             return nullptr;
1181         }
1182         doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1183         // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
1184         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1185     }
1186 }
1187 
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1188 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1189 {
1190     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1191     resultConst->SetPrimType(toType);
1192     resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1193     return resultConst;
1194 }
1195 
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1196 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1197 {
1198     if (fromType == PTY_f128 || toType == PTY_f128) {
1199         // folding while rounding float128 is not supported yet
1200         return nullptr;
1201     }
1202 
1203     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1204     if (fromType == PTY_f32) {
1205         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1206         float floatValue = round(constValue.GetValue());
1207         if (FloatToIntOverflow(floatValue, toType)) {
1208             return nullptr;
1209         }
1210         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1211     } else if (fromType == PTY_f64) {
1212         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1213         double doubleValue = round(constValue.GetValue());
1214         if (DoubleToIntOverflow(doubleValue, toType)) {
1215             return nullptr;
1216         }
1217         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1218             static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1219     } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1220         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1221         if (IsSignedInteger(fromType)) {
1222             int64 fromValue = constValue.GetExtValue();
1223             float floatValue = round(static_cast<float>(fromValue));
1224             if (static_cast<int64>(floatValue) == fromValue) {
1225                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1226             }
1227         } else {
1228             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1229             float floatValue = round(static_cast<float>(fromValue));
1230             if (static_cast<uint64>(floatValue) == fromValue) {
1231                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1232             }
1233         }
1234     } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1235         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1236         if (IsSignedInteger(fromType)) {
1237             int64 fromValue = constValue.GetExtValue();
1238             double doubleValue = round(static_cast<double>(fromValue));
1239             if (static_cast<int64>(doubleValue) == fromValue) {
1240                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1241             }
1242         } else {
1243             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1244             double doubleValue = round(static_cast<double>(fromValue));
1245             if (static_cast<uint64>(doubleValue) == fromValue) {
1246                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1247             }
1248         }
1249     }
1250     return nullptr;
1251 }
1252 
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1253 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1254 {
1255     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1256     resultConst->SetPrimType(toType);
1257     resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1258     return resultConst;
1259 }
1260 
FoldTrunc(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1261 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1262 {
1263     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1264     resultConst->SetPrimType(toType);
1265     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1266     if (fromType == PTY_f32) {
1267         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1268         CHECK_NULL_FATAL(constValue);
1269         float floatValue = trunc(constValue->GetValue());
1270         if (IsPrimitiveFloat(toType)) {
1271             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1272         } else if (FloatToIntOverflow(floatValue, toType)) {
1273             return nullptr;
1274         } else {
1275             resultConst->SetConstVal(
1276                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1277         }
1278     } else {
1279         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1280         CHECK_NULL_FATAL(constValue);
1281         double doubleValue = trunc(constValue->GetValue());
1282         if (IsPrimitiveFloat(toType)) {
1283             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1284         } else if (DoubleToIntOverflow(doubleValue, toType)) {
1285             return nullptr;
1286         } else {
1287             resultConst->SetConstVal(
1288                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1289         }
1290     }
1291     return resultConst;
1292 }
1293 
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1294 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1295 {
1296     if (IsPrimitiveDynType(fromType) || IsPrimitiveDynType(toType) ||
1297         IsPrimitiveVector(fromType) || IsPrimitiveVector(toType)) {
1298         // do not fold
1299         return nullptr;
1300     }
1301     if (fromType == PTY_f128 || toType == PTY_f128) {
1302         // folding while Cvt float128 is not supported yet
1303         return nullptr;
1304     }
1305 
1306     if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1307         MIRConst *toConst = nullptr;
1308         uint32 fromSize = GetPrimTypeBitSize(fromType);
1309         uint32 toSize = GetPrimTypeBitSize(toType);
1310         // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1311         if (fromType == PTY_u1) {
1312             fromSize = 1;
1313         }
1314         if (toType == PTY_u1) {
1315             toSize = 1;
1316         }
1317         if (toSize > fromSize) {
1318             Opcode op = OP_zext;
1319             if (IsSignedInteger(fromType)) {
1320                 op = OP_sext;
1321             }
1322             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1323             ASSERT_NOT_NULL(constVal);
1324             toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1325                 constVal->GetValue().TruncOrExtend(fromType));
1326         } else {
1327             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1328             ASSERT_NOT_NULL(constVal);
1329             MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1330             toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1331                 static_cast<uint64>(constVal->GetExtValue()), type);
1332         }
1333         return toConst;
1334     }
1335     if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1336         MIRConst *toConst = nullptr;
1337         if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1338             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1339             const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1340             ASSERT_NOT_NULL(fromValue);
1341             float floatValue = static_cast<float>(fromValue->GetValue());
1342             MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1343             toConst = toValue;
1344         } else {
1345             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1346             const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1347             ASSERT_NOT_NULL(fromValue);
1348             double doubleValue = static_cast<double>(fromValue->GetValue());
1349             MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1350             toConst = toValue;
1351         }
1352         return toConst;
1353     }
1354     if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1355         return FoldFloorMIRConst(cst, fromType, toType, false);
1356     }
1357     if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1358         return FoldRoundMIRConst(cst, fromType, toType);
1359     }
1360     CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1361     return nullptr;
1362 }
1363 
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1364 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1365 {
1366     MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1367     if (toConstValue == nullptr) {
1368         return nullptr;
1369     }
1370     ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1371     toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1372     toConst->SetConstVal(toConstValue);
1373     return toConst;
1374 }
1375 
1376 // return a primType with bit size >= bitSize (and the nearest one),
1377 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1378 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1379 {
1380     bool isSigned = IsSignedInteger(ptyp);
1381     bool isFloat = IsPrimitiveFloat(ptyp);
1382     if (bitSize == 1) { // 1 bit
1383         return PTY_u1;
1384     }
1385     if (bitSize <= 8) { // 8 bit
1386         return isSigned ? PTY_i8 : PTY_u8;
1387     }
1388     if (bitSize <= 16) { // 16 bit
1389         return isSigned ? PTY_i16 : PTY_u16;
1390     }
1391     if (bitSize <= 32) { // 32 bit
1392         return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1393     }
1394     if (bitSize <= 64) { // 64 bit
1395         return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1396     }
1397     if (bitSize <= 128) { // 128 bit
1398         return isFloat ? PTY_f128 : (isSigned ? PTY_i128 : PTY_u128);
1399     }
1400     return ptyp;
1401 }
1402 
GetIntPrimTypeMax(PrimType ptyp)1403 size_t GetIntPrimTypeMax(PrimType ptyp)
1404 {
1405     switch (ptyp) {
1406         case PTY_u1:
1407             return 1;
1408         case PTY_u8:
1409             return UINT8_MAX;
1410         case PTY_i8:
1411             return INT8_MAX;
1412         case PTY_u16:
1413             return UINT16_MAX;
1414         case PTY_i16:
1415             return INT16_MAX;
1416         case PTY_u32:
1417             return UINT32_MAX;
1418         case PTY_i32:
1419             return INT32_MAX;
1420         case PTY_u64:
1421             return UINT64_MAX;
1422         case PTY_i64:
1423             return INT64_MAX;
1424         default:
1425             CHECK_FATAL(false, "NYI");
1426     }
1427 }
1428 
GetIntPrimTypeMin(PrimType ptyp)1429 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1430 {
1431     if (IsUnsignedInteger(ptyp)) {
1432         return 0;
1433     }
1434     switch (ptyp) {
1435         case PTY_i8:
1436             return INT8_MIN;
1437         case PTY_i16:
1438             return INT16_MIN;
1439         case PTY_i32:
1440             return INT32_MIN;
1441         case PTY_i64:
1442             return INT64_MIN;
1443         default:
1444             CHECK_FATAL(false, "NYI");
1445     }
1446 }
1447 
1448 // return a primtype to represent value range of expr
GetExprValueRangePtyp(BaseNode * expr)1449 PrimType GetExprValueRangePtyp(BaseNode *expr)
1450 {
1451     PrimType ptyp = expr->GetPrimType();
1452     Opcode op = expr->GetOpCode();
1453     if (expr->IsLeaf()) {
1454         return ptyp;
1455     }
1456     if (kOpcodeInfo.IsTypeCvt(op)) {
1457         auto *node = static_cast<TypeCvtNode *>(expr);
1458         if (GetPrimTypeSize(node->FromType()) < GetPrimTypeSize(node->GetPrimType())) {
1459             return GetExprValueRangePtyp(expr->Opnd(0));
1460         }
1461         return ptyp;
1462     }
1463     if (op == OP_sext || op == OP_zext || op == OP_extractbits) {
1464         auto *node = static_cast<ExtractbitsNode *>(expr);
1465         uint8 size = node->GetBitsSize();
1466         return GetNearestSizePtyp(size, expr->GetPrimType());
1467     }
1468     // find max size primtype of opnds.
1469     size_t maxTypeSize = 1;
1470     size_t ptypSize = GetPrimTypeSize(ptyp);
1471     for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
1472         PrimType opndPtyp = GetExprValueRangePtyp(expr->Opnd(i));
1473         size_t opndSize = GetPrimTypeSize(opndPtyp);
1474         if (ptypSize <= opndSize) {
1475             return ptyp;
1476         }
1477         if (maxTypeSize < opndSize) {
1478             maxTypeSize = opndSize;
1479             constexpr size_t intMaxSize = 8;
1480             if (maxTypeSize == intMaxSize) {
1481                 break;
1482             }
1483         }
1484     }
1485     return GetNearestSizePtyp(static_cast<uint8>(maxTypeSize), ptyp);
1486 }
1487 
IsCvtEliminatable(PrimType fromPtyp,PrimType destPtyp,Opcode op,Opcode opndOp)1488 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1489 {
1490     if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1491         return false;
1492     }
1493     if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1494         return false;
1495     }
1496     return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1497         (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1498         (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1499 }
1500 
FoldTypeCvt(TypeCvtNode * node)1501 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1502 {
1503     CHECK_NULL_FATAL(node);
1504     BaseNode *result = nullptr;
1505     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1506         return {node, std::nullopt};
1507     }
1508     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1509     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1510     PrimType destPtyp = node->GetPrimType();
1511     PrimType fromPtyp = node->FromType();
1512     if (cst != nullptr) {
1513         switch (node->GetOpCode()) {
1514             case OP_ceil: {
1515                 result = FoldCeil(*cst, fromPtyp, destPtyp);
1516                 break;
1517             }
1518             case OP_cvt: {
1519                 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1520                 break;
1521             }
1522             case OP_floor: {
1523                 result = FoldFloor(*cst, fromPtyp, destPtyp);
1524                 break;
1525             }
1526             case OP_round: {
1527                 result = FoldRound(*cst, fromPtyp, destPtyp);
1528                 break;
1529             }
1530             case OP_trunc: {
1531                 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1532                 break;
1533             }
1534             default:
1535                 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1536                 break;
1537         }
1538     } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1539         // the cvt is redundant
1540         return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1541     }
1542     if (result == nullptr) {
1543         BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1544         if (e != node->Opnd(0)) {
1545             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1546                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1547         } else {
1548             result = node;
1549         }
1550     }
1551     return std::make_pair(result, std::nullopt);
1552 }
1553 
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1554 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1555 {
1556     uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1557     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1558     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1559     return constValue;
1560 }
1561 
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1562 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1563                                            const ConstvalNode &cst) const
1564 {
1565     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1566     const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1567     ASSERT_NOT_NULL(intCst);
1568     IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1569     MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1570     resultConst->SetPrimType(toConst->GetType().GetPrimType());
1571     resultConst->SetConstVal(toConst);
1572     return resultConst;
1573 }
1574 
1575 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode & x,MIRFunction & f)1576 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1577 {
1578     if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1579         return false;  // this is trying to be conservative
1580     }
1581     BaseNode *opnd = x.Opnd(0);
1582     MIRType *mirType = nullptr;
1583     if (opnd->GetOpCode() == OP_dread) {
1584         DreadNode *dread = static_cast<DreadNode*>(opnd);
1585         MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1586         ASSERT_NOT_NULL(sym);
1587         mirType = sym->GetType();
1588         if (dread->GetFieldID() != 0) {
1589             MIRStructType *structType = dynamic_cast<MIRStructType*>(mirType);
1590             if (structType == nullptr) {
1591                 return false;
1592             }
1593             mirType = structType->GetFieldType(dread->GetFieldID());
1594         }
1595     } else if (opnd->GetOpCode() == OP_iread) {
1596         IreadNode *iread = static_cast<IreadNode*>(opnd);
1597         MIRPtrType *ptrType =
1598             dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1599         if (ptrType == nullptr) {
1600             return false;
1601         }
1602         mirType = ptrType->GetPointedType();
1603         if (iread->GetFieldID() != 0) {
1604             MIRStructType *structType = dynamic_cast<MIRStructType*>(mirType);
1605             if (structType == nullptr) {
1606                 return false;
1607             }
1608             mirType = structType->GetFieldType(iread->GetFieldID());
1609         }
1610     } else if (opnd->GetOpCode() == OP_extractbits &&
1611                 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1612         return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1613             IsUnsignedInteger(opnd->GetPrimType()));
1614     } else {
1615         return false;
1616     }
1617     return IsPrimitiveInteger(mirType->GetPrimType()) &&
1618             ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1619             (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1620             mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1621             mirType->GetPrimType() == x.GetPrimType();
1622 }
1623 
1624 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1625 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1626 {
1627     CHECK_NULL_FATAL(node);
1628     BaseNode *result = nullptr;
1629     uint8 offset = node->GetBitsOffset();
1630     uint8 size = node->GetBitsSize();
1631     Opcode opcode = node->GetOpCode();
1632     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1633     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1634     if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1635         result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1636         return std::make_pair(result, std::nullopt);
1637     }
1638     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1639     if (e != node->Opnd(0)) {
1640         result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1641                                                                        size, e);
1642     } else {
1643         result = node;
1644     }
1645     // check for consecutive and redundant extraction of same bits
1646     BaseNode *opnd = result->Opnd(0);
1647     DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1648     Opcode opndOp = opnd->GetOpCode();
1649     if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1650         uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1651         uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1652         if (offset == opndOffset && size == opndSize) {
1653             result->SetOpnd(opnd->Opnd(0), 0);  // delete the redundant extraction
1654         }
1655     }
1656     if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1657         if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1658             return std::make_pair(result->Opnd(0), std::nullopt);
1659         }
1660     }
1661     return std::make_pair(result, std::nullopt);
1662 }
1663 
FoldIread(IreadNode * node)1664 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1665 {
1666     CHECK_NULL_FATAL(node);
1667     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1668     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1669     node->SetOpnd(e, 0);
1670     BaseNode *result = node;
1671     if (e->GetOpCode() != OP_addrof) {
1672         return std::make_pair(result, std::nullopt);
1673     }
1674 
1675     AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1676     MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1677     DEBUG_ASSERT(msy != nullptr, "nullptr check");
1678     TyIdx typeId = msy->GetTyIdx();
1679     CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1680     MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1681     if (addrofNode->GetFieldID() != 0) {
1682         CHECK_FATAL(msyType->IsStructType(), "must be");
1683         msyType = static_cast<MIRStructType*>(msyType)->GetFieldType(addrofNode->GetFieldID());
1684     }
1685     MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1686     // If the high level type of iaddrof/iread doesn't match
1687     // the type of addrof's rhs, this optimization cannot be done.
1688     if (ptrType->GetPointedType() != msyType) {
1689         return std::make_pair(result, std::nullopt);
1690     }
1691 
1692     Opcode op = node->GetOpCode();
1693     FieldID fieldID = node->GetFieldID();
1694     if (op == OP_iaddrof) {
1695         AddrofNode *newAddrof = addrofNode->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1696         CHECK_NULL_FATAL(newAddrof);
1697         newAddrof->SetFieldID(newAddrof->GetFieldID() + fieldID);
1698         result = newAddrof;
1699     } else if (op == OP_iread) {
1700         result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1701                                                                   node->GetFieldID() + addrofNode->GetFieldID());
1702     }
1703     return std::make_pair(result, std::nullopt);
1704 }
1705 
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1706 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1707 {
1708     switch (op) {
1709         case OP_add: {
1710             int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1711             if (IsUnsignedInteger(primType)) {
1712                 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1713             }
1714             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1715             return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1716                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1717                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1718                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1719         }
1720         case OP_sub: {
1721             if (IsUnsignedInteger(primType)) {
1722                 return cstA < cstB;
1723             }
1724             int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1725             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1726             return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1727                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1728                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1729                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1730         }
1731         default: {
1732             return false;
1733         }
1734     }
1735 }
1736 
FoldBinary(BinaryNode * node)1737 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1738 {
1739     CHECK_NULL_FATAL(node);
1740     BaseNode *result = nullptr;
1741     std::optional<IntVal> sum = std::nullopt;
1742     Opcode op = node->GetOpCode();
1743     PrimType primType = node->GetPrimType();
1744     PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1745     PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1746     if (lPrimTypes == PTY_f128 || rPrimTypes == PTY_f128 || node->GetPrimType() == PTY_f128) {
1747         // folding of non-unary float128 is not supported yet
1748         return std::make_pair(static_cast<BaseNode*>(node), std::nullopt);
1749     }
1750     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1751     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1752     BaseNode *l = lp.first;
1753     BaseNode *r = rp.first;
1754     ASSERT_NOT_NULL(r);
1755     ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1756     ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1757     bool isInt = IsPrimitiveInteger(primType);
1758 
1759     if (lConst != nullptr && rConst != nullptr) {
1760         MIRConst *lConstVal = lConst->GetConstVal();
1761         MIRConst *rConstVal = rConst->GetConstVal();
1762         ASSERT_NOT_NULL(lConstVal);
1763         ASSERT_NOT_NULL(rConstVal);
1764         // Don't fold div by 0, for floats div by 0 is well defined.
1765         if ((op == OP_div || op == OP_rem) && isInt &&
1766             !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1767             result = NewBinaryNode(node, op, primType, lConst, rConst);
1768         } else {
1769             // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1770             // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1771             // logic since the alternative is to return pair(result = nullptr, sum = 6).
1772             // Doing so would introduce many nullptr checks in the code. See previous
1773             // commits that implemented that logic for a comparison.
1774             result = FoldConstBinary(op, primType, *lConst, *rConst);
1775         }
1776     } else if (lConst != nullptr && isInt) {
1777         MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1778         ASSERT_NOT_NULL(mcst);
1779         PrimType cstTyp = mcst->GetType().GetPrimType();
1780         IntVal cst = mcst->GetValue();
1781         if (op == OP_add) {
1782             if (IsSignedInteger(cstTyp) && rp.second &&
1783                 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1784                 // do not introduce signed integer overflow
1785                 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1786             } else {
1787                 sum = cst + rp.second;
1788                 result = r;
1789             }
1790         } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1791             // We exclude u1 type for fixing the following wrong example:
1792             // before cf:
1793             //   sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1794             // after cf:
1795             //   add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1796             sum = cst - rp.second;
1797             if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1798                 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1799             }
1800             result = NegateTree(r);
1801         } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1802                     op == OP_band || op == OP_cand || op == OP_land) &&
1803                     cst == 0) {
1804             // 0 * X -> 0
1805             // 0 / X -> 0
1806             // 0 % X -> 0
1807             // 0 >> X -> 0
1808             // 0 << X -> 0
1809             // 0 & X -> 0
1810             // 0 && X -> 0
1811             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1812         } else if (op == OP_mul && cst == 1) {
1813             // 1 * X --> X
1814             sum = rp.second;
1815             result = r;
1816         } else if (op == OP_bior && cst == -1) {
1817             // (-1) | X -> -1
1818             result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1819         } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1820             // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1821             sum = cst * rp.second;
1822             if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1823                 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1824             }
1825             result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1826         } else if (op == OP_lior || op == OP_cior) {
1827             if (cst != 0) {
1828                 // 5 || X -> 1
1829                 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1830             } else {
1831                 // when cst is zero
1832                 // 0 || X -> (X != 0);
1833                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1834                     OP_ne, primType, r->GetPrimType(), r,
1835                     mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1836             }
1837         } else if ((op == OP_cand || op == OP_land) && cst != 0) {
1838             // 5 && X -> (X != 0)
1839             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1840                 OP_ne, primType, r->GetPrimType(), r, mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1841         } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1842             // 0 | X -> X
1843             // 0 ^ X -> X
1844             sum = rp.second;
1845             result = r;
1846         } else {
1847             result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1848         }
1849         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1850             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1851         }
1852     } else if (rConst != nullptr && isInt) {
1853         MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1854         ASSERT_NOT_NULL(mcst);
1855         PrimType cstTyp = mcst->GetType().GetPrimType();
1856         IntVal cst = mcst->GetValue();
1857         if (op == OP_add) {
1858             if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1859                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1860             } else {
1861                 result = l;
1862                 sum = lp.second + cst;
1863             }
1864         } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1865             result = l;
1866             sum = lp.second - cst;
1867         } else if ((op == OP_mul || op == OP_band || op == OP_cand || op == OP_land) && cst == 0) {
1868             // X * 0 -> 0
1869             // X & 0 -> 0
1870             // X && 0 -> 0
1871             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1872         } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1873             // case [X * 1 -> X]
1874             // case [X / 1 = X]
1875             sum = lp.second;
1876             result = l;
1877         } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1878                 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1879             // temporary fix for constfold of mul/div in DejaGnu
1880             // Later we need a more formal interface for pattern match
1881             // X * Y / Y -> X
1882             BaseNode *x = l->Opnd(0);
1883             BaseNode *y = l->Opnd(1);
1884             ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1885             ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1886             bool foldMulDiv = false;
1887             if (yConst != nullptr && xConst == nullptr &&
1888                 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1889                 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1890                 ASSERT_NOT_NULL(yCst);
1891                 IntVal mulCst = yCst->GetValue();
1892                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1893                     mulCst.GetExtValue() == cst.GetExtValue()) {
1894                     foldMulDiv = true;
1895                     result = x;
1896                 }
1897             } else if (xConst != nullptr && yConst == nullptr &&
1898                         IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1899                 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1900                 ASSERT_NOT_NULL(xCst);
1901                 IntVal mulCst = xCst->GetValue();
1902                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1903                     mulCst.GetExtValue() == cst.GetExtValue()) {
1904                     foldMulDiv = true;
1905                     result = y;
1906                 }
1907             }
1908             if (!foldMulDiv) {
1909                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1910             }
1911         } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1912             // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1913             sum = lp.second * cst;
1914             if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1915                 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1916             }
1917             if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1918                 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1919                 result = lp.first->Opnd(0);
1920             } else {
1921                 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1922             }
1923         } else if (op == OP_band && cst == -1) {
1924             // X & (-1) -> X
1925             sum = lp.second;
1926             result = l;
1927         } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1928                    (!lp.second.has_value() || lp.second == 0)) {
1929             bool fold2extractbits = false;
1930             if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1931                 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1932                 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1933                     ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1934                     int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1935                     uint64 ucst = cst.GetZXTValue();
1936                     uint32 bsize = 0;
1937                     do {
1938                         bsize++;
1939                         ucst >>= 1;
1940                     } while (ucst != 0);
1941                     if (shrAmt + static_cast<int64>(bsize) <=
1942                         static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1943                         static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1944                         fold2extractbits = true;
1945                         // change to use extractbits
1946                         result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1947                             GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1948                         sum = std::nullopt;
1949                     }
1950                 }
1951             }
1952             if (!fold2extractbits) {
1953                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1954                 sum = std::nullopt;
1955             }
1956         } else if (op == OP_bior && cst == -1) {
1957             // X | (-1) -> -1
1958             result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1959         } else if ((op == OP_lior || op == OP_cior)) {
1960             if (cst == 0) {
1961                 // X || 0 -> X
1962                 sum = lp.second;
1963                 result = l;
1964             } else if (!cst.GetSignBit()) {
1965                 // X || 5 -> 1
1966                 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1967             } else {
1968                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1969             }
1970         } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1971             // X >> 0 -> X
1972             // X << 0 -> X
1973             // X | 0 -> X
1974             // X ^ 0 -> X
1975             sum = lp.second;
1976             result = l;
1977         } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1978             // bxor i32 (
1979             //   cvt i32 u1 (regread u1 %13),
1980             //  constValue i32 1),
1981             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1982             if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1983                 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1984                 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1985                     BaseNode *base = cvtNode->Opnd(0);
1986                     BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1987                     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1988                     BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1989                     result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1990                 }
1991             }
1992         } else if (op == OP_rem && cst == 1) {
1993             // X % 1 -> 0
1994             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1995         } else {
1996             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1997         }
1998         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1999             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
2000         }
2001     } else if (isInt && (op == OP_add || op == OP_sub)) {
2002         if (op == OP_add) {
2003             result = NewBinaryNode(node, op, primType, l, r);
2004             sum = lp.second + rp.second;
2005         } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
2006             // if fold is (x - (y - z))    ->     (x - neg(z)) - y
2007             // (x - neg(z)) Could cross the int limit
2008             // return node
2009             result = node;
2010         } else {
2011             result = NewBinaryNode(node, op, primType, l, r);
2012             sum = lp.second - rp.second;
2013         }
2014     } else {
2015         result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
2016     }
2017     return std::make_pair(result, sum);
2018 }
2019 
SimplifyDoubleConstvalCompare(CompareNode & node,bool isRConstval,bool isGtOrLt) const2020 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
2021 {
2022     if (isRConstval) {
2023         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
2024         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2025             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
2026             return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2027                 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
2028         }
2029     } else {
2030         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
2031         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2032             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
2033             if (isGtOrLt) {
2034                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2035                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
2036             } else {
2037                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2038                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
2039             }
2040         }
2041     }
2042     return &node;
2043 }
2044 
SimplifyDoubleCompare(CompareNode & compareNode) const2045 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
2046 {
2047     // See arm manual B.cond(P2993) and FCMP(P1091)
2048     CompareNode *node = &compareNode;
2049     BaseNode *result = node;
2050     BaseNode *l = node->Opnd(0);
2051     BaseNode *r = node->Opnd(1);
2052     if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
2053         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
2054             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
2055             result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
2056         } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
2057             // ne (u1 x, constValue 0)  <==> x
2058             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
2059             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2060                 BaseNode *opnd = l;
2061                 do {
2062                     if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
2063                         result = opnd;
2064                         break;
2065                     } else if (opnd->GetOpCode() == OP_cvt) {
2066                         TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
2067                         opnd = cvtNode->Opnd(0);
2068                     } else {
2069                         opnd = nullptr;
2070                     }
2071                 } while (opnd != nullptr);
2072             }
2073         } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
2074             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
2075             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
2076                 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
2077                 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
2078                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2079                     resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
2080             }
2081         }
2082     } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
2083         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
2084             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
2085             result = SimplifyDoubleConstvalCompare(*node,
2086                 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
2087         }
2088     }
2089     return result;
2090 }
2091 
FoldCompare(CompareNode * node)2092 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
2093 {
2094     CHECK_NULL_FATAL(node);
2095     BaseNode *result = nullptr;
2096     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
2097     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
2098     ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
2099     ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
2100     if (node->GetOpndType() == PTY_f128 || node->GetPrimType() == PTY_f128) {
2101         // folding of non-unary float128 is not supported yet
2102         return std::make_pair(static_cast<BaseNode*>(node), std::nullopt);
2103     }
2104     Opcode opcode = node->GetOpCode();
2105     if (lConst != nullptr && rConst != nullptr && !IsPrimitiveDynType(node->GetOpndType())) {
2106         result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
2107     } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
2108                lConst->GetConstVal()->GetKind() == kConstInt) {
2109         BaseNode *l = lp.first;
2110         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2111         result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
2112     } else {
2113         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
2114         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2115         if (l != node->Opnd(0) || r != node->Opnd(1)) {
2116             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2117                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
2118         } else {
2119             result = node;
2120         }
2121         auto *compareNode = static_cast<CompareNode*>(result);
2122         CHECK_NULL_FATAL(compareNode);
2123         result = SimplifyDoubleCompare(*compareNode);
2124     }
2125     return std::make_pair(result, std::nullopt);
2126 }
2127 
Fold(BaseNode * node)2128 BaseNode *ConstantFold::Fold(BaseNode *node)
2129 {
2130     if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
2131         return nullptr;
2132     }
2133     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
2134     BaseNode *result = PairToExpr(node->GetPrimType(), p);
2135     if (result == node) {
2136         result = nullptr;
2137     }
2138     return result;
2139 }
2140 
FoldDepositbits(DepositbitsNode * node)2141 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldDepositbits(DepositbitsNode *node)
2142 {
2143     CHECK_NULL_FATAL(node);
2144     BaseNode *result = nullptr;
2145     uint8 bitsOffset = node->GetBitsOffset();
2146     uint8 bitsSize = node->GetBitsSize();
2147     std::pair<BaseNode*, std::optional<IntVal>> leftPair = DispatchFold(node->Opnd(0));
2148     std::pair<BaseNode*, std::optional<IntVal>> rightPair = DispatchFold(node->Opnd(1));
2149     ConstvalNode *leftConst = safe_cast<ConstvalNode>(leftPair.first);
2150     ConstvalNode *rightConst = safe_cast<ConstvalNode>(rightPair.first);
2151     if (leftConst != nullptr && rightConst != nullptr) {
2152         MIRIntConst *intConst0 = safe_cast<MIRIntConst>(leftConst->GetConstVal());
2153         MIRIntConst *intConst1 = safe_cast<MIRIntConst>(rightConst->GetConstVal());
2154         ASSERT_NOT_NULL(intConst0);
2155         ASSERT_NOT_NULL(intConst1);
2156         ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
2157         resultConst->SetPrimType(node->GetPrimType());
2158         MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
2159         MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
2160         uint64 op0ExtractVal = 0;
2161         uint64 op1ExtractVal = 0;
2162         uint64 mask0 = (1LLU << (bitsSize + bitsOffset)) - 1;
2163         uint64 mask1 = (1LLU << bitsOffset) - 1;
2164         uint64 op0Mask = ~(mask0 ^ mask1);
2165         op0ExtractVal = (static_cast<uint64>(intConst0->GetExtValue()) & op0Mask);
2166         op1ExtractVal = (static_cast<uint64>(intConst1->GetExtValue()) << bitsOffset) &
2167                         ((1ULL << (bitsSize + bitsOffset)) - 1);
2168         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
2169             (op0ExtractVal | op1ExtractVal), constValue->GetType());
2170         resultConst->SetConstVal(constValue);
2171         result = resultConst;
2172     } else {
2173         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), leftPair);
2174         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rightPair);
2175         if (l != node->Opnd(0) || r != node->Opnd(1)) {
2176             result = mirModule->CurFuncCodeMemPool()->New<DepositbitsNode>(
2177                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), bitsOffset, bitsSize, l, r);
2178         } else {
2179             result = node;
2180         }
2181     }
2182     return std::make_pair(result, std::nullopt);
2183 }
2184 
FoldArray(ArrayNode * node)2185 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldArray(ArrayNode *node)
2186 {
2187     CHECK_NULL_FATAL(node);
2188     BaseNode *result = nullptr;
2189     size_t i = 0;
2190     bool isFolded = false;
2191     ArrayNode *arrNode = mirModule->CurFuncCodeMemPool()->New<ArrayNode>(*mirModule, PrimType(node->GetPrimType()),
2192                                                                          node->GetTyIdx(), node->GetBoundsCheck());
2193     for (i = 0; i < node->GetNopndSize(); i++) {
2194         std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->GetNopndAt(i));
2195         BaseNode *tmpNode = PairToExpr(node->GetNopndAt(i)->GetPrimType(), p);
2196         if (tmpNode != node->GetNopndAt(i)) {
2197             isFolded = true;
2198         }
2199         arrNode->GetNopnd().push_back(tmpNode);
2200         arrNode->SetNumOpnds(arrNode->GetNumOpnds() + 1);
2201     }
2202     if (isFolded) {
2203         result = arrNode;
2204     } else {
2205         result = node;
2206     }
2207     return std::make_pair(result, std::nullopt);
2208 }
2209 
FoldTernary(TernaryNode * node)2210 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTernary(TernaryNode *node)
2211 {
2212     CHECK_NULL_FATAL(node);
2213     constexpr size_t kFirst = 0;
2214     constexpr size_t kSecond = 1;
2215     constexpr size_t kThird = 2;
2216     BaseNode *result = node;
2217     std::vector<PrimType> primTypes;
2218     std::vector<std::pair<BaseNode*, std::optional<IntVal>>> p;
2219     for (size_t i = 0; i < node->NumOpnds(); i++) {
2220         BaseNode *tempNopnd = node->Opnd(i);
2221         CHECK_NULL_FATAL(tempNopnd);
2222         primTypes.push_back(tempNopnd->GetPrimType());
2223         p.push_back(DispatchFold(tempNopnd));
2224     }
2225     if (node->GetOpCode() == OP_select) {
2226         ConstvalNode *const0 = safe_cast<ConstvalNode>(p[kFirst].first);
2227         if (const0 != nullptr) {
2228             MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0->GetConstVal());
2229             ASSERT_NOT_NULL(intConst0);
2230             // Selecting the first value if not 0, selecting the second value otherwise.
2231             if (!intConst0->IsZero()) {
2232                 result = PairToExpr(primTypes[kSecond], p[kSecond]);
2233             } else {
2234                 result = PairToExpr(primTypes[kThird], p[kThird]);
2235             }
2236         } else {
2237             ConstvalNode *const1 = safe_cast<ConstvalNode>(p[kSecond].first);
2238             ConstvalNode *const2 = safe_cast<ConstvalNode>(p[kThird].first);
2239             if (const1 != nullptr && const2 != nullptr) {
2240                 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2241                 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2242                 double dconst1 = 0.0;
2243                 double dconst2 = 0.0;
2244                 // for fpconst
2245                 if (intConst1 == nullptr || intConst2 == nullptr) {
2246                     PrimType ptyp = const1->GetPrimType();
2247                     if (ptyp == PTY_f64) {
2248                         MIRDoubleConst *dConst1 = safe_cast<MIRDoubleConst>(const1->GetConstVal());
2249                         dconst1 = dConst1->GetValue();
2250                         MIRDoubleConst *dConst2 = safe_cast<MIRDoubleConst>(const2->GetConstVal());
2251                         dconst2 = dConst2->GetValue();
2252                     } else if (ptyp == PTY_f32) {
2253                         MIRFloatConst *fConst1 = safe_cast<MIRFloatConst>(const1->GetConstVal());
2254                         dconst1 = static_cast<double>(fConst1->GetFloatValue());
2255                         MIRFloatConst *fConst2 = safe_cast<MIRFloatConst>(const2->GetConstVal());
2256                         dconst2 = static_cast<double>(fConst2->GetFloatValue());
2257                     }
2258                 } else {
2259                     dconst1 = static_cast<double>(intConst1->GetExtValue());
2260                     dconst2 = static_cast<double>(intConst2->GetExtValue());
2261                 }
2262                 PrimType foldedPrimType = primTypes[kSecond];
2263                 if (!IsPrimitiveInteger(foldedPrimType)) {
2264                     foldedPrimType = primTypes[kThird];
2265                 }
2266                 if (dconst1 == 1.0 && dconst2 == 0.0 && GetPrimTypeActualBitSize(primTypes[0]) == 1) {
2267                     if (IsPrimitiveInteger(foldedPrimType)) {
2268                         result = PairToExpr(foldedPrimType, p[0]);
2269                     } else {
2270                         result = PairToExpr(primTypes[0], p[0]);
2271                         result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2272                                                                                    result);
2273                     }
2274                     return std::make_pair(result, std::nullopt);
2275                 }
2276                 if (dconst1 == 0.0 && dconst2 == 1.0 && GetPrimTypeActualBitSize(primTypes[0]) == 1) {
2277                     BaseNode *lnot = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2278                         OP_eq, primTypes[0], primTypes[0], PairToExpr(primTypes[0], p[0]),
2279                         mirModule->GetMIRBuilder()->CreateIntConst(0, primTypes[0]));
2280                     std::pair<BaseNode*, std::optional<IntVal>> pairTemp = DispatchFold(lnot);
2281                     if (IsPrimitiveInteger(foldedPrimType)) {
2282                         result = PairToExpr(foldedPrimType, pairTemp);
2283                     } else {
2284                         result = PairToExpr(primTypes[0], pairTemp);
2285                         result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2286                                                                                    result);
2287                     }
2288                     return std::make_pair(result, std::nullopt);
2289                 }
2290             }
2291         }
2292     }
2293     BaseNode *e0 = PairToExpr(primTypes[kFirst], p[kFirst]);
2294     BaseNode *e1 = PairToExpr(primTypes[kSecond], p[kSecond]);
2295     BaseNode *e2 = PairToExpr(primTypes[kThird], p[kThird]);  // count up to 3 for ternary node
2296     if (e0 != node->Opnd(kFirst) || e1 != node->Opnd(kSecond) || e2 != node->Opnd(kThird)) {
2297         result = mirModule->CurFuncCodeMemPool()->New<TernaryNode>(Opcode(node->GetOpCode()),
2298                                                                    PrimType(node->GetPrimType()), e0, e1, e2);
2299     }
2300     return std::make_pair(result, std::nullopt);
2301 }
2302 
SimplifyDassign(DassignNode * node)2303 StmtNode *ConstantFold::SimplifyDassign(DassignNode *node)
2304 {
2305     CHECK_NULL_FATAL(node);
2306     BaseNode *returnValue = nullptr;
2307     returnValue = Fold(node->GetRHS());
2308     if (returnValue != nullptr) {
2309         node->SetRHS(returnValue);
2310     }
2311     return node;
2312 }
2313 
SimplifyIassignWithAddrofBaseNode(IassignNode & node,const AddrofNode & base) const2314 StmtNode *ConstantFold::SimplifyIassignWithAddrofBaseNode(IassignNode &node, const AddrofNode &base) const
2315 {
2316     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2317     if (!mirTypeOfIass->IsMIRPtrType()) {
2318         return &node;
2319     }
2320     auto *iassPtType = static_cast<MIRPtrType*>(mirTypeOfIass);
2321 
2322     MIRSymbol *lhsSym = mirModule->CurFunction()->GetLocalOrGlobalSymbol(base.GetStIdx());
2323     TyIdx lhsTyIdx = lhsSym->GetTyIdx();
2324     if (base.GetFieldID() != 0) {
2325         auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(lhsTyIdx);
2326         if (!mirType->IsStructType()) {
2327             return &node;
2328         }
2329         lhsTyIdx = static_cast<MIRStructType*>(mirType)->GetFieldType(base.GetFieldID())->GetTypeIndex();
2330     }
2331     if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2332         DassignNode *dassignNode = mirModule->CurFuncCodeMemPool()->New<DassignNode>();
2333         dassignNode->SetStIdx(base.GetStIdx());
2334         dassignNode->SetRHS(node.GetRHS());
2335         dassignNode->SetFieldID(base.GetFieldID() + node.GetFieldID());
2336         // reuse stmtid to maintain stmtFreqs if profileUse is on
2337         dassignNode->SetStmtID(node.GetStmtID());
2338         return dassignNode;
2339     }
2340     return &node;
2341 }
2342 
SimplifyIassignWithIaddrofBaseNode(IassignNode & node,const IaddrofNode & base)2343 StmtNode *ConstantFold::SimplifyIassignWithIaddrofBaseNode(IassignNode &node, const IaddrofNode &base)
2344 {
2345     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2346     if (!mirTypeOfIass->IsMIRPtrType()) {
2347         return &node;
2348     }
2349     auto *iassPtType = static_cast<MIRPtrType*>(mirTypeOfIass);
2350 
2351     if (base.GetFieldID() == 0) {
2352         // this iaddrof is redundant
2353         node.SetAddrExpr(base.Opnd(0));
2354         return &node;
2355     }
2356 
2357     auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(base.GetTyIdx());
2358     if (!mirType->IsMIRPtrType()) {
2359         return &node;
2360     }
2361     auto *iaddrofPtType = static_cast<MIRPtrType*>(mirType);
2362 
2363     MIRStructType *lhsStructTy =
2364         static_cast<MIRStructType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iaddrofPtType->GetPointedTyIdx()));
2365     TyIdx lhsTyIdx = lhsStructTy->GetFieldType(base.GetFieldID())->GetTypeIndex();
2366     if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2367         // eliminate the iaddrof by updating the iassign's fieldID and tyIdx
2368         node.SetFieldID(node.GetFieldID() + base.GetFieldID());
2369         node.SetTyIdx(base.GetTyIdx());
2370         node.SetOpnd(base.Opnd(0), 0);
2371         // recursive call for the new iassign
2372         return SimplifyIassign(&node);
2373     }
2374     return &node;
2375 }
2376 
SimplifyIassign(IassignNode * node)2377 StmtNode *ConstantFold::SimplifyIassign(IassignNode *node)
2378 {
2379     CHECK_NULL_FATAL(node);
2380     BaseNode *returnValue = nullptr;
2381     returnValue = Fold(node->Opnd(0));
2382     if (returnValue != nullptr) {
2383         node->SetOpnd(returnValue, 0);
2384     }
2385     returnValue = Fold(node->GetRHS());
2386     if (returnValue != nullptr) {
2387         node->SetRHS(returnValue);
2388     }
2389     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
2390     if (!mirTypeOfIass->IsMIRPtrType()) {
2391         return node;
2392     }
2393 
2394     auto *opnd = node->Opnd(0);
2395     ASSERT_NOT_NULL(opnd);
2396     switch (opnd->GetOpCode()) {
2397         case OP_addrof: {
2398             return SimplifyIassignWithAddrofBaseNode(*node, static_cast<AddrofNode&>(*opnd));
2399         }
2400         case OP_iaddrof: {
2401             return SimplifyIassignWithIaddrofBaseNode(*node, static_cast<IreadNode&>(*opnd));
2402         }
2403         default:
2404             break;
2405     }
2406     return node;
2407 }
2408 
SimplifyCondGoto(CondGotoNode * node)2409 StmtNode *ConstantFold::SimplifyCondGoto(CondGotoNode *node)
2410 {
2411     CHECK_NULL_FATAL(node);
2412     // optimize condgoto need to update frequency, skip here
2413     if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2414         return node;
2415     }
2416     BaseNode *returnValue = nullptr;
2417     returnValue = Fold(node->Opnd(0));
2418     returnValue = (returnValue == nullptr) ? node : returnValue;
2419     if (returnValue == node && node->Opnd(0)->GetOpCode() == OP_select) {
2420         return SimplifyCondGotoSelect(node);
2421     } else {
2422         if (returnValue != node) {
2423             node->SetOpnd(returnValue, 0);
2424         }
2425         ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2426         if (cst == nullptr) {
2427             return node;
2428         }
2429         MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2430         ASSERT_NOT_NULL(intConst);
2431         if ((node->GetOpCode() == OP_brtrue && !intConst->IsZero()) ||
2432             (node->GetOpCode() == OP_brfalse && intConst->IsZero())) {
2433             uint32 freq = static_cast<uint32>(mirModule->CurFunction()->GetFreqFromLastStmt(node->GetStmtID()));
2434             GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2435             gotoNode->SetOffset(node->GetOffset());
2436             if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2437                 gotoNode->SetStmtID(node->GetStmtID());  // reuse condnode stmtid
2438             }
2439             mirModule->CurFunction()->SetLastFreqMap(gotoNode->GetStmtID(), freq);
2440             return gotoNode;
2441         } else {
2442             return nullptr;
2443         }
2444     }
2445     return node;
2446 }
2447 
SimplifyCondGotoSelect(CondGotoNode * node) const2448 StmtNode *ConstantFold::SimplifyCondGotoSelect(CondGotoNode *node) const
2449 {
2450     CHECK_NULL_FATAL(node);
2451     TernaryNode *sel = static_cast<TernaryNode*>(node->Opnd(0));
2452     if (sel == nullptr || sel->GetOpCode() != OP_select) {
2453         return node;
2454     }
2455     ConstvalNode *const1 = safe_cast<ConstvalNode>(sel->Opnd(1));
2456     ConstvalNode *const2 = safe_cast<ConstvalNode>(sel->Opnd(2));
2457     if (const1 != nullptr && const2 != nullptr) {
2458         MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2459         MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2460         ASSERT_NOT_NULL(intConst1);
2461         ASSERT_NOT_NULL(intConst2);
2462         if (intConst1->GetValue() == 1 && intConst2->GetValue() == 0) {
2463             node->SetOpnd(sel->Opnd(0), 0);
2464         } else if (intConst1->GetValue() == 0 && intConst2->GetValue() == 1) {
2465             node->SetOpCode((node->GetOpCode() == OP_brfalse) ? OP_brtrue : OP_brfalse);
2466             node->SetOpnd(sel->Opnd(0), 0);
2467         }
2468     }
2469     return node;
2470 }
2471 
SimplifySwitch(SwitchNode * node)2472 StmtNode *ConstantFold::SimplifySwitch(SwitchNode *node)
2473 {
2474     CHECK_NULL_FATAL(node);
2475     BaseNode *returnValue = nullptr;
2476     returnValue = Fold(node->GetSwitchOpnd());
2477     if (returnValue != nullptr) {
2478         node->SetSwitchOpnd(returnValue);
2479         ConstvalNode *cst = safe_cast<ConstvalNode>(node->GetSwitchOpnd());
2480         if (cst == nullptr) {
2481             return node;
2482         }
2483         MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2484         ASSERT_NOT_NULL(intConst);
2485         GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2486         bool isdefault = true;
2487         for (unsigned i = 0; i < node->GetSwitchTable().size(); i++) {
2488             if (node->GetCasePair(i).first == intConst->GetValue()) {
2489                 isdefault = false;
2490                 gotoNode->SetOffset(static_cast<LabelIdx>(node->GetCasePair(i).second));
2491                 break;
2492             }
2493         }
2494         if (isdefault) {
2495             gotoNode->SetOffset(node->GetDefaultLabel());
2496         }
2497         return gotoNode;
2498     }
2499     return node;
2500 }
2501 
SimplifyUnary(UnaryStmtNode * node)2502 StmtNode *ConstantFold::SimplifyUnary(UnaryStmtNode *node)
2503 {
2504     CHECK_NULL_FATAL(node);
2505     BaseNode *returnValue = nullptr;
2506     if (node->Opnd(0) == nullptr) {
2507         return node;
2508     }
2509     returnValue = Fold(node->Opnd(0));
2510     if (returnValue != nullptr) {
2511         node->SetOpnd(returnValue, 0);
2512     }
2513     return node;
2514 }
2515 
SimplifyBinary(BinaryStmtNode * node)2516 StmtNode *ConstantFold::SimplifyBinary(BinaryStmtNode *node)
2517 {
2518     CHECK_NULL_FATAL(node);
2519     BaseNode *returnValue = nullptr;
2520     returnValue = Fold(node->GetBOpnd(0));
2521     if (returnValue != nullptr) {
2522         node->SetBOpnd(returnValue, 0);
2523     }
2524     returnValue = Fold(node->GetBOpnd(1));
2525     if (returnValue != nullptr) {
2526         node->SetBOpnd(returnValue, 1);
2527     }
2528     return node;
2529 }
2530 
SimplifyBlock(BlockNode * node)2531 StmtNode *ConstantFold::SimplifyBlock(BlockNode *node)
2532 {
2533     CHECK_NULL_FATAL(node);
2534     if (node->GetFirst() == nullptr) {
2535         return node;
2536     }
2537     StmtNode *s = node->GetFirst();
2538     StmtNode *prevStmt = nullptr;
2539     do {
2540         StmtNode *returnValue = Simplify(s);
2541         if (returnValue != nullptr) {
2542             if (returnValue->GetOpCode() == OP_block) {
2543                 BlockNode *blk = static_cast<BlockNode*>(returnValue);
2544                 if (blk->IsEmpty()) {
2545                     node->RemoveStmt(s);
2546                 } else {
2547                     node->ReplaceStmtWithBlock(*s, *blk);
2548                     prevStmt = s;
2549                 }
2550             } else {
2551                 node->ReplaceStmt1WithStmt2(s, returnValue);
2552                 prevStmt = s;
2553             }
2554             s = s->GetNext();
2555         } else {
2556             // delete s from block
2557             StmtNode *nextStmt = s->GetNext();
2558             if (s == node->GetFirst()) {
2559                 node->SetFirst(nextStmt);
2560                 if (nextStmt != nullptr) {
2561                     nextStmt->SetPrev(nullptr);
2562                 }
2563             } else {
2564                 CHECK_NULL_FATAL(prevStmt);
2565                 prevStmt->SetNext(nextStmt);
2566                 if (nextStmt != nullptr) {
2567                     nextStmt->SetPrev(prevStmt);
2568                 }
2569             }
2570             if (s == node->GetLast()) {
2571                 node->SetLast(prevStmt);
2572             }
2573             s = nextStmt;
2574         }
2575     } while (s != nullptr);
2576     return node;
2577 }
2578 
SimplifyAsm(AsmNode * node)2579 StmtNode *ConstantFold::SimplifyAsm(AsmNode *node)
2580 {
2581     CHECK_NULL_FATAL(node);
2582     /* fold constval in input */
2583     for (size_t i = 0; i < node->NumOpnds(); i++) {
2584         const std::string &str = GlobalTables::GetUStrTable().GetStringFromStrIdx(node->inputConstraints[i]);
2585         if (str == "i") {
2586             std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(i));
2587             node->SetOpnd(p.first, i);
2588             continue;
2589         }
2590     }
2591     return node;
2592 }
2593 
SimplifyIf(IfStmtNode * node)2594 StmtNode *ConstantFold::SimplifyIf(IfStmtNode *node)
2595 {
2596     CHECK_NULL_FATAL(node);
2597     BaseNode *returnValue = nullptr;
2598     (void)Simplify(node->GetThenPart());
2599     if (node->GetElsePart()) {
2600         (void)Simplify(node->GetElsePart());
2601     }
2602     returnValue = Fold(node->Opnd());
2603     if (returnValue != nullptr) {
2604         node->SetOpnd(returnValue, 0);
2605         // do not delete c/c++ dead if-body here
2606         return node;
2607     }
2608     return node;
2609 }
2610 
SimplifyWhile(WhileStmtNode * node)2611 StmtNode *ConstantFold::SimplifyWhile(WhileStmtNode *node)
2612 {
2613     CHECK_NULL_FATAL(node);
2614     BaseNode *returnValue = nullptr;
2615     if (node->Opnd(0) == nullptr) {
2616         return node;
2617     }
2618     if (node->GetBody()) {
2619         (void)Simplify(node->GetBody());
2620     }
2621     returnValue = Fold(node->Opnd(0));
2622     if (returnValue != nullptr) {
2623         node->SetOpnd(returnValue, 0);
2624         // do not delete c/c++ dead while-body here
2625         return node;
2626     }
2627     return node;
2628 }
2629 
SimplifyNary(NaryStmtNode * node)2630 StmtNode *ConstantFold::SimplifyNary(NaryStmtNode *node)
2631 {
2632     CHECK_NULL_FATAL(node);
2633     BaseNode *returnValue = nullptr;
2634     for (size_t i = 0; i < node->NumOpnds(); i++) {
2635         returnValue = Fold(node->GetNopndAt(i));
2636         if (returnValue != nullptr) {
2637             node->SetNOpndAt(i, returnValue);
2638         }
2639     }
2640     return node;
2641 }
2642 
SimplifyIcall(IcallNode * node)2643 StmtNode *ConstantFold::SimplifyIcall(IcallNode *node)
2644 {
2645     CHECK_NULL_FATAL(node);
2646     BaseNode *returnValue = nullptr;
2647     for (size_t i = 0; i < node->NumOpnds(); i++) {
2648         returnValue = Fold(node->GetNopndAt(i));
2649         if (returnValue != nullptr) {
2650             node->SetNOpndAt(i, returnValue);
2651         }
2652     }
2653     // icall node transform to call node
2654     CHECK_FATAL(!node->GetNopnd().empty(), "container check");
2655     switch (node->GetNopndAt(0)->GetOpCode()) {
2656         case OP_addroffunc: {
2657             AddroffuncNode *addrofNode = static_cast<AddroffuncNode*>(node->GetNopndAt(0));
2658             CallNode *callNode = mirModule->CurFuncCodeMemPool()->New<CallNode>(
2659                 *mirModule,
2660                 (node->GetOpCode() == OP_icall || node->GetOpCode() == OP_icallproto) ? OP_call : OP_callassigned);
2661             if (node->GetOpCode() == OP_icallassigned || node->GetOpCode() == OP_icallprotoassigned) {
2662                 callNode->SetReturnVec(node->GetReturnVec());
2663             }
2664             callNode->SetPUIdx(addrofNode->GetPUIdx());
2665             for (size_t i = 1; i < node->GetNopndSize(); i++) {
2666                 callNode->GetNopnd().push_back(node->GetNopndAt(i));
2667             }
2668             callNode->SetNumOpnds(callNode->GetNopndSize());
2669             // reuse stmtID to skip update stmtFreqs when profileUse is on
2670             callNode->SetStmtID(node->GetStmtID());
2671             return callNode;
2672         }
2673         default:
2674             break;
2675     }
2676     return node;
2677 }
2678 
ProcessFunc(MIRFunction * func)2679 void ConstantFold::ProcessFunc(MIRFunction *func)
2680 {
2681     if (func->IsEmpty()) {
2682         return;
2683     }
2684     mirModule->SetCurFunction(func);
2685     (void)Simplify(func->GetBody());
2686 }
2687 
2688 }  // namespace maple
2689