• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 
28 namespace maple {
29 
30 namespace {
31 
32 constexpr uint64 kJsTypeNumber = 4;
33 constexpr uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32;  // set high 32 bit as JSTYPE_NUMBER
34 constexpr uint32 kByteSizeOfBit64 = 8;                            // byte number for 64 bit
35 constexpr uint32 kBitSizePerByte = 8;
36 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
37 
38 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
39 
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)40 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
41 {
42     if (!v1 && !v2) {
43         return std::nullopt;
44     }
45 
46     // Perform all calculations in terms of the maximum available signed type.
47     // The value will be truncated for an appropriate type when constant is created in PariToExpr function
48     // replace with PTY_i128 when IntVal supports 128bit calculation
49     return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(0, PTY_i64);
50 }
51 
52 // Perform all calculations in terms of the maximum available signed type.
53 // The value will be truncated for an appropriate type when constant is created in PariToExpr function
54 // replace with PTY_i128 when IntVal supports 128bit calculation
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)55 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
56 {
57     if (!v1 && !v2) {
58         return std::nullopt;
59     }
60 
61     if (v1 && v2) {
62         return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
63     }
64 
65     if (v1) {
66         return v1->TruncOrExtend(PTY_i64);
67     }
68 
69     // !v1 && v2
70     return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
71 }
72 
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)73 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
74 {
75     return AddSub(v1, v2, true);
76 }
77 
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)78 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
79 {
80     return AddSub(v1, v2, false);
81 }
82 
83 }  // anonymous namespace
84 
85 // This phase is designed to achieve compiler optimization by
86 // simplifying constant expressions. The constant expression
87 // is evaluated and replaced by the value calculated on compile
88 // time to save time on runtime.
89 //
90 // The main procedure shows as following:
91 // A. Analyze expression type
92 // B. Analysis operator type
93 // C. Replace the expression with the result of the operation
94 
95 // true if the constant's bits are made of only one group of contiguous 1's
96 // starting at bit 0
ContiguousBitsOf1(uint64 x)97 static bool ContiguousBitsOf1(uint64 x)
98 {
99     if (x == 0) {
100         return false;
101     }
102     return (~x & (x + 1)) == (x + 1);
103 }
104 
IsPowerOf2(uint64 num)105 inline bool IsPowerOf2(uint64 num)
106 {
107     if (num == 0) {
108         return false;
109     }
110     return (~(num - 1) & num) == num;
111 }
112 
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const113 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
114                                         BaseNode *rhs) const
115 {
116     CHECK_NULL_FATAL(old);
117     BinaryNode *result = nullptr;
118     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
119         result = old;
120     } else {
121         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
122     }
123     return result;
124 }
125 
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const126 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
127 {
128     CHECK_NULL_FATAL(old);
129     UnaryNode *result = nullptr;
130     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
131         result = old;
132     } else {
133         result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
134     }
135     return result;
136 }
137 
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const138 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode *, std::optional<IntVal>> &pair) const
139 {
140     CHECK_NULL_FATAL(pair.first);
141     BaseNode *result = pair.first;
142     if (!pair.second || *pair.second == 0) {
143         return result;
144     }
145     if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
146         // -a, 5 -> 5 - a
147         ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(pair.second->GetExtValue(), resultType);
148         BaseNode *r = static_cast<UnaryNode *>(pair.first)->Opnd(0);
149         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
150     } else {
151         if ((!pair.second->GetSignBit() && pair.second->GetSXTValue(GetPrimTypeBitSize(resultType)) > 0) ||
152             pair.second->GetSXTValue() == INT64_MIN) {
153             // +-a, 5 -> a + 5
154             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(pair.second->GetExtValue(), resultType);
155             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
156         } else {
157             // +-a, -5 -> a + -5
158             ConstvalNode *val =
159                 mirModule->GetMIRBuilder()->CreateIntConst((-pair.second.value()).GetExtValue(), resultType);
160             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
161         }
162     }
163     return result;
164 }
165 
FoldBase(BaseNode * node) const166 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
167 {
168     return std::make_pair(node, std::nullopt);
169 }
170 
Simplify(StmtNode * node)171 StmtNode *ConstantFold::Simplify(StmtNode *node)
172 {
173     CHECK_NULL_FATAL(node);
174     switch (node->GetOpCode()) {
175         case OP_dassign:
176         case OP_maydassign:
177             return SimplifyDassign(static_cast<DassignNode *>(node));
178         case OP_iassign:
179             return SimplifyIassign(static_cast<IassignNode *>(node));
180         case OP_block:
181             return SimplifyBlock(static_cast<BlockNode *>(node));
182         case OP_if:
183             return SimplifyIf(static_cast<IfStmtNode *>(node));
184         case OP_dowhile:
185         case OP_while:
186             return SimplifyWhile(static_cast<WhileStmtNode *>(node));
187         case OP_switch:
188             return SimplifySwitch(static_cast<SwitchNode *>(node));
189         case OP_eval:
190         case OP_throw:
191         case OP_free:
192         case OP_decref:
193         case OP_incref:
194         case OP_decrefreset:
195         case OP_regassign:
196             CASE_OP_ASSERT_NONNULL
197         case OP_igoto:
198             return SimplifyUnary(static_cast<UnaryStmtNode *>(node));
199         case OP_brfalse:
200         case OP_brtrue:
201             return SimplifyCondGoto(static_cast<CondGotoNode *>(node));
202         case OP_return:
203         case OP_syncenter:
204         case OP_syncexit:
205         case OP_call:
206         case OP_virtualcall:
207         case OP_superclasscall:
208         case OP_interfacecall:
209         case OP_customcall:
210         case OP_polymorphiccall:
211         case OP_intrinsiccall:
212         case OP_xintrinsiccall:
213         case OP_intrinsiccallwithtype:
214         case OP_callassigned:
215         case OP_virtualcallassigned:
216         case OP_superclasscallassigned:
217         case OP_interfacecallassigned:
218         case OP_customcallassigned:
219         case OP_polymorphiccallassigned:
220         case OP_intrinsiccallassigned:
221         case OP_intrinsiccallwithtypeassigned:
222         case OP_xintrinsiccallassigned:
223         case OP_callinstant:
224         case OP_callinstantassigned:
225         case OP_virtualcallinstant:
226         case OP_virtualcallinstantassigned:
227         case OP_superclasscallinstant:
228         case OP_superclasscallinstantassigned:
229         case OP_interfacecallinstant:
230         case OP_interfacecallinstantassigned:
231             CASE_OP_ASSERT_BOUNDARY
232             return SimplifyNary(static_cast<NaryStmtNode *>(node));
233         case OP_icall:
234         case OP_icallassigned:
235         case OP_icallproto:
236         case OP_icallprotoassigned:
237             return SimplifyIcall(static_cast<IcallNode *>(node));
238         case OP_asm:
239             return SimplifyAsm(static_cast<AsmNode *>(node));
240         default:
241             return node;
242     }
243 }
244 
DispatchFold(BaseNode * node)245 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
246 {
247     CHECK_NULL_FATAL(node);
248     switch (node->GetOpCode()) {
249         case OP_sizeoftype:
250             return FoldSizeoftype(static_cast<SizeoftypeNode *>(node));
251         case OP_abs:
252         case OP_bnot:
253         case OP_lnot:
254         case OP_neg:
255         case OP_recip:
256         case OP_sqrt:
257             return FoldUnary(static_cast<UnaryNode *>(node));
258         case OP_ceil:
259         case OP_floor:
260         case OP_round:
261         case OP_trunc:
262         case OP_cvt:
263             return FoldTypeCvt(static_cast<TypeCvtNode *>(node));
264         case OP_sext:
265         case OP_zext:
266         case OP_extractbits:
267             return FoldExtractbits(static_cast<ExtractbitsNode *>(node));
268         case OP_iaddrof:
269         case OP_iread:
270             return FoldIread(static_cast<IreadNode *>(node));
271         case OP_add:
272         case OP_ashr:
273         case OP_band:
274         case OP_bior:
275         case OP_bxor:
276         case OP_cand:
277         case OP_cior:
278         case OP_div:
279         case OP_land:
280         case OP_lior:
281         case OP_lshr:
282         case OP_max:
283         case OP_min:
284         case OP_mul:
285         case OP_rem:
286         case OP_shl:
287         case OP_sub:
288             return FoldBinary(static_cast<BinaryNode *>(node));
289         case OP_eq:
290         case OP_ne:
291         case OP_ge:
292         case OP_gt:
293         case OP_le:
294         case OP_lt:
295         case OP_cmp:
296             return FoldCompare(static_cast<CompareNode *>(node));
297         case OP_depositbits:
298             return FoldDepositbits(static_cast<DepositbitsNode *>(node));
299         case OP_select:
300             return FoldTernary(static_cast<TernaryNode *>(node));
301         case OP_array:
302             return FoldArray(static_cast<ArrayNode *>(node));
303         case OP_retype:
304             return FoldRetype(static_cast<RetypeNode *>(node));
305         case OP_gcmallocjarray:
306         case OP_gcpermallocjarray:
307             return FoldGcmallocjarray(static_cast<JarrayMallocNode *>(node));
308         default:
309             return FoldBase(static_cast<BaseNode *>(node));
310     }
311 }
312 
Negate(BaseNode * node) const313 BaseNode *ConstantFold::Negate(BaseNode *node) const
314 {
315     CHECK_NULL_FATAL(node);
316     return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
317 }
318 
Negate(UnaryNode * node) const319 BaseNode *ConstantFold::Negate(UnaryNode *node) const
320 {
321     CHECK_NULL_FATAL(node);
322     BaseNode *result = nullptr;
323     if (node->GetOpCode() == OP_neg) {
324         result = static_cast<BaseNode *>(node->Opnd(0));
325     } else {
326         BaseNode *n = static_cast<BaseNode *>(node);
327         result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
328     }
329     return result;
330 }
331 
Negate(const ConstvalNode * node) const332 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
333 {
334     CHECK_NULL_FATAL(node);
335     ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
336     CHECK_NULL_FATAL(copy);
337     copy->GetConstVal()->Neg();
338     return copy;
339 }
340 
NegateTree(BaseNode * node) const341 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
342 {
343     CHECK_NULL_FATAL(node);
344     if (node->IsUnaryNode()) {
345         return Negate(static_cast<UnaryNode *>(node));
346     } else if (node->GetOpCode() == OP_constval) {
347         return Negate(static_cast<ConstvalNode *>(node));
348     } else {
349         return Negate(static_cast<BaseNode *>(node));
350     }
351 }
352 
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const353 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
354                                                           const MIRIntConst &intConst0,
355                                                           const MIRIntConst &intConst1) const
356 {
357     uint64 result = 0;
358 
359     bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
360     bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
361     bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
362 
363     switch (opcode) {
364         case OP_eq: {
365             result = equal;
366             break;
367         }
368         case OP_ge: {
369             result = greater || equal;
370             break;
371         }
372         case OP_gt: {
373             result = greater;
374             break;
375         }
376         case OP_le: {
377             result = less || equal;
378             break;
379         }
380         case OP_lt: {
381             result = less;
382             break;
383         }
384         case OP_ne: {
385             result = !equal;
386             break;
387         }
388         case OP_cmp: {
389             if (greater) {
390                 result = kGreater;
391             } else if (equal) {
392                 result = kEqual;
393             } else if (less) {
394                 result = kLess;
395             }
396             break;
397         }
398         default:
399             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
400             break;
401     }
402     // determine the type
403     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
404     // form the constant
405     MIRIntConst *constValue = nullptr;
406     if (type.GetPrimType() == PTY_dyni32) {
407         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
408         constValue->SetValue(kJsTypeNumberInHigh32Bit | result);
409     } else {
410         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
411     }
412     return constValue;
413 }
414 
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const415 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
416                                                    const ConstvalNode &const0, const ConstvalNode &const1) const
417 {
418     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
419     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
420     CHECK_NULL_FATAL(intConst0);
421     CHECK_NULL_FATAL(intConst1);
422     MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
423     // form the ConstvalNode
424     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
425     resultConst->SetPrimType(resultType);
426     resultConst->SetConstVal(constValue);
427     return resultConst;
428 }
429 
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * intConst0,const MIRIntConst * intConst1) const430 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *intConst0,
431                                                    const MIRIntConst *intConst1) const
432 {
433     IntVal intVal0 = intConst0->GetValue();
434     IntVal intVal1 = intConst1->GetValue();
435     IntVal result(0, resultType);
436 
437     switch (opcode) {
438         case OP_add: {
439             result = intVal0.Add(intVal1, resultType);
440             break;
441         }
442         case OP_sub: {
443             result = intVal0.Sub(intVal1, resultType);
444             break;
445         }
446         case OP_mul: {
447             result = intVal0.Mul(intVal1, resultType);
448             break;
449         }
450         case OP_div: {
451             result = intVal0.Div(intVal1, resultType);
452             break;
453         }
454         case OP_rem: {
455             result = intVal0.Rem(intVal1, resultType);
456             break;
457         }
458         case OP_ashr: {
459             result = intVal0.AShr(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
460             break;
461         }
462         case OP_lshr: {
463             result = intVal0.LShr(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
464             break;
465         }
466         case OP_shl: {
467             result = intVal0.Shl(intVal1.GetZXTValue() % GetPrimTypeBitSize(resultType), resultType);
468             break;
469         }
470         case OP_max: {
471             result = Max(intVal0, intVal1, resultType);
472             break;
473         }
474         case OP_min: {
475             result = Min(intVal0, intVal1, resultType);
476             break;
477         }
478         case OP_band: {
479             result = intVal0.And(intVal1, resultType);
480             break;
481         }
482         case OP_bior: {
483             result = intVal0.Or(intVal1, resultType);
484             break;
485         }
486         case OP_bxor: {
487             result = intVal0.Xor(intVal1, resultType);
488             break;
489         }
490         case OP_cand:
491         case OP_land: {
492             result = IntVal(intVal0.GetExtValue() && intVal1.GetExtValue(), resultType);
493             break;
494         }
495         case OP_cior:
496         case OP_lior: {
497             result = IntVal(intVal0.GetExtValue() || intVal1.GetExtValue(), resultType);
498             break;
499         }
500         case OP_depositbits: {
501             // handled in FoldDepositbits
502             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstBinary");
503             break;
504         }
505         default:
506             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
507             break;
508     }
509     // determine the type
510     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
511     // form the constant
512     MIRIntConst *constValue = nullptr;
513     if (type.GetPrimType() == PTY_dyni32) {
514         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
515         constValue->SetValue(kJsTypeNumberInHigh32Bit | result.GetExtValue());
516     } else {
517         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result.GetExtValue(), type);
518     }
519     return constValue;
520 }
521 
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const522 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
523                                                const ConstvalNode &const1) const
524 {
525     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
526     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
527     CHECK_NULL_FATAL(intConst0);
528     CHECK_NULL_FATAL(intConst1);
529     MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, intConst0, intConst1);
530     // form the ConstvalNode
531     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
532     resultConst->SetPrimType(resultType);
533     resultConst->SetConstVal(constValue);
534     return resultConst;
535 }
536 
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const537 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
538                                               const ConstvalNode &const1) const
539 {
540     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
541     const MIRDoubleConst *doubleConst0 = nullptr;
542     const MIRDoubleConst *doubleConst1 = nullptr;
543     const MIRFloatConst *floatConst0 = nullptr;
544     const MIRFloatConst *floatConst1 = nullptr;
545     bool useDouble = (const0.GetPrimType() == PTY_f64);
546     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
547     resultConst->SetPrimType(resultType);
548     if (useDouble) {
549         doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
550         doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
551         CHECK_NULL_FATAL(doubleConst0);
552         CHECK_NULL_FATAL(doubleConst1);
553     } else {
554         floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
555         floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
556         CHECK_NULL_FATAL(floatConst0);
557         CHECK_NULL_FATAL(floatConst1);
558     }
559     float constValueFloat = 0.0;
560     double constValueDouble = 0.0;
561     switch (opcode) {
562         case OP_add: {
563             if (useDouble) {
564                 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
565             } else {
566                 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
567             }
568             break;
569         }
570         case OP_sub: {
571             if (useDouble) {
572                 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
573             } else {
574                 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
575             }
576             break;
577         }
578         case OP_mul: {
579             if (useDouble) {
580                 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
581             } else {
582                 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
583             }
584             break;
585         }
586         case OP_div: {
587             // for floats div by 0 is well defined
588             if (useDouble) {
589                 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
590             } else {
591                 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
592             }
593             break;
594         }
595         case OP_max: {
596             if (useDouble) {
597                 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
598                                                                                           : doubleConst1->GetValue();
599             } else {
600                 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
601                                                                                        : floatConst1->GetValue();
602             }
603             break;
604         }
605         case OP_min: {
606             if (useDouble) {
607                 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
608                                                                                           : doubleConst1->GetValue();
609             } else {
610                 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
611                                                                                        : floatConst1->GetValue();
612             }
613             break;
614         }
615         case OP_rem:
616         case OP_ashr:
617         case OP_lshr:
618         case OP_shl:
619         case OP_band:
620         case OP_bior:
621         case OP_bxor:
622         case OP_cand:
623         case OP_land:
624         case OP_cior:
625         case OP_lior:
626         case OP_depositbits: {
627             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
628             break;
629         }
630         default:
631             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
632             break;
633     }
634     if (resultType == PTY_f64) {
635         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
636     } else {
637         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
638     }
639     return resultConst;
640 }
641 
ConstValueEqual(int64 leftValue,int64 rightValue) const642 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
643 {
644     return (leftValue == rightValue);
645 }
646 
ConstValueEqual(float leftValue,float rightValue) const647 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
648 {
649     return fabs(leftValue - rightValue) <= FLT_MIN;
650 }
651 
ConstValueEqual(double leftValue,double rightValue) const652 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
653 {
654     return fabs(leftValue - rightValue) <= DBL_MIN;
655 }
656 
657 template <typename T>
FullyEqual(T leftValue,T rightValue) const658 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
659 {
660     if (isinf(leftValue) && isinf(rightValue)) {
661         // (inf == inf), add the judgement here in case of the subtraction between float type inf
662         return true;
663     } else {
664         return ConstValueEqual(leftValue, rightValue);
665     }
666 }
667 
668 template <typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const669 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
670 {
671     typename T::value_type leftValue = leftConst->GetValue();
672     typename T::value_type rightValue = rightConst->GetValue();
673     int64 result = 0;
674     ;
675     switch (op) {
676         case OP_eq: {
677             result = FullyEqual(leftValue, rightValue);
678             break;
679         }
680         case OP_ge: {
681             result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
682             break;
683         }
684         case OP_gt: {
685             result = (leftValue > rightValue);
686             break;
687         }
688         case OP_le: {
689             result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
690             break;
691         }
692         case OP_lt: {
693             result = (leftValue < rightValue);
694             break;
695         }
696         case OP_ne: {
697             result = !FullyEqual(leftValue, rightValue);
698             break;
699         }
700         case OP_cmpl:
701         case OP_cmpg: {
702             if (std::isnan(leftValue) || std::isnan(rightValue)) {
703                 result = (op == OP_cmpg) ? kGreater : kLess;
704                 break;
705             }
706         }
707             [[clang::fallthrough]];
708         case OP_cmp: {
709             if (leftValue > rightValue) {
710                 result = kGreater;
711             } else if (FullyEqual(leftValue, rightValue)) {
712                 result = kEqual;
713             } else {
714                 result = kLess;
715             }
716             break;
717         }
718         default:
719             DEBUG_ASSERT(false, "Unknown opcode for Comparison");
720             break;
721     }
722     return result;
723 }
724 
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const725 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
726                                                          const MIRConst &leftConst, const MIRConst &rightConst) const
727 {
728     int64 result = 0;
729     bool useDouble = (opndType == PTY_f64);
730     if (useDouble) {
731         result =
732             ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
733     } else {
734         result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
735     }
736     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
737     MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
738     return resultConst;
739 }
740 
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const741 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
742                                                   const ConstvalNode &const0, const ConstvalNode &const1) const
743 {
744     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
745     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
746     resultConst->SetPrimType(resultType);
747     resultConst->SetConstVal(
748         FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
749     return resultConst;
750 }
751 
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1)752 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
753                                                     const MIRConst &const0, const MIRConst &const1)
754 {
755     MIRConst *returnValue = nullptr;
756     if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
757         const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
758         const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
759         ASSERT_NOT_NULL(intConst0);
760         ASSERT_NOT_NULL(intConst1);
761         returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
762     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
763         returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
764     } else {
765         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
766     }
767     return returnValue;
768 }
769 
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const770 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
771                                                 const ConstvalNode &const0, const ConstvalNode &const1) const
772 {
773     ConstvalNode *returnValue = nullptr;
774     if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
775         returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
776     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
777         returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
778     } else {
779         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
780     }
781     return returnValue;
782 }
783 
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r)784 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
785                                                       BaseNode &l, BaseNode &r)
786 {
787     CompareNode *result = nullptr;
788     Opcode op = opcode;
789     switch (opcode) {
790         case OP_gt: {
791             op = OP_lt;
792             break;
793         }
794         case OP_lt: {
795             op = OP_gt;
796             break;
797         }
798         case OP_ge: {
799             op = OP_le;
800             break;
801         }
802         case OP_le: {
803             op = OP_ge;
804             break;
805         }
806         case OP_eq: {
807             break;
808         }
809         case OP_ne: {
810             break;
811         }
812         default:
813             DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
814             break;
815     }
816 
817     result =
818         mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
819     return result;
820 }
821 
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const822 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
823                                             const ConstvalNode &const1) const
824 {
825     ConstvalNode *returnValue = nullptr;
826     if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
827         returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
828     } else if (resultType == PTY_f32 || resultType == PTY_f64) {
829         returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
830     } else {
831         DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
832     }
833     return returnValue;
834 }
835 
FoldIntConstUnary(Opcode opcode,PrimType resultType,const ConstvalNode * constNode) const836 ConstvalNode *ConstantFold::FoldIntConstUnary(Opcode opcode, PrimType resultType, const ConstvalNode *constNode) const
837 {
838     CHECK_NULL_FATAL(constNode);
839     const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode->GetConstVal());
840     ASSERT_NOT_NULL(cst);
841     IntVal result = cst->GetValue().TruncOrExtend(resultType);
842     switch (opcode) {
843         case OP_abs: {
844             if (result.IsSigned() && result.GetSignBit()) {
845                 result = -result;
846             }
847             break;
848         }
849         case OP_bnot: {
850             result = ~result;
851             break;
852         }
853         case OP_lnot: {
854             result = {result == 0, resultType};
855             break;
856         }
857         case OP_neg: {
858             result = -result;
859             break;
860         }
861         case OP_sext:         // handled in FoldExtractbits
862         case OP_zext:         // handled in FoldExtractbits
863         case OP_extractbits:  // handled in FoldExtractbits
864         case OP_recip:
865         case OP_sqrt: {
866             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnary");
867             break;
868         }
869         default:
870             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnary");
871             break;
872     }
873     // determine the type
874     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
875     // form the constant
876     MIRIntConst *constValue = nullptr;
877     if (type.GetPrimType() == PTY_dyni32) {
878         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
879         constValue->SetValue(kJsTypeNumberInHigh32Bit | result.GetExtValue());
880     } else {
881         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result.GetExtValue(), type);
882     }
883     // form the ConstvalNode
884     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
885     resultConst->SetPrimType(type.GetPrimType());
886     resultConst->SetConstVal(constValue);
887     return resultConst;
888 }
889 
890 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const891 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
892 {
893     CHECK_NULL_FATAL(constNode);
894     typename T::value_type constValue = 0.0;
895     T *fpCst = static_cast<T *>(constNode->GetConstVal());
896     switch (opcode) {
897         case OP_recip:
898         case OP_neg:
899         case OP_abs: {
900             if (OP_recip == opcode) {
901                 constValue = typename T::value_type(1.0 / fpCst->GetValue());
902             } else if (OP_neg == opcode) {
903                 constValue = typename T::value_type(-fpCst->GetValue());
904             } else if (OP_abs == opcode) {
905                 constValue = typename T::value_type(fabs(fpCst->GetValue()));
906             }
907             break;
908         }
909         case OP_sqrt: {
910             constValue = typename T::value_type(sqrt(fpCst->GetValue()));
911             break;
912         }
913         case OP_bnot:
914         case OP_lnot:
915         case OP_sext:
916         case OP_zext:
917         case OP_extractbits: {
918             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
919             break;
920         }
921         default:
922             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
923             break;
924     }
925     auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
926     resultConst->SetPrimType(resultType);
927     if (resultType == PTY_f32) {
928         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValue));
929     } else {
930         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
931     }
932     return resultConst;
933 }
934 
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const935 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
936 {
937     ConstvalNode *returnValue = nullptr;
938     if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
939         returnValue = FoldIntConstUnary(opcode, resultType, constNode);
940     } else if (resultType == PTY_f32) {
941         returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, constNode);
942     } else if (resultType == PTY_f64) {
943         returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, constNode);
944     } else if (PTY_f128 == resultType) {
945     } else {
946         DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
947     }
948     return returnValue;
949 }
950 
FoldSizeoftype(SizeoftypeNode * node) const951 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldSizeoftype(SizeoftypeNode *node) const
952 {
953     CHECK_NULL_FATAL(node);
954     BaseNode *result = node;
955     MIRType *argType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
956     if (argType->GetKind() == kTypeScalar) {
957         MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
958         uint32 size = GetPrimTypeSize(argType->GetPrimType());
959         ConstvalNode *constValueNode = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
960         constValueNode->SetPrimType(node->GetPrimType());
961         constValueNode->SetConstVal(GlobalTables::GetIntConstTable().GetOrCreateIntConst(size, resultType));
962         result = constValueNode;
963     }
964     return std::make_pair(result, std::nullopt);
965 }
966 
FoldRetype(RetypeNode * node)967 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
968 {
969     CHECK_NULL_FATAL(node);
970     BaseNode *result = node;
971     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
972     if (node->Opnd(0) != p.first) {
973         RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
974         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
975         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
976         result = newRetNode;
977     }
978     return std::make_pair(result, std::nullopt);
979 }
980 
FoldGcmallocjarray(JarrayMallocNode * node)981 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldGcmallocjarray(JarrayMallocNode *node)
982 {
983     CHECK_NULL_FATAL(node);
984     BaseNode *result = node;
985     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
986     if (node->Opnd(0) != p.first) {
987         JarrayMallocNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
988         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldGcmallocjarray");
989         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
990         result = newRetNode;
991     }
992     return std::make_pair(result, std::nullopt);
993 }
994 
FoldUnary(UnaryNode * node)995 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
996 {
997     CHECK_NULL_FATAL(node);
998     BaseNode *result = nullptr;
999     std::optional<IntVal> sum = std::nullopt;
1000     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1001     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1002     if (cst != nullptr) {
1003         result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), cst);
1004     } else {
1005         bool isInt = IsPrimitiveInteger(node->GetPrimType());
1006         // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
1007         // primType will be set to opnd type. There will be problems in some cases. For example:
1008         // before cf:
1009         //   neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
1010         // after cf:
1011         //   neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))  # wrong!
1012         // As a workaround, we exclude u1 opnd type
1013         if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
1014             result = NegateTree(p.first);
1015             if (result->GetOpCode() == OP_neg) {
1016                 PrimType origPtyp = node->GetPrimType();
1017                 PrimType newPtyp = result->GetPrimType();
1018                 if (newPtyp == origPtyp) {
1019                     if (static_cast<UnaryNode *>(result)->Opnd(0) == node->Opnd(0)) {
1020                         // NegateTree returned an UnaryNode quivalent to `n`, so keep the
1021                         // original UnaryNode to preserve identity
1022                         result = node;
1023                     }
1024                 } else {
1025                     if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
1026                         // do not fold explicit cvt
1027                         return std::make_pair(node, std::nullopt);
1028                     } else {
1029                         result->SetPrimType(origPtyp);
1030                     }
1031                 }
1032             }
1033             if (p.second) {
1034                 sum = -(*p.second);
1035             }
1036         } else {
1037             result =
1038                 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
1039         }
1040     }
1041     return std::make_pair(result, sum);
1042 }
1043 
FloatToIntOverflow(float fval,PrimType totype)1044 static bool FloatToIntOverflow(float fval, PrimType totype)
1045 {
1046     static const float safeFloatMaxToInt32 = 2147483520.0f;  // 2^31 - 128
1047     static const float safeFloatMinToInt32 = -2147483520.0f;
1048     static const float safeFloatMaxToInt64 = 9223372036854775680.0f;  // 2^63 - 128
1049     static const float safeFloatMinToInt64 = -9223372036854775680.0f;
1050     if (!std::isfinite(fval)) {
1051         return true;
1052     }
1053     if (totype == PTY_i64 || totype == PTY_u64) {
1054         if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
1055             return true;
1056         }
1057     } else {
1058         if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
1059             return true;
1060         }
1061     }
1062     return false;
1063 }
1064 
DoubleToIntOverflow(double dval,PrimType totype)1065 static bool DoubleToIntOverflow(double dval, PrimType totype)
1066 {
1067     static const double safeDoubleMaxToInt32 = 2147482624.0;  // 2^31 - 1024
1068     static const double safeDoubleMinToInt32 = -2147482624.0;
1069     static const double safeDoubleMaxToInt64 = 9223372036854774784.0;  // 2^63 - 1024
1070     static const double safeDoubleMinToInt64 = -9223372036854774784.0;
1071     if (!std::isfinite(dval)) {
1072         return true;
1073     }
1074     if (totype == PTY_i64 || totype == PTY_u64) {
1075         if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
1076             return true;
1077         }
1078     } else {
1079         if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
1080             return true;
1081         }
1082     }
1083     return false;
1084 }
1085 
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1086 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1087 {
1088     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1089     resultConst->SetPrimType(toType);
1090     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1091     if (fromType == PTY_f32) {
1092         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1093         ASSERT_NOT_NULL(constValue);
1094         float floatValue = ceil(constValue->GetValue());
1095         if (FloatToIntOverflow(floatValue, toType)) {
1096             return nullptr;
1097         }
1098         resultConst->SetConstVal(
1099             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType));
1100     } else {
1101         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1102         ASSERT_NOT_NULL(constValue);
1103         double doubleValue = ceil(constValue->GetValue());
1104         if (DoubleToIntOverflow(doubleValue, toType)) {
1105             return nullptr;
1106         }
1107         resultConst->SetConstVal(
1108             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType));
1109     }
1110     return resultConst;
1111 }
1112 
1113 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const1114 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
1115 {
1116     DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
1117     size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
1118     bool isSigned = IsSignedInteger(resultType.GetPrimType());
1119     int64 max = LLONG_MAX >> shiftNum;
1120     uint64 umax = ULLONG_MAX >> shiftNum;
1121     int64 min = isSigned ? (LLONG_MIN >> shiftNum) : 0;
1122     if (isSigned && (value > max)) {
1123         return static_cast<T>(max);
1124     } else if (!isSigned && (value > umax)) {
1125         return static_cast<T>(umax);
1126     } else if (value < min) {
1127         return static_cast<T>(min);
1128     }
1129     return value;
1130 }
1131 
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const1132 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
1133 {
1134     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1135     if (fromType == PTY_f32) {
1136         const auto &constValue = static_cast<const MIRFloatConst &>(cst);
1137         float floatValue = constValue.GetValue();
1138         if (isFloor) {
1139             floatValue = floor(constValue.GetValue());
1140         }
1141         if (FloatToIntOverflow(floatValue, toType)) {
1142             return nullptr;
1143         }
1144         floatValue = CalIntValueFromFloatValue(floatValue, resultType);
1145         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1146     } else {
1147         const auto &constValue = static_cast<const MIRDoubleConst &>(cst);
1148         double doubleValue = constValue.GetValue();
1149         if (isFloor) {
1150             doubleValue = floor(constValue.GetValue());
1151         }
1152         if (DoubleToIntOverflow(doubleValue, toType)) {
1153             return nullptr;
1154         }
1155         doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1156         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1157     }
1158 }
1159 
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1160 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1161 {
1162     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1163     resultConst->SetPrimType(toType);
1164     resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1165     return resultConst;
1166 }
1167 
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1168 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1169 {
1170     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1171     if (fromType == PTY_f32) {
1172         const auto &constValue = static_cast<const MIRFloatConst &>(cst);
1173         float floatValue = round(constValue.GetValue());
1174         if (FloatToIntOverflow(floatValue, toType)) {
1175             return nullptr;
1176         }
1177         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1178     } else if (fromType == PTY_f64) {
1179         const auto &constValue = static_cast<const MIRDoubleConst &>(cst);
1180         double doubleValue = round(constValue.GetValue());
1181         if (DoubleToIntOverflow(doubleValue, toType)) {
1182             return nullptr;
1183         }
1184         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1185     } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1186         const auto &constValue = static_cast<const MIRIntConst &>(cst);
1187         if (IsSignedInteger(fromType)) {
1188             int64 fromValue = constValue.GetExtValue();
1189             float floatValue = round(static_cast<float>(fromValue));
1190             if (static_cast<int64>(floatValue) == fromValue) {
1191                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1192             }
1193         } else {
1194             uint64 fromValue = constValue.GetExtValue();
1195             float floatValue = round(static_cast<float>(fromValue));
1196             if (static_cast<uint64>(floatValue) == fromValue) {
1197                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1198             }
1199         }
1200     } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1201         const auto &constValue = static_cast<const MIRIntConst &>(cst);
1202         if (IsSignedInteger(fromType)) {
1203             int64 fromValue = constValue.GetExtValue();
1204             double doubleValue = round(static_cast<double>(fromValue));
1205             if (static_cast<int64>(doubleValue) == fromValue) {
1206                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1207             }
1208         } else {
1209             uint64 fromValue = constValue.GetExtValue();
1210             double doubleValue = round(static_cast<double>(fromValue));
1211             if (static_cast<uint64>(doubleValue) == fromValue) {
1212                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1213             }
1214         }
1215     }
1216     return nullptr;
1217 }
1218 
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1219 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1220 {
1221     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1222     resultConst->SetPrimType(toType);
1223     resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1224     return resultConst;
1225 }
1226 
FoldTrunk(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1227 ConstvalNode *ConstantFold::FoldTrunk(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1228 {
1229     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1230     resultConst->SetPrimType(toType);
1231     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1232     if (fromType == PTY_f32) {
1233         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1234         CHECK_NULL_FATAL(constValue);
1235         float floatValue = trunc(constValue->GetValue());
1236         if (FloatToIntOverflow(floatValue, toType)) {
1237             return nullptr;
1238         }
1239         resultConst->SetConstVal(
1240             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType));
1241     } else {
1242         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1243         CHECK_NULL_FATAL(constValue);
1244         double doubleValue = trunc(constValue->GetValue());
1245         if (DoubleToIntOverflow(doubleValue, toType)) {
1246             return nullptr;
1247         }
1248         resultConst->SetConstVal(
1249             GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType));
1250     }
1251     return resultConst;
1252 }
1253 
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1254 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1255 {
1256     if (IsPrimitiveDynType(fromType) || IsPrimitiveDynType(toType)) {
1257         // do not fold
1258         return nullptr;
1259     }
1260     if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1261         MIRConst *toConst = nullptr;
1262         uint32 fromSize = GetPrimTypeBitSize(fromType);
1263         uint32 toSize = GetPrimTypeBitSize(toType);
1264         // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1265         if (fromType == PTY_u1) {
1266             fromSize = 1;
1267         }
1268         if (toType == PTY_u1) {
1269             toSize = 1;
1270         }
1271         if (toSize > fromSize) {
1272             Opcode op = OP_zext;
1273             if (IsSignedInteger(fromType)) {
1274                 op = OP_sext;
1275             }
1276             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1277             ASSERT_NOT_NULL(constVal);
1278             toConst = FoldSignExtendMIRConst(op, toType, fromSize, constVal->GetValue().TruncOrExtend(fromType));
1279         } else {
1280             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1281             ASSERT_NOT_NULL(constVal);
1282             MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1283             toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(constVal->GetExtValue(), type);
1284         }
1285         return toConst;
1286     }
1287     if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1288         MIRConst *toConst = nullptr;
1289         if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1290             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1291             const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1292             ASSERT_NOT_NULL(fromValue);
1293             float floatValue = static_cast<float>(fromValue->GetValue());
1294             MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1295             toConst = toValue;
1296         } else {
1297             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1298             const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1299             ASSERT_NOT_NULL(fromValue);
1300             double doubleValue = static_cast<double>(fromValue->GetValue());
1301             MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1302             toConst = toValue;
1303         }
1304         return toConst;
1305     }
1306     if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1307         return FoldFloorMIRConst(cst, fromType, toType, false);
1308     }
1309     if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1310         return FoldRoundMIRConst(cst, fromType, toType);
1311     }
1312     CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1313     return nullptr;
1314 }
1315 
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1316 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1317 {
1318     MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1319     if (toConstValue == nullptr) {
1320         return nullptr;
1321     }
1322     ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1323     toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1324     toConst->SetConstVal(toConstValue);
1325     return toConst;
1326 }
1327 
1328 // return a primType with bit size >= bitSize (and the nearest one),
1329 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1330 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1331 {
1332     bool isSigned = IsSignedInteger(ptyp);
1333     bool isFloat = IsPrimitiveFloat(ptyp);
1334     if (bitSize == 1) { // 1 bit
1335         return PTY_u1;
1336     }
1337     if (bitSize <= 8) { // 8 bit
1338         return isSigned ? PTY_i8 : PTY_u8;
1339     }
1340     if (bitSize <= 16) { // 16 bit
1341         return isSigned ? PTY_i16 : PTY_u16;
1342     }
1343     if (bitSize <= 32) { // 32 bit
1344         return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1345     }
1346     if (bitSize <= 64) { // 64 bit
1347         return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1348     }
1349     if (bitSize <= 128) { // 128 bit
1350         return isFloat ? PTY_f128 : (isSigned ? PTY_i128 : PTY_u128);
1351     }
1352     return ptyp;
1353 }
1354 
GetIntPrimTypeMax(PrimType ptyp)1355 size_t GetIntPrimTypeMax(PrimType ptyp)
1356 {
1357     switch (ptyp) {
1358         case PTY_u1:
1359             return 1;
1360         case PTY_u8:
1361             return UINT8_MAX;
1362         case PTY_i8:
1363             return INT8_MAX;
1364         case PTY_u16:
1365             return UINT16_MAX;
1366         case PTY_i16:
1367             return INT16_MAX;
1368         case PTY_u32:
1369             return UINT32_MAX;
1370         case PTY_i32:
1371             return INT32_MAX;
1372         case PTY_u64:
1373             return UINT64_MAX;
1374         case PTY_i64:
1375             return INT64_MAX;
1376         default:
1377             CHECK_FATAL(false, "NYI");
1378     }
1379 }
1380 
GetIntPrimTypeMin(PrimType ptyp)1381 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1382 {
1383     if (IsUnsignedInteger(ptyp)) {
1384         return 0;
1385     }
1386     switch (ptyp) {
1387         case PTY_i8:
1388             return INT8_MIN;
1389         case PTY_i16:
1390             return INT16_MIN;
1391         case PTY_i32:
1392             return INT32_MIN;
1393         case PTY_i64:
1394             return INT64_MIN;
1395         default:
1396             CHECK_FATAL(false, "NYI");
1397     }
1398 }
1399 
1400 // return a primtype to represent value range of expr
GetExprValueRangePtyp(BaseNode * expr)1401 PrimType GetExprValueRangePtyp(BaseNode *expr)
1402 {
1403     PrimType ptyp = expr->GetPrimType();
1404     Opcode op = expr->GetOpCode();
1405     if (expr->IsLeaf()) {
1406         return ptyp;
1407     }
1408     if (kOpcodeInfo.IsTypeCvt(op)) {
1409         auto *node = static_cast<TypeCvtNode *>(expr);
1410         if (GetPrimTypeSize(node->FromType()) < GetPrimTypeSize(node->GetPrimType())) {
1411             return GetExprValueRangePtyp(expr->Opnd(0));
1412         }
1413         return ptyp;
1414     }
1415     if (op == OP_sext || op == OP_zext || op == OP_extractbits) {
1416         auto *node = static_cast<ExtractbitsNode *>(expr);
1417         uint8 size = node->GetBitsSize();
1418         return GetNearestSizePtyp(size, expr->GetPrimType());
1419     }
1420     // find max size primtype of opnds.
1421     size_t maxTypeSize = 1;
1422     size_t ptypSize = GetPrimTypeSize(ptyp);
1423     for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
1424         PrimType opndPtyp = GetExprValueRangePtyp(expr->Opnd(i));
1425         size_t opndSize = GetPrimTypeSize(opndPtyp);
1426         if (ptypSize <= opndSize) {
1427             return ptyp;
1428         }
1429         if (maxTypeSize < opndSize) {
1430             maxTypeSize = opndSize;
1431             constexpr size_t intMaxSize = 8;
1432             if (maxTypeSize == intMaxSize) {
1433                 break;
1434             }
1435         }
1436     }
1437     return GetNearestSizePtyp(maxTypeSize, ptyp);
1438 }
1439 
FoldTypeCvt(TypeCvtNode * node)1440 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1441 {
1442     CHECK_NULL_FATAL(node);
1443     BaseNode *result = nullptr;
1444     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1445     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1446     PrimType destPtyp = node->GetPrimType();
1447     PrimType fromPtyp = node->FromType();
1448     if (cst != nullptr) {
1449         switch (node->GetOpCode()) {
1450             case OP_ceil: {
1451                 result = FoldCeil(*cst, fromPtyp, destPtyp);
1452                 break;
1453             }
1454             case OP_cvt: {
1455                 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1456                 break;
1457             }
1458             case OP_floor: {
1459                 result = FoldFloor(*cst, fromPtyp, destPtyp);
1460                 break;
1461             }
1462             case OP_round: {
1463                 result = FoldRound(*cst, fromPtyp, destPtyp);
1464                 break;
1465             }
1466             case OP_trunc: {
1467                 result = FoldTrunk(*cst, fromPtyp, destPtyp);
1468                 break;
1469             }
1470             default:
1471                 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1472                 break;
1473         }
1474     } else if (node->GetOpCode() == OP_cvt && GetPrimTypeSize(fromPtyp) == GetPrimTypeSize(destPtyp)) {
1475         if ((IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1476             (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1477             (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp) &&
1478              GetPrimTypeSize(fromPtyp) == GetPrimTypeSize(destPtyp))) {
1479             // the cvt is redundant
1480             return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1481         }
1482     } else if (node->GetOpCode() == OP_cvt && p.second.has_value() && *p.second != 0 &&
1483                p.second->GetExtValue() < INT_MAX && p.second->GetSXTValue() > -kMaxOffset &&
1484                IsPrimitiveInteger(node->GetPrimType()) && IsSignedInteger(node->FromType()) &&
1485                GetPrimTypeSize(node->GetPrimType()) > GetPrimTypeSize(node->FromType())) {
1486         bool simplifyCvt = false;
1487         // Integer values must not be allowd to wrap, for pointer arithmetic, including array indexing
1488         // Therefore, if dest-type of cvt expr is pointer type, like : cvt ptr/ref/a64 (op (expr, constval))
1489         // we assume that result of cvt opnd will never wrap
1490         if (IsPossible64BitAddress(node->GetPrimType()) && node->GetPrimType() != PTY_u64) {
1491             simplifyCvt = true;
1492         } else {
1493             // Ensure : op(expr, constval) is not overflow
1494             // Example :
1495             // given a expr like : r = (i + constVal), and its mplIR may be:
1496             // cvt i64 i32 (add i32 (cvt i32 i16 (dread i16 i), constval i32 constVal))
1497             // expr primtype(i32) range : |I32_MIN--------------------0--------------------I32_MAX|
1498             // expr value range(i16)    :               |I16_MIN------0------I16_MAX|
1499             // value range difference   : |-------------|                           |-------------|
1500             // constVal value range     : [I32_MIN - I16MIN, I32_MAX - I16_MAX], (i + constVal) will never pos/neg
1501             // overflow.
1502             PrimType valuePtyp = GetExprValueRangePtyp(p.first);
1503             PrimType exprPtyp = p.first->GetPrimType();
1504             int64 constVal = p.second->GetExtValue();
1505             // cannot only check (min <= constval && constval <= max) here.
1506             // If constval = -1, it will be SIZE_T_MAX when converted to size_t
1507             if ((constVal < 0 && constVal >= GetIntPrimTypeMin(exprPtyp) - GetIntPrimTypeMin(valuePtyp)) ||
1508                 (constVal > 0 &&
1509                  static_cast<size_t>(constVal) <= GetIntPrimTypeMax(exprPtyp) - GetIntPrimTypeMax(valuePtyp))) {
1510                 simplifyCvt = true;
1511             }
1512         }
1513         if (simplifyCvt) {
1514             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, node->GetPrimType(), node->FromType(),
1515                                                                        p.first);
1516             return std::make_pair(result, p.second->TruncOrExtend(node->GetPrimType()));
1517         }
1518     }
1519     if (result == nullptr) {
1520         BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1521         if (e != node->Opnd(0)) {
1522             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1523                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1524         } else {
1525             result = node;
1526         }
1527     }
1528     return std::make_pair(result, std::nullopt);
1529 }
1530 
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1531 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1532 {
1533     uint64 result = opcode == OP_sext ? val.GetSXTValue(size) : val.GetZXTValue(size);
1534     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1535     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1536     return constValue;
1537 }
1538 
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1539 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1540                                            const ConstvalNode &cst) const
1541 {
1542     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1543     const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1544     ASSERT_NOT_NULL(intCst);
1545     IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1546     MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1547     resultConst->SetPrimType(toConst->GetType().GetPrimType());
1548     resultConst->SetConstVal(toConst);
1549     return resultConst;
1550 }
1551 
1552 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(ExtractbitsNode * x,MIRFunction * f)1553 static bool ExtractbitsRedundant(ExtractbitsNode *x, MIRFunction *f)
1554 {
1555     if (GetPrimTypeSize(x->GetPrimType()) == k8ByteSize) {
1556         return false;  // this is trying to be conservative
1557     }
1558     BaseNode *opnd = x->Opnd(0);
1559     MIRType *mirType = nullptr;
1560     if (opnd->GetOpCode() == OP_dread) {
1561         DreadNode *dread = static_cast<DreadNode *>(opnd);
1562         MIRSymbol *sym = f->GetLocalOrGlobalSymbol(dread->GetStIdx());
1563         mirType = sym->GetType();
1564         if (dread->GetFieldID() != 0) {
1565             MIRStructType *structType = dynamic_cast<MIRStructType *>(mirType);
1566             if (structType == nullptr) {
1567                 return false;
1568             }
1569             mirType = structType->GetFieldType(dread->GetFieldID());
1570         }
1571     } else if (opnd->GetOpCode() == OP_iread) {
1572         IreadNode *iread = static_cast<IreadNode *>(opnd);
1573         MIRPtrType *ptrType =
1574             dynamic_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1575         if (ptrType == nullptr) {
1576             return false;
1577         }
1578         mirType = ptrType->GetPointedType();
1579         if (iread->GetFieldID() != 0) {
1580             MIRStructType *structType = dynamic_cast<MIRStructType *>(mirType);
1581             if (structType == nullptr) {
1582                 return false;
1583             }
1584             mirType = structType->GetFieldType(iread->GetFieldID());
1585         }
1586     } else {
1587         return false;
1588     }
1589     return IsPrimitiveInteger(mirType->GetPrimType()) && mirType->GetSize() * kBitSizePerByte == x->GetBitsSize() &&
1590            IsSignedInteger(mirType->GetPrimType()) == IsSignedInteger(x->GetPrimType());
1591 }
1592 
1593 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1594 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1595 {
1596     CHECK_NULL_FATAL(node);
1597     BaseNode *result = nullptr;
1598     uint8 offset = node->GetBitsOffset();
1599     uint8 size = node->GetBitsSize();
1600     Opcode opcode = node->GetOpCode();
1601     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1602     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1603     if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1604         result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1605         return std::make_pair(result, std::nullopt);
1606     }
1607     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1608     if (e != node->Opnd(0)) {
1609         result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1610                                                                        size, e);
1611     } else {
1612         result = node;
1613     }
1614     // check for consecutive and redundant extraction of same bits
1615     BaseNode *opnd = result->Opnd(0);
1616     Opcode opndOp = opnd->GetOpCode();
1617     if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1618         uint8 opndOffset = static_cast<ExtractbitsNode *>(opnd)->GetBitsOffset();
1619         uint8 opndSize = static_cast<ExtractbitsNode *>(opnd)->GetBitsSize();
1620         if (offset == opndOffset && size == opndSize) {
1621             result->SetOpnd(opnd->Opnd(0), 0);  // delete the redundant extraction
1622         }
1623     }
1624     if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1625         if (ExtractbitsRedundant(static_cast<ExtractbitsNode *>(result), mirModule->CurFunction())) {
1626             return std::make_pair(result->Opnd(0), std::nullopt);
1627         }
1628     }
1629     return std::make_pair(result, std::nullopt);
1630 }
1631 
FoldIread(IreadNode * node)1632 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1633 {
1634     CHECK_NULL_FATAL(node);
1635     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1636     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1637     node->SetOpnd(e, 0);
1638     BaseNode *result = node;
1639     if (e->GetOpCode() != OP_addrof) {
1640         return std::make_pair(result, std::nullopt);
1641     }
1642 
1643     AddrofNode *addrofNode = static_cast<AddrofNode *>(e);
1644     MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1645     TyIdx typeId = msy->GetTyIdx();
1646     CHECK_FATAL(GlobalTables::GetTypeTable().GetTypeTable().empty() == false, "container check");
1647     MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1648     if (addrofNode->GetFieldID() != 0 && (msyType->GetKind() == kTypeStruct || msyType->GetKind() == kTypeClass)) {
1649         msyType = static_cast<MIRStructType *>(msyType)->GetFieldType(addrofNode->GetFieldID());
1650     }
1651     MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1652     // If the high level type of iaddrof/iread doesn't match
1653     // the type of addrof's rhs, this optimization cannot be done.
1654     if (ptrType->GetPointedType() != msyType) {
1655         return std::make_pair(result, std::nullopt);
1656     }
1657 
1658     Opcode op = node->GetOpCode();
1659     FieldID fieldID = node->GetFieldID();
1660     if (op == OP_iaddrof) {
1661         AddrofNode *newAddrof = addrofNode->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1662         CHECK_NULL_FATAL(newAddrof);
1663         newAddrof->SetFieldID(newAddrof->GetFieldID() + fieldID);
1664         result = newAddrof;
1665     } else if (op == OP_iread) {
1666         result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1667                                                                   node->GetFieldID() + addrofNode->GetFieldID());
1668     }
1669     return std::make_pair(result, std::nullopt);
1670 }
1671 
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1672 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1673 {
1674     switch (op) {
1675         case OP_add: {
1676             int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1677             if (IsUnsignedInteger(primType)) {
1678                 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1679             }
1680             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1681             return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1682                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1683                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1684                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1685         }
1686         case OP_sub: {
1687             if (IsUnsignedInteger(primType)) {
1688                 return cstA < cstB;
1689             }
1690             int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1691             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1692             return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1693                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1694                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1695                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1696         }
1697         default: {
1698             return false;
1699         }
1700     }
1701 }
1702 
FoldBinary(BinaryNode * node)1703 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1704 {
1705     CHECK_NULL_FATAL(node);
1706     BaseNode *result = nullptr;
1707     std::optional<IntVal> sum = std::nullopt;
1708     Opcode op = node->GetOpCode();
1709     PrimType primType = node->GetPrimType();
1710     PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1711     PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1712     std::pair<BaseNode *, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1713     std::pair<BaseNode *, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1714     BaseNode *l = lp.first;
1715     BaseNode *r = rp.first;
1716     ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1717     ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1718     bool isInt = IsPrimitiveInteger(primType);
1719 
1720     if (lConst != nullptr && rConst != nullptr) {
1721         MIRConst *lConstVal = lConst->GetConstVal();
1722         MIRConst *rConstVal = rConst->GetConstVal();
1723         ASSERT_NOT_NULL(lConstVal);
1724         ASSERT_NOT_NULL(rConstVal);
1725         // Don't fold div by 0, for floats div by 0 is well defined.
1726         if ((op == OP_div || op == OP_rem) && isInt &&
1727             !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1728             result = NewBinaryNode(node, op, primType, lConst, rConst);
1729         } else {
1730             // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1731             // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1732             // logic since the alternative is to return pair(result = nullptr, sum = 6).
1733             // Doing so would introduce many nullptr checks in the code. See previous
1734             // commits that implemented that logic for a comparison.
1735             result = FoldConstBinary(op, primType, *lConst, *rConst);
1736         }
1737     } else if (lConst != nullptr && isInt) {
1738         MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1739         ASSERT_NOT_NULL(mcst);
1740         PrimType cstTyp = mcst->GetType().GetPrimType();
1741         IntVal cst = mcst->GetValue();
1742         if (op == OP_add) {
1743             sum = cst + rp.second;
1744             result = r;
1745         } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1746             // We exclude u1 type for fixing the following wrong example:
1747             // before cf:
1748             //   sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1749             // after cf:
1750             //   add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1751             sum = cst - rp.second;
1752             result = NegateTree(r);
1753         } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1754                     op == OP_band || op == OP_cand || op == OP_land) &&
1755                    cst == 0) {
1756             // 0 * X -> 0
1757             // 0 / X -> 0
1758             // 0 % X -> 0
1759             // 0 >> X -> 0
1760             // 0 << X -> 0
1761             // 0 & X -> 0
1762             // 0 && X -> 0
1763             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1764         } else if (op == OP_mul && cst == 1) {
1765             // 1 * X --> X
1766             sum = rp.second;
1767             result = r;
1768         } else if (op == OP_bior && cst == -1) {
1769             // (-1) | X -> -1
1770             result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp);
1771         } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1772             // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1773             sum = cst * rp.second;
1774             if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1775                 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1776             }
1777             result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1778         } else if (op == OP_lior || op == OP_cior) {
1779             if (cst != 0) {
1780                 // 5 || X -> 1
1781                 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1782             } else {
1783                 // when cst is zero
1784                 // 0 || X -> (X != 0);
1785                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1786                     OP_ne, primType, r->GetPrimType(), r,
1787                     mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1788             }
1789         } else if ((op == OP_cand || op == OP_land) && cst != 0) {
1790             // 5 && X -> (X != 0)
1791             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1792                 OP_ne, primType, r->GetPrimType(), r, mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1793         } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1794             // 0 | X -> X
1795             // 0 ^ X -> X
1796             sum = rp.second;
1797             result = r;
1798         } else {
1799             result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1800         }
1801         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1802             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1803         }
1804     } else if (rConst != nullptr && isInt) {
1805         MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1806         ASSERT_NOT_NULL(mcst);
1807         PrimType cstTyp = mcst->GetType().GetPrimType();
1808         IntVal cst = mcst->GetValue();
1809         if (op == OP_add) {
1810             if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1811                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1812             } else {
1813                 result = l;
1814                 sum = lp.second + cst;
1815             }
1816         } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1817             {
1818                 result = l;
1819                 sum = lp.second - cst;
1820             }
1821         } else if ((op == OP_mul || op == OP_band || op == OP_cand || op == OP_land) && cst == 0) {
1822             // X * 0 -> 0
1823             // X & 0 -> 0
1824             // X && 0 -> 0
1825             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1826         } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1827             // case [X * 1 -> X]
1828             // case [X / 1 = X]
1829             sum = lp.second;
1830             result = l;
1831         } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1832             // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1833             sum = lp.second * cst;
1834             if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1835                 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1836             }
1837             result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1838         } else if (op == OP_band && cst == -1) {
1839             // X & (-1) -> X
1840             sum = lp.second;
1841             result = l;
1842         } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1843                    (!lp.second.has_value() || lp.second == 0)) {
1844             bool fold2extractbits = false;
1845             if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1846                 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1847                 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1848                     ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1849                     int64 shrAmt = static_cast<MIRIntConst *>(shrOpnd->GetConstVal())->GetExtValue();
1850                     uint64 ucst = cst.GetZXTValue();
1851                     uint64 bsize = 0;
1852                     do {
1853                         bsize++;
1854                         ucst >>= 1;
1855                     } while (ucst != 0);
1856                     if (shrAmt + bsize <= GetPrimTypeSize(primType) * kBitSizePerByte) {
1857                         fold2extractbits = true;
1858                         // change to use extractbits
1859                         result = mirModule->GetMIRBuilder()->CreateExprExtractbits(
1860                             OP_extractbits, GetUnsignedPrimType(primType), shrAmt, bsize, shrNode->Opnd(0));
1861                         sum = std::nullopt;
1862                     }
1863                 }
1864             }
1865             if (!fold2extractbits) {
1866                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1867                 sum = std::nullopt;
1868             }
1869         } else if (op == OP_bior && cst == -1) {
1870             // X | (-1) -> -1
1871             result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp);
1872         } else if ((op == OP_lior || op == OP_cior)) {
1873             if (cst == 0) {
1874                 // X || 0 -> X
1875                 sum = lp.second;
1876                 result = l;
1877             } else if (!cst.GetSignBit()) {
1878                 // X || 5 -> 1
1879                 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1880             } else {
1881                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1882             }
1883         } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1884             // X >> 0 -> X
1885             // X << 0 -> X
1886             // X | 0 -> X
1887             // X ^ 0 -> X
1888             sum = lp.second;
1889             result = l;
1890         } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1891             // bxor i32 (
1892             //   cvt i32 u1 (regread u1 %13),
1893             //  constValue i32 1),
1894             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1895             if (l->GetOpCode() == OP_cvt) {
1896                 TypeCvtNode *cvtNode = static_cast<TypeCvtNode *>(l);
1897                 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1898                     BaseNode *base = cvtNode->Opnd(0);
1899                     BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1900                     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(base);
1901                     BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1902                     result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1903                 }
1904             }
1905         } else if (op == OP_rem && cst == 1) {
1906             // X % 1 -> 0
1907             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1908         } else {
1909             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1910         }
1911         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1912             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1913         }
1914     } else if (isInt && (op == OP_add || op == OP_sub)) {
1915         if (op == OP_add) {
1916             result = NewBinaryNode(node, op, primType, l, r);
1917             sum = lp.second + rp.second;
1918         } else if (node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1919             // if fold is (x - (y - z))    ->     (x - neg(z)) - y
1920             // (x - neg(z)) Could cross the int limit
1921             // return node
1922             result = node;
1923         } else {
1924             result = NewBinaryNode(node, op, primType, l, r);
1925             sum = lp.second - rp.second;
1926         }
1927     } else {
1928         result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1929     }
1930     return std::make_pair(result, sum);
1931 }
1932 
SimplifyDoubleCompare(CompareNode & compareNode) const1933 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1934 {
1935     // See arm manual B.cond(P2993) and FCMP(P1091)
1936     CompareNode *node = &compareNode;
1937     BaseNode *result = node;
1938     BaseNode *l = node->Opnd(0);
1939     BaseNode *r = node->Opnd(1);
1940     if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1941         if (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) {
1942             ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1943             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1944                 const CompareNode *compNode = static_cast<CompareNode *>(l);
1945                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1946                     Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1947                     compNode->Opnd(0), compNode->Opnd(1));
1948             }
1949         } else if (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval) {
1950             ConstvalNode *constNode = static_cast<ConstvalNode *>(l);
1951             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1952                 const CompareNode *compNode = static_cast<CompareNode *>(r);
1953                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1954                     Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1955                     compNode->Opnd(0), compNode->Opnd(1));
1956             }
1957         } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1958             // ne (u1 x, constValue 0)  <==> x
1959             ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1960             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1961                 BaseNode *opnd = l;
1962                 do {
1963                     if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1964                         result = opnd;
1965                         break;
1966                     } else if (opnd->GetOpCode() == OP_cvt) {
1967                         TypeCvtNode *cvtNode = static_cast<TypeCvtNode *>(opnd);
1968                         opnd = cvtNode->Opnd(0);
1969                     } else {
1970                         opnd = nullptr;
1971                     }
1972                 } while (opnd != nullptr);
1973             }
1974         } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1975             ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1976             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1977                 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1978                 result = l;
1979                 result->SetOpCode(l->GetOpCode() == OP_ne ? OP_eq : OP_ne);
1980             }
1981         }
1982     } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1983         if (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) {
1984             ConstvalNode *constNode = static_cast<ConstvalNode *>(r);
1985             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1986                 const CompareNode *compNode = static_cast<CompareNode *>(l);
1987                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1988                     Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1989                     compNode->Opnd(0), compNode->Opnd(1));
1990             }
1991         } else if (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval) {
1992             ConstvalNode *constNode = static_cast<ConstvalNode *>(l);
1993             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1994                 const CompareNode *compNode = static_cast<CompareNode *>(r);
1995                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1996                     Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), compNode->GetOpndType(),
1997                     compNode->Opnd(1), compNode->Opnd(0));
1998             }
1999         }
2000     }
2001     return result;
2002 }
2003 
FoldCompare(CompareNode * node)2004 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
2005 {
2006     CHECK_NULL_FATAL(node);
2007     BaseNode *result = nullptr;
2008     std::pair<BaseNode *, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
2009     std::pair<BaseNode *, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
2010     ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
2011     ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
2012     Opcode opcode = node->GetOpCode();
2013     if (lConst != nullptr && rConst != nullptr && !IsPrimitiveDynType(node->GetOpndType())) {
2014         result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
2015     } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
2016                lConst->GetConstVal()->GetKind() == kConstInt) {
2017         BaseNode *l = lp.first;
2018         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2019         result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
2020     } else {
2021         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
2022         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2023         if (l != node->Opnd(0) || r != node->Opnd(1)) {
2024             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2025                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
2026         } else {
2027             result = node;
2028         }
2029         auto *compareNode = static_cast<CompareNode *>(result);
2030         CHECK_NULL_FATAL(compareNode);
2031         result = SimplifyDoubleCompare(*compareNode);
2032     }
2033     return std::make_pair(result, std::nullopt);
2034 }
2035 
Fold(BaseNode * node)2036 BaseNode *ConstantFold::Fold(BaseNode *node)
2037 {
2038     if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
2039         return nullptr;
2040     }
2041     std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node);
2042     BaseNode *result = PairToExpr(node->GetPrimType(), p);
2043     if (result == node) {
2044         result = nullptr;
2045     }
2046     return result;
2047 }
2048 
FoldDepositbits(DepositbitsNode * node)2049 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldDepositbits(DepositbitsNode *node)
2050 {
2051     CHECK_NULL_FATAL(node);
2052     BaseNode *result = nullptr;
2053     uint8 bitsOffset = node->GetBitsOffset();
2054     uint8 bitsSize = node->GetBitsSize();
2055     std::pair<BaseNode *, std::optional<IntVal>> leftPair = DispatchFold(node->Opnd(0));
2056     std::pair<BaseNode *, std::optional<IntVal>> rightPair = DispatchFold(node->Opnd(1));
2057     ConstvalNode *leftConst = safe_cast<ConstvalNode>(leftPair.first);
2058     ConstvalNode *rightConst = safe_cast<ConstvalNode>(rightPair.first);
2059     if (leftConst != nullptr && rightConst != nullptr) {
2060         MIRIntConst *intConst0 = safe_cast<MIRIntConst>(leftConst->GetConstVal());
2061         MIRIntConst *intConst1 = safe_cast<MIRIntConst>(rightConst->GetConstVal());
2062         ASSERT_NOT_NULL(intConst0);
2063         ASSERT_NOT_NULL(intConst1);
2064         ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
2065         resultConst->SetPrimType(node->GetPrimType());
2066         MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
2067         MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
2068         uint64 op0ExtractVal = 0;
2069         uint64 op1ExtractVal = 0;
2070         uint64 mask0 = (1LLU << (bitsSize + bitsOffset)) - 1;
2071         uint64 mask1 = (1LLU << bitsOffset) - 1;
2072         uint64 op0Mask = ~(mask0 ^ mask1);
2073         op0ExtractVal = (intConst0->GetExtValue() & op0Mask);
2074         op1ExtractVal = (intConst1->GetExtValue() << bitsOffset) & ((1ULL << (bitsSize + bitsOffset)) - 1);
2075         constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
2076             static_cast<int64>(op0ExtractVal | op1ExtractVal), constValue->GetType());
2077         resultConst->SetConstVal(constValue);
2078         result = resultConst;
2079     } else {
2080         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), leftPair);
2081         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rightPair);
2082         if (l != node->Opnd(0) || r != node->Opnd(1)) {
2083             result = mirModule->CurFuncCodeMemPool()->New<DepositbitsNode>(
2084                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), bitsOffset, bitsSize, l, r);
2085         } else {
2086             result = node;
2087         }
2088     }
2089     return std::make_pair(result, std::nullopt);
2090 }
2091 
FoldArray(ArrayNode * node)2092 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldArray(ArrayNode *node)
2093 {
2094     CHECK_NULL_FATAL(node);
2095     BaseNode *result = nullptr;
2096     size_t i = 0;
2097     bool isFolded = false;
2098     ArrayNode *arrNode = mirModule->CurFuncCodeMemPool()->New<ArrayNode>(*mirModule, PrimType(node->GetPrimType()),
2099                                                                          node->GetTyIdx(), node->GetBoundsCheck());
2100     for (i = 0; i < node->GetNopndSize(); i++) {
2101         std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->GetNopndAt(i));
2102         BaseNode *tmpNode = PairToExpr(node->GetNopndAt(i)->GetPrimType(), p);
2103         if (tmpNode != node->GetNopndAt(i)) {
2104             isFolded = true;
2105         }
2106         arrNode->GetNopnd().push_back(tmpNode);
2107         arrNode->SetNumOpnds(arrNode->GetNumOpnds() + 1);
2108     }
2109     if (isFolded) {
2110         result = arrNode;
2111     } else {
2112         result = node;
2113     }
2114     return std::make_pair(result, std::nullopt);
2115 }
2116 
FoldTernary(TernaryNode * node)2117 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldTernary(TernaryNode *node)
2118 {
2119     CHECK_NULL_FATAL(node);
2120     constexpr size_t kFirst = 0;
2121     constexpr size_t kSecond = 1;
2122     constexpr size_t kThird = 2;
2123     BaseNode *result = node;
2124     std::vector<PrimType> primTypes;
2125     std::vector<std::pair<BaseNode *, std::optional<IntVal>>> p;
2126     for (size_t i = 0; i < node->NumOpnds(); i++) {
2127         BaseNode *tempNopnd = node->Opnd(i);
2128         CHECK_NULL_FATAL(tempNopnd);
2129         primTypes.push_back(tempNopnd->GetPrimType());
2130         p.push_back(DispatchFold(tempNopnd));
2131     }
2132     if (node->GetOpCode() == OP_select) {
2133         ConstvalNode *const0 = safe_cast<ConstvalNode>(p[kFirst].first);
2134         if (const0 != nullptr) {
2135             MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0->GetConstVal());
2136             ASSERT_NOT_NULL(intConst0);
2137             // Selecting the first value if not 0, selecting the second value otherwise.
2138             if (!intConst0->IsZero()) {
2139                 result = PairToExpr(primTypes[kSecond], p[kSecond]);
2140             } else {
2141                 result = PairToExpr(primTypes[kThird], p[kThird]);
2142             }
2143         } else {
2144             ConstvalNode *const1 = safe_cast<ConstvalNode>(p[kSecond].first);
2145             ConstvalNode *const2 = safe_cast<ConstvalNode>(p[kThird].first);
2146             if (const1 != nullptr && const2 != nullptr) {
2147                 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2148                 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2149                 double dconst1 = 0.0;
2150                 double dconst2 = 0.0;
2151                 // for fpconst
2152                 if (intConst1 == nullptr || intConst2 == nullptr) {
2153                     PrimType ptyp = const1->GetPrimType();
2154                     if (ptyp == PTY_f64) {
2155                         MIRDoubleConst *dConst1 = safe_cast<MIRDoubleConst>(const1->GetConstVal());
2156                         dconst1 = dConst1->GetValue();
2157                         MIRDoubleConst *dConst2 = safe_cast<MIRDoubleConst>(const2->GetConstVal());
2158                         dconst2 = dConst2->GetValue();
2159                     } else if (ptyp == PTY_f32) {
2160                         MIRFloatConst *fConst1 = safe_cast<MIRFloatConst>(const1->GetConstVal());
2161                         dconst1 = static_cast<double>(fConst1->GetFloatValue());
2162                         MIRFloatConst *fConst2 = safe_cast<MIRFloatConst>(const2->GetConstVal());
2163                         dconst2 = static_cast<double>(fConst2->GetFloatValue());
2164                     }
2165                 } else {
2166                     dconst1 = static_cast<double>(intConst1->GetExtValue());
2167                     dconst2 = static_cast<double>(intConst2->GetExtValue());
2168                 }
2169                 PrimType foldedPrimType = primTypes[kSecond];
2170                 if (!IsPrimitiveInteger(foldedPrimType)) {
2171                     foldedPrimType = primTypes[kThird];
2172                 }
2173                 if (dconst1 == 1.0 && dconst2 == 0.0) {
2174                     if (IsPrimitiveInteger(foldedPrimType)) {
2175                         result = PairToExpr(foldedPrimType, p[0]);
2176                     } else {
2177                         result = PairToExpr(primTypes[0], p[0]);
2178                         result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2179                                                                                    result);
2180                     }
2181                     return std::make_pair(result, std::nullopt);
2182                 }
2183                 if (dconst1 == 0.0 && dconst2 == 1.0) {
2184                     BaseNode *lnot = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2185                         OP_eq, primTypes[0], primTypes[0], PairToExpr(primTypes[0], p[0]),
2186                         mirModule->GetMIRBuilder()->CreateIntConst(0, primTypes[0]));
2187                     std::pair<BaseNode *, std::optional<IntVal>> pairTemp = DispatchFold(lnot);
2188                     if (IsPrimitiveInteger(foldedPrimType)) {
2189                         result = PairToExpr(foldedPrimType, pairTemp);
2190                     } else {
2191                         result = PairToExpr(primTypes[0], pairTemp);
2192                         result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2193                                                                                    result);
2194                     }
2195                     return std::make_pair(result, std::nullopt);
2196                 }
2197             }
2198         }
2199     }
2200     BaseNode *e0 = PairToExpr(primTypes[kFirst], p[kFirst]);
2201     BaseNode *e1 = PairToExpr(primTypes[kSecond], p[kSecond]);
2202     BaseNode *e2 = PairToExpr(primTypes[kThird], p[kThird]);  // count up to 3 for ternary node
2203     if (e0 != node->Opnd(kFirst) || e1 != node->Opnd(kSecond) || e2 != node->Opnd(kThird)) {
2204         result = mirModule->CurFuncCodeMemPool()->New<TernaryNode>(Opcode(node->GetOpCode()),
2205                                                                    PrimType(node->GetPrimType()), e0, e1, e2);
2206     }
2207     return std::make_pair(result, std::nullopt);
2208 }
2209 
SimplifyDassign(DassignNode * node)2210 StmtNode *ConstantFold::SimplifyDassign(DassignNode *node)
2211 {
2212     CHECK_NULL_FATAL(node);
2213     BaseNode *returnValue = nullptr;
2214     returnValue = Fold(node->GetRHS());
2215     if (returnValue != nullptr) {
2216         node->SetRHS(returnValue);
2217     }
2218     return node;
2219 }
2220 
SimplifyIassignWithAddrofBaseNode(IassignNode & node,const AddrofNode & base)2221 StmtNode *ConstantFold::SimplifyIassignWithAddrofBaseNode(IassignNode &node, const AddrofNode &base)
2222 {
2223     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2224     if (!mirTypeOfIass->IsMIRPtrType()) {
2225         return &node;
2226     }
2227     auto *iassPtType = static_cast<MIRPtrType *>(mirTypeOfIass);
2228 
2229     MIRSymbol *lhsSym = mirModule->CurFunction()->GetLocalOrGlobalSymbol(base.GetStIdx());
2230     TyIdx lhsTyIdx = lhsSym->GetTyIdx();
2231     if (base.GetFieldID() != 0) {
2232         auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(lhsTyIdx);
2233         if (!mirType->IsStructType()) {
2234             return &node;
2235         }
2236         lhsTyIdx = static_cast<MIRStructType *>(mirType)->GetFieldType(base.GetFieldID())->GetTypeIndex();
2237     }
2238     if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2239         DassignNode *dassignNode = mirModule->CurFuncCodeMemPool()->New<DassignNode>();
2240         dassignNode->SetStIdx(base.GetStIdx());
2241         dassignNode->SetRHS(node.GetRHS());
2242         dassignNode->SetFieldID(base.GetFieldID() + node.GetFieldID());
2243         // reuse stmtid to maintain stmtFreqs if profileUse is on
2244         dassignNode->SetStmtID(node.GetStmtID());
2245         return dassignNode;
2246     }
2247     return &node;
2248 }
2249 
SimplifyIassignWithIaddrofBaseNode(IassignNode & node,const IaddrofNode & base)2250 StmtNode *ConstantFold::SimplifyIassignWithIaddrofBaseNode(IassignNode &node, const IaddrofNode &base)
2251 {
2252     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2253     if (!mirTypeOfIass->IsMIRPtrType()) {
2254         return &node;
2255     }
2256     auto *iassPtType = static_cast<MIRPtrType *>(mirTypeOfIass);
2257 
2258     if (base.GetFieldID() == 0) {
2259         // this iaddrof is redundant
2260         node.SetAddrExpr(base.Opnd(0));
2261         return &node;
2262     }
2263 
2264     auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(base.GetTyIdx());
2265     if (!mirType->IsMIRPtrType()) {
2266         return &node;
2267     }
2268     auto *iaddrofPtType = static_cast<MIRPtrType *>(mirType);
2269 
2270     MIRStructType *lhsStructTy =
2271         static_cast<MIRStructType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iaddrofPtType->GetPointedTyIdx()));
2272     TyIdx lhsTyIdx = lhsStructTy->GetFieldType(base.GetFieldID())->GetTypeIndex();
2273     if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2274         // eliminate the iaddrof by updating the iassign's fieldID and tyIdx
2275         node.SetFieldID(node.GetFieldID() + base.GetFieldID());
2276         node.SetTyIdx(base.GetTyIdx());
2277         node.SetOpnd(base.Opnd(0), 0);
2278         // recursive call for the new iassign
2279         return SimplifyIassign(&node);
2280     }
2281     return &node;
2282 }
2283 
SimplifyIassign(IassignNode * node)2284 StmtNode *ConstantFold::SimplifyIassign(IassignNode *node)
2285 {
2286     CHECK_NULL_FATAL(node);
2287     BaseNode *returnValue = nullptr;
2288     returnValue = Fold(node->Opnd(0));
2289     if (returnValue != nullptr) {
2290         node->SetOpnd(returnValue, 0);
2291     }
2292     returnValue = Fold(node->GetRHS());
2293     if (returnValue != nullptr) {
2294         node->SetRHS(returnValue);
2295     }
2296     auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
2297     if (!mirTypeOfIass->IsMIRPtrType()) {
2298         return node;
2299     }
2300 
2301     auto *opnd = node->Opnd(0);
2302     ASSERT_NOT_NULL(opnd);
2303     switch (opnd->GetOpCode()) {
2304         case OP_addrof: {
2305             return SimplifyIassignWithAddrofBaseNode(*node, static_cast<AddrofNode &>(*opnd));
2306         }
2307         case OP_iaddrof: {
2308             return SimplifyIassignWithIaddrofBaseNode(*node, static_cast<IreadNode &>(*opnd));
2309         }
2310         default:
2311             break;
2312     }
2313     return node;
2314 }
2315 
SimplifyCondGoto(CondGotoNode * node)2316 StmtNode *ConstantFold::SimplifyCondGoto(CondGotoNode *node)
2317 {
2318     CHECK_NULL_FATAL(node);
2319     // optimize condgoto need to update frequency, skip here
2320     if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2321         return node;
2322     }
2323     BaseNode *returnValue = nullptr;
2324     returnValue = Fold(node->Opnd(0));
2325     returnValue = (returnValue == nullptr) ? node : returnValue;
2326     if (returnValue == node && node->Opnd(0)->GetOpCode() == OP_select) {
2327         return SimplifyCondGotoSelect(node);
2328     } else {
2329         if (returnValue != node) {
2330             node->SetOpnd(returnValue, 0);
2331         }
2332         ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2333         if (cst == nullptr) {
2334             return node;
2335         }
2336         MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2337         ASSERT_NOT_NULL(intConst);
2338         if ((node->GetOpCode() == OP_brtrue && !intConst->IsZero()) ||
2339             (node->GetOpCode() == OP_brfalse && intConst->IsZero())) {
2340             uint32 freq = mirModule->CurFunction()->GetFreqFromLastStmt(node->GetStmtID());
2341             GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2342             gotoNode->SetOffset(node->GetOffset());
2343             if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2344                 gotoNode->SetStmtID(node->GetStmtID());  // reuse condnode stmtid
2345             }
2346             mirModule->CurFunction()->SetLastFreqMap(gotoNode->GetStmtID(), freq);
2347             return gotoNode;
2348         } else {
2349             return nullptr;
2350         }
2351     }
2352     return node;
2353 }
2354 
SimplifyCondGotoSelect(CondGotoNode * node) const2355 StmtNode *ConstantFold::SimplifyCondGotoSelect(CondGotoNode *node) const
2356 {
2357     CHECK_NULL_FATAL(node);
2358     TernaryNode *sel = static_cast<TernaryNode *>(node->Opnd(0));
2359     if (sel == nullptr || sel->GetOpCode() != OP_select) {
2360         return node;
2361     }
2362     ConstvalNode *const1 = safe_cast<ConstvalNode>(sel->Opnd(1));
2363     ConstvalNode *const2 = safe_cast<ConstvalNode>(sel->Opnd(2));
2364     if (const1 != nullptr && const2 != nullptr) {
2365         MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2366         MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2367         ASSERT_NOT_NULL(intConst1);
2368         ASSERT_NOT_NULL(intConst2);
2369         if (intConst1->GetValue() == 1 && intConst2->GetValue() == 0) {
2370             node->SetOpnd(sel->Opnd(0), 0);
2371         } else if (intConst1->GetValue() == 0 && intConst2->GetValue() == 1) {
2372             node->SetOpCode((node->GetOpCode() == OP_brfalse) ? OP_brtrue : OP_brfalse);
2373             node->SetOpnd(sel->Opnd(0), 0);
2374         }
2375     }
2376     return node;
2377 }
2378 
SimplifySwitch(SwitchNode * node)2379 StmtNode *ConstantFold::SimplifySwitch(SwitchNode *node)
2380 {
2381     CHECK_NULL_FATAL(node);
2382     BaseNode *returnValue = nullptr;
2383     returnValue = Fold(node->GetSwitchOpnd());
2384     if (returnValue != nullptr) {
2385         node->SetSwitchOpnd(returnValue);
2386         ConstvalNode *cst = safe_cast<ConstvalNode>(node->GetSwitchOpnd());
2387         if (cst == nullptr) {
2388             return node;
2389         }
2390         MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2391         ASSERT_NOT_NULL(intConst);
2392         GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2393         bool isdefault = true;
2394         for (unsigned i = 0; i < node->GetSwitchTable().size(); i++) {
2395             if (node->GetCasePair(i).first == intConst->GetValue()) {
2396                 isdefault = false;
2397                 gotoNode->SetOffset((LabelIdx)node->GetCasePair(i).second);
2398                 break;
2399             }
2400         }
2401         if (isdefault) {
2402             gotoNode->SetOffset(node->GetDefaultLabel());
2403         }
2404         return gotoNode;
2405     }
2406     return node;
2407 }
2408 
SimplifyUnary(UnaryStmtNode * node)2409 StmtNode *ConstantFold::SimplifyUnary(UnaryStmtNode *node)
2410 {
2411     CHECK_NULL_FATAL(node);
2412     BaseNode *returnValue = nullptr;
2413     if (node->Opnd(0) == nullptr) {
2414         return node;
2415     }
2416     returnValue = Fold(node->Opnd(0));
2417     if (returnValue != nullptr) {
2418         node->SetOpnd(returnValue, 0);
2419     }
2420     return node;
2421 }
2422 
SimplifyBinary(BinaryStmtNode * node)2423 StmtNode *ConstantFold::SimplifyBinary(BinaryStmtNode *node)
2424 {
2425     CHECK_NULL_FATAL(node);
2426     BaseNode *returnValue = nullptr;
2427     returnValue = Fold(node->GetBOpnd(0));
2428     if (returnValue != nullptr) {
2429         node->SetBOpnd(returnValue, 0);
2430     }
2431     returnValue = Fold(node->GetBOpnd(1));
2432     if (returnValue != nullptr) {
2433         node->SetBOpnd(returnValue, 1);
2434     }
2435     return node;
2436 }
2437 
SimplifyBlock(BlockNode * node)2438 StmtNode *ConstantFold::SimplifyBlock(BlockNode *node)
2439 {
2440     CHECK_NULL_FATAL(node);
2441     if (node->GetFirst() == nullptr) {
2442         return node;
2443     }
2444     StmtNode *s = node->GetFirst();
2445     StmtNode *prevStmt = nullptr;
2446     do {
2447         StmtNode *returnValue = Simplify(s);
2448         if (returnValue != nullptr) {
2449             if (returnValue->GetOpCode() == OP_block) {
2450                 BlockNode *blk = static_cast<BlockNode *>(returnValue);
2451                 node->ReplaceStmtWithBlock(*s, *blk);
2452             } else {
2453                 node->ReplaceStmt1WithStmt2(s, returnValue);
2454             }
2455             prevStmt = s;
2456             s = s->GetNext();
2457         } else {
2458             // delete s from block
2459             StmtNode *nextStmt = s->GetNext();
2460             if (s == node->GetFirst()) {
2461                 node->SetFirst(nextStmt);
2462                 if (nextStmt != nullptr) {
2463                     nextStmt->SetPrev(nullptr);
2464                 }
2465             } else {
2466                 CHECK_NULL_FATAL(prevStmt);
2467                 prevStmt->SetNext(nextStmt);
2468                 if (nextStmt != nullptr) {
2469                     nextStmt->SetPrev(prevStmt);
2470                 }
2471             }
2472             if (s == node->GetLast()) {
2473                 node->SetLast(prevStmt);
2474             }
2475             s = nextStmt;
2476         }
2477     } while (s != nullptr);
2478     return node;
2479 }
2480 
SimplifyAsm(AsmNode * node)2481 StmtNode *ConstantFold::SimplifyAsm(AsmNode *node)
2482 {
2483     CHECK_NULL_FATAL(node);
2484     /* fold constval in input */
2485     for (size_t i = 0; i < node->NumOpnds(); i++) {
2486         const std::string &str = GlobalTables::GetUStrTable().GetStringFromStrIdx(node->inputConstraints[i]);
2487         if (str == "i") {
2488             std::pair<BaseNode *, std::optional<IntVal>> p = DispatchFold(node->Opnd(i));
2489             node->SetOpnd(p.first, i);
2490             continue;
2491         }
2492     }
2493     return node;
2494 }
2495 
SimplifyIf(IfStmtNode * node)2496 StmtNode *ConstantFold::SimplifyIf(IfStmtNode *node)
2497 {
2498     CHECK_NULL_FATAL(node);
2499     BaseNode *returnValue = nullptr;
2500     (void)Simplify(node->GetThenPart());
2501     if (node->GetElsePart()) {
2502         (void)Simplify(node->GetElsePart());
2503     }
2504     returnValue = Fold(node->Opnd());
2505     if (returnValue != nullptr) {
2506         node->SetOpnd(returnValue, 0);
2507         ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd());
2508         // do not delete c/c++ dead if-body here
2509         if (cst == nullptr || (!mirModule->IsJavaModule())) {
2510             return node;
2511         }
2512         MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2513         ASSERT_NOT_NULL(intConst);
2514         if (intConst->GetValue() == 0) {
2515             return node->GetElsePart();
2516         } else {
2517             return node->GetThenPart();
2518         }
2519     }
2520     return node;
2521 }
2522 
SimplifyWhile(WhileStmtNode * node)2523 StmtNode *ConstantFold::SimplifyWhile(WhileStmtNode *node)
2524 {
2525     CHECK_NULL_FATAL(node);
2526     BaseNode *returnValue = nullptr;
2527     if (node->Opnd(0) == nullptr) {
2528         return node;
2529     }
2530     if (node->GetBody()) {
2531         (void)Simplify(node->GetBody());
2532     }
2533     returnValue = Fold(node->Opnd(0));
2534     if (returnValue != nullptr) {
2535         node->SetOpnd(returnValue, 0);
2536         ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2537         // do not delete c/c++ dead while-body here
2538         if (cst == nullptr || (!mirModule->IsJavaModule())) {
2539             return node;
2540         }
2541         if (cst->GetConstVal()->IsZero()) {
2542             if (OP_while == node->GetOpCode()) {
2543                 return nullptr;
2544             } else {
2545                 return node->GetBody();
2546             }
2547         }
2548     }
2549     return node;
2550 }
2551 
SimplifyNary(NaryStmtNode * node)2552 StmtNode *ConstantFold::SimplifyNary(NaryStmtNode *node)
2553 {
2554     CHECK_NULL_FATAL(node);
2555     BaseNode *returnValue = nullptr;
2556     for (size_t i = 0; i < node->NumOpnds(); i++) {
2557         returnValue = Fold(node->GetNopndAt(i));
2558         if (returnValue != nullptr) {
2559             node->SetNOpndAt(i, returnValue);
2560         }
2561     }
2562     return node;
2563 }
2564 
SimplifyIcall(IcallNode * node)2565 StmtNode *ConstantFold::SimplifyIcall(IcallNode *node)
2566 {
2567     CHECK_NULL_FATAL(node);
2568     BaseNode *returnValue = nullptr;
2569     for (size_t i = 0; i < node->NumOpnds(); i++) {
2570         returnValue = Fold(node->GetNopndAt(i));
2571         if (returnValue != nullptr) {
2572             node->SetNOpndAt(i, returnValue);
2573         }
2574     }
2575     // icall node transform to call node
2576     CHECK_FATAL(node->GetNopnd().empty() == false, "container check");
2577     switch (node->GetNopndAt(0)->GetOpCode()) {
2578         case OP_addroffunc: {
2579             AddroffuncNode *addrofNode = static_cast<AddroffuncNode *>(node->GetNopndAt(0));
2580             CallNode *callNode = mirModule->CurFuncCodeMemPool()->New<CallNode>(
2581                 *mirModule,
2582                 (node->GetOpCode() == OP_icall || node->GetOpCode() == OP_icallproto) ? OP_call : OP_callassigned);
2583             if (node->GetOpCode() == OP_icallassigned || node->GetOpCode() == OP_icallprotoassigned) {
2584                 callNode->SetReturnVec(node->GetReturnVec());
2585             }
2586             callNode->SetPUIdx(addrofNode->GetPUIdx());
2587             for (size_t i = 1; i < node->GetNopndSize(); i++) {
2588                 callNode->GetNopnd().push_back(node->GetNopndAt(i));
2589             }
2590             callNode->SetNumOpnds(callNode->GetNopndSize());
2591             // reuse stmtID to skip update stmtFreqs when profileUse is on
2592             callNode->SetStmtID(node->GetStmtID());
2593             return callNode;
2594         }
2595         default:
2596             break;
2597     }
2598     return node;
2599 }
2600 
ProcessFunc(MIRFunction * func)2601 void ConstantFold::ProcessFunc(MIRFunction *func)
2602 {
2603     if (func->IsEmpty()) {
2604         return;
2605     }
2606     mirModule->SetCurFunction(func);
2607     (void)Simplify(func->GetBody());
2608 }
2609 
2610 }  // namespace maple
2611