• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "constantfold.h"
17 #include <cfloat>
18 
19 namespace maple {
20 
21 namespace {
22 constexpr uint32 kByteSizeOfBit64 = 8;                            // byte number for 64 bit
23 constexpr uint32 kBitSizePerByte = 8;
24 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
25 
26 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
27 
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)28 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
29 {
30     if (!v1 && !v2) {
31         return std::nullopt;
32     }
33 
34     // Perform all calculations in terms of the maximum available signed type.
35     // The value will be truncated for an appropriate type when constant is created in PairToExpr function
36     return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
37 }
38 
39 // Perform all calculations in terms of the maximum available signed type.
40 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)41 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
42 {
43     if (!v1 && !v2) {
44         return std::nullopt;
45     }
46 
47     if (v1 && v2) {
48         return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
49     }
50 
51     if (v1) {
52         return v1->TruncOrExtend(PTY_i64);
53     }
54 
55     // !v1 && v2
56     return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
57 }
58 
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)59 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
60 {
61     return AddSub(v1, v2, true);
62 }
63 
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)64 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
65 {
66     return AddSub(v1, v2, false);
67 }
68 
69 }  // anonymous namespace
70 
71 // This phase is designed to achieve compiler optimization by
72 // simplifying constant expressions. The constant expression
73 // is evaluated and replaced by the value calculated on compile
74 // time to save time on runtime.
75 //
76 // The main procedure shows as following:
77 // A. Analyze expression type
78 // B. Analysis operator type
79 // C. Replace the expression with the result of the operation
80 
81 // true if the constant's bits are made of only one group of contiguous 1's
82 // starting at bit 0
ContiguousBitsOf1(uint64 x)83 static bool ContiguousBitsOf1(uint64 x)
84 {
85     if (x == 0) {
86         return false;
87     }
88     return (~x & (x + 1)) == (x + 1);
89 }
90 
IsPowerOf2(uint64 num)91 inline bool IsPowerOf2(uint64 num)
92 {
93     if (num == 0) {
94         return false;
95     }
96     return (~(num - 1) & num) == num;
97 }
98 
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const99 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
100                                         BaseNode *rhs) const
101 {
102     CHECK_NULL_FATAL(old);
103     BinaryNode *result = nullptr;
104     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
105         result = old;
106     } else {
107         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
108     }
109     return result;
110 }
111 
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const112 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
113 {
114     CHECK_NULL_FATAL(old);
115     UnaryNode *result = nullptr;
116     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
117         result = old;
118     } else {
119         result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
120     }
121     return result;
122 }
123 
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const124 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
125 {
126     CHECK_NULL_FATAL(pair.first);
127     BaseNode *result = pair.first;
128     if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
129         return result;
130     }
131     if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
132         // -a, 5 -> 5 - a
133         ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
134             static_cast<uint64>(pair.second->GetExtValue()), resultType);
135         BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
136         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
137     } else {
138         if ((!pair.second->GetSignBit() &&
139             pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
140             pair.second->TruncOrExtend(resultType).IsMinValue() ||
141             pair.second->GetSXTValue() == INT64_MIN) {
142             // +-a, 5 -> a + 5
143             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
144                 static_cast<uint64>(pair.second->GetExtValue()), resultType);
145             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
146         } else {
147             // +-a, -5 -> a + -5
148             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
149                 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
150             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
151         }
152     }
153     return result;
154 }
155 
FoldBase(BaseNode * node) const156 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
157 {
158     return std::make_pair(node, std::nullopt);
159 }
160 
DispatchFold(BaseNode * node)161 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
162 {
163     CHECK_NULL_FATAL(node);
164     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
165         return {node, std::nullopt};
166     }
167     switch (node->GetOpCode()) {
168         case OP_abs:
169         case OP_bnot:
170         case OP_lnot:
171         case OP_neg:
172         case OP_sqrt:
173             return FoldUnary(static_cast<UnaryNode*>(node));
174         case OP_ceil:
175         case OP_floor:
176         case OP_trunc:
177         case OP_cvt:
178             return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
179         case OP_sext:
180         case OP_zext:
181         case OP_extractbits:
182             return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
183         case OP_iread:
184             return FoldIread(static_cast<IreadNode*>(node));
185         case OP_add:
186         case OP_ashr:
187         case OP_band:
188         case OP_bior:
189         case OP_bxor:
190         case OP_div:
191         case OP_lshr:
192         case OP_max:
193         case OP_min:
194         case OP_mul:
195         case OP_rem:
196         case OP_shl:
197         case OP_sub:
198             return FoldBinary(static_cast<BinaryNode*>(node));
199         case OP_eq:
200         case OP_ne:
201         case OP_ge:
202         case OP_gt:
203         case OP_le:
204         case OP_lt:
205         case OP_cmp:
206             return FoldCompare(static_cast<CompareNode*>(node));
207         case OP_retype:
208             return FoldRetype(static_cast<RetypeNode*>(node));
209         default:
210             return FoldBase(static_cast<BaseNode*>(node));
211     }
212 }
213 
Negate(BaseNode * node) const214 BaseNode *ConstantFold::Negate(BaseNode *node) const
215 {
216     CHECK_NULL_FATAL(node);
217     return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
218 }
219 
Negate(UnaryNode * node) const220 BaseNode *ConstantFold::Negate(UnaryNode *node) const
221 {
222     CHECK_NULL_FATAL(node);
223     BaseNode *result = nullptr;
224     if (node->GetOpCode() == OP_neg) {
225         result = static_cast<BaseNode*>(node->Opnd(0));
226     } else {
227         BaseNode *n = static_cast<BaseNode*>(node);
228         result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
229     }
230     return result;
231 }
232 
Negate(const ConstvalNode * node) const233 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
234 {
235     CHECK_NULL_FATAL(node);
236     ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
237     CHECK_NULL_FATAL(copy);
238     copy->GetConstVal()->Neg();
239     return copy;
240 }
241 
NegateTree(BaseNode * node) const242 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
243 {
244     CHECK_NULL_FATAL(node);
245     if (node->IsUnaryNode()) {
246         return Negate(static_cast<UnaryNode*>(node));
247     } else if (node->GetOpCode() == OP_constval) {
248         return Negate(static_cast<ConstvalNode*>(node));
249     } else {
250         return Negate(static_cast<BaseNode*>(node));
251     }
252 }
253 
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const254 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
255                                                           const MIRIntConst &intConst0,
256                                                           const MIRIntConst &intConst1) const
257 {
258     uint64 result = 0;
259 
260     bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
261     bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
262     bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
263 
264     switch (opcode) {
265         case OP_eq: {
266             result = equal;
267             break;
268         }
269         case OP_ge: {
270             result = (greater || equal) ? 1 : 0;
271             break;
272         }
273         case OP_gt: {
274             result = greater;
275             break;
276         }
277         case OP_le: {
278             result = (less || equal) ? 1 : 0;
279             break;
280         }
281         case OP_lt: {
282             result = less;
283             break;
284         }
285         case OP_ne: {
286             result = !equal;
287             break;
288         }
289         case OP_cmp: {
290             if (greater) {
291                 result = kGreater;
292             } else if (equal) {
293                 result = kEqual;
294             } else if (less) {
295                 result = static_cast<uint64>(kLess);
296             }
297             break;
298         }
299         default:
300             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
301             break;
302     }
303     // determine the type
304     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
305     // form the constant
306     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
307     return constValue;
308 }
309 
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const310 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
311                                                    const ConstvalNode &const0, const ConstvalNode &const1) const
312 {
313     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
314     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
315     CHECK_NULL_FATAL(intConst0);
316     CHECK_NULL_FATAL(intConst1);
317     MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
318     // form the ConstvalNode
319     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
320     resultConst->SetPrimType(resultType);
321     resultConst->SetConstVal(constValue);
322     return resultConst;
323 }
324 
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst & intConst0,const MIRIntConst & intConst1)325 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
326                                                    const MIRIntConst &intConst1)
327 {
328     IntVal intVal0 = intConst0.GetValue();
329     IntVal intVal1 = intConst1.GetValue();
330     IntVal result(static_cast<uint64>(0), resultType);
331 
332     switch (opcode) {
333         case OP_add: {
334             result = intVal0.Add(intVal1, resultType);
335             break;
336         }
337         case OP_sub: {
338             result = intVal0.Sub(intVal1, resultType);
339             break;
340         }
341         case OP_mul: {
342             result = intVal0.Mul(intVal1, resultType);
343             break;
344         }
345         case OP_div: {
346             result = intVal0.Div(intVal1, resultType);
347             break;
348         }
349         case OP_rem: {
350             result = intVal0.Rem(intVal1, resultType);
351             break;
352         }
353         case OP_ashr: {
354             result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
355             break;
356         }
357         case OP_lshr: {
358             result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
359             break;
360         }
361         case OP_shl: {
362             result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
363             break;
364         }
365         case OP_max: {
366             result = Max(intVal0, intVal1, resultType);
367             break;
368         }
369         case OP_min: {
370             result = Min(intVal0, intVal1, resultType);
371             break;
372         }
373         case OP_band: {
374             result = intVal0.And(intVal1, resultType);
375             break;
376         }
377         case OP_bior: {
378             result = intVal0.Or(intVal1, resultType);
379             break;
380         }
381         case OP_bxor: {
382             result = intVal0.Xor(intVal1, resultType);
383             break;
384         }
385         default:
386             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
387             break;
388     }
389     // determine the type
390     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
391     // form the constant
392     MIRIntConst *constValue =
393         GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
394     return constValue;
395 }
396 
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const397 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
398                                                const ConstvalNode &const1) const
399 {
400     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
401     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
402     CHECK_NULL_FATAL(intConst0);
403     CHECK_NULL_FATAL(intConst1);
404     MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
405     // form the ConstvalNode
406     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
407     resultConst->SetPrimType(resultType);
408     resultConst->SetConstVal(constValue);
409     return resultConst;
410 }
411 
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const412 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
413                                               const ConstvalNode &const1) const
414 {
415     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
416     const MIRDoubleConst *doubleConst0 = nullptr;
417     const MIRDoubleConst *doubleConst1 = nullptr;
418     const MIRFloatConst *floatConst0 = nullptr;
419     const MIRFloatConst *floatConst1 = nullptr;
420     bool useDouble = (const0.GetPrimType() == PTY_f64);
421     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
422     resultConst->SetPrimType(resultType);
423     if (useDouble) {
424         doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
425         doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
426         CHECK_NULL_FATAL(doubleConst0);
427         CHECK_NULL_FATAL(doubleConst1);
428     } else {
429         floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
430         floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
431         CHECK_NULL_FATAL(floatConst0);
432         CHECK_NULL_FATAL(floatConst1);
433     }
434     float constValueFloat = 0.0;
435     double constValueDouble = 0.0;
436     switch (opcode) {
437         case OP_add: {
438             if (useDouble) {
439                 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
440             } else {
441                 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
442             }
443             break;
444         }
445         case OP_sub: {
446             if (useDouble) {
447                 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
448             } else {
449                 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
450             }
451             break;
452         }
453         case OP_mul: {
454             if (useDouble) {
455                 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
456             } else {
457                 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
458             }
459             break;
460         }
461         case OP_div: {
462             // for floats div by 0 is well defined
463             if (useDouble) {
464                 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
465             } else {
466                 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
467             }
468             break;
469         }
470         case OP_max: {
471             if (useDouble) {
472                 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
473                                                                                         : doubleConst1->GetValue();
474             } else {
475                 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
476                                                                                     : floatConst1->GetValue();
477             }
478             break;
479         }
480         case OP_min: {
481             if (useDouble) {
482                 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
483                                                                                         : doubleConst1->GetValue();
484             } else {
485                 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
486                                                                                     : floatConst1->GetValue();
487             }
488             break;
489         }
490         case OP_rem:
491         case OP_ashr:
492         case OP_lshr:
493         case OP_shl:
494         case OP_band:
495         case OP_bior:
496         case OP_bxor: {
497             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
498             break;
499         }
500         default:
501             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
502             break;
503     }
504     if (resultType == PTY_f64) {
505         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
506     } else {
507         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
508     }
509     return resultConst;
510 }
511 
ConstValueEqual(int64 leftValue,int64 rightValue) const512 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
513 {
514     return (leftValue == rightValue);
515 }
516 
ConstValueEqual(float leftValue,float rightValue) const517 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
518 {
519     auto result = fabs(leftValue - rightValue);
520     return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
521 }
522 
ConstValueEqual(double leftValue,double rightValue) const523 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
524 {
525     auto result = fabs(leftValue - rightValue);
526     return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
527 }
528 
529 template<typename T>
FullyEqual(T leftValue,T rightValue) const530 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
531 {
532     if (std::isinf(leftValue) && std::isinf(rightValue)) {
533         // (inf == inf), add the judgement here in case of the subtraction between float type inf
534         return true;
535     } else {
536         return ConstValueEqual(leftValue, rightValue);
537     }
538 }
539 
540 template<typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const541 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
542 {
543     DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
544     typename T::value_type leftValue = leftConst->GetValue();
545     DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
546     typename T::value_type rightValue = rightConst->GetValue();
547     int64 result = 0;
548     switch (op) {
549         case OP_eq: {
550             result = FullyEqual(leftValue, rightValue);
551             break;
552         }
553         case OP_ge: {
554             result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
555             break;
556         }
557         case OP_gt: {
558             result = (leftValue > rightValue);
559             break;
560         }
561         case OP_le: {
562             result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
563             break;
564         }
565         case OP_lt: {
566             result = (leftValue < rightValue);
567             break;
568         }
569         case OP_ne: {
570             result = !FullyEqual(leftValue, rightValue);
571             break;
572         }
573         [[clang::fallthrough]];
574         case OP_cmp: {
575             if (leftValue > rightValue) {
576                 result = kGreater;
577             } else if (FullyEqual(leftValue, rightValue)) {
578                 result = kEqual;
579             } else {
580                 result = kLess;
581             }
582             break;
583         }
584         default:
585             DEBUG_ASSERT(false, "Unknown opcode for Comparison");
586             break;
587     }
588     return result;
589 }
590 
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const591 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
592                                                          const MIRConst &leftConst, const MIRConst &rightConst) const
593 {
594     int64 result = 0;
595     bool useDouble = (opndType == PTY_f64);
596     if (useDouble) {
597         result =
598             ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
599     } else {
600         result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
601     }
602     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
603     MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
604     return resultConst;
605 }
606 
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const607 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
608                                                   const ConstvalNode &const0, const ConstvalNode &const1) const
609 {
610     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
611     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
612     resultConst->SetPrimType(resultType);
613     resultConst->SetConstVal(
614         FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
615     return resultConst;
616 }
617 
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1) const618 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
619                                                     const MIRConst &const0, const MIRConst &const1) const
620 {
621     MIRConst *returnValue = nullptr;
622     if (IsPrimitiveInteger(opndType)) {
623         const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
624         const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
625         ASSERT_NOT_NULL(intConst0);
626         ASSERT_NOT_NULL(intConst1);
627         returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
628     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
629         returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
630     } else {
631         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
632     }
633     return returnValue;
634 }
635 
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const636 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
637                                                 const ConstvalNode &const0, const ConstvalNode &const1) const
638 {
639     ConstvalNode *returnValue = nullptr;
640     if (IsPrimitiveInteger(opndType)) {
641         returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
642     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
643         returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
644     } else {
645         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
646     }
647     return returnValue;
648 }
649 
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r) const650 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
651                                                       BaseNode &l, BaseNode &r) const
652 {
653     CompareNode *result = nullptr;
654     Opcode op = opcode;
655     switch (opcode) {
656         case OP_gt: {
657             op = OP_lt;
658             break;
659         }
660         case OP_lt: {
661             op = OP_gt;
662             break;
663         }
664         case OP_ge: {
665             op = OP_le;
666             break;
667         }
668         case OP_le: {
669             op = OP_ge;
670             break;
671         }
672         case OP_eq: {
673             break;
674         }
675         case OP_ne: {
676             break;
677         }
678         default:
679             DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
680             break;
681     }
682 
683     result =
684         mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
685     return result;
686 }
687 
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const688 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
689                                             const ConstvalNode &const1) const
690 {
691     ConstvalNode *returnValue = nullptr;
692     if (IsPrimitiveInteger(resultType)) {
693         returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
694     } else if (resultType == PTY_f32 || resultType == PTY_f64) {
695         returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
696     } else {
697         DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
698     }
699     return returnValue;
700 }
701 
FoldIntConstUnaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * constNode)702 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
703 {
704     CHECK_NULL_FATAL(constNode);
705     IntVal result = constNode->GetValue().TruncOrExtend(resultType);
706     switch (opcode) {
707         case OP_abs: {
708             if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
709                 result = -result;
710             }
711             break;
712         }
713         case OP_bnot: {
714             result = ~result;
715             break;
716         }
717         case OP_lnot: {
718             uint64 resultInt = result == 0 ? 1 : 0;
719             result = {resultInt, resultType};
720             break;
721         }
722         case OP_neg: {
723             result = -result;
724             break;
725         }
726         case OP_sext:         // handled in FoldExtractbits
727         case OP_zext:         // handled in FoldExtractbits
728         case OP_extractbits:  // handled in FoldExtractbits
729         case OP_sqrt: {
730             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
731             break;
732         }
733         default:
734             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
735             break;
736     }
737     // determine the type
738     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
739     // form the constant
740     MIRIntConst *constValue =
741         GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
742     return constValue;
743 }
744 
745 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const746 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
747 {
748     CHECK_NULL_FATAL(constNode);
749     double constValue = 0;
750     T *fpCst = static_cast<T*>(constNode->GetConstVal());
751     switch (opcode) {
752         case OP_neg: {
753             constValue = typename T::value_type(-fpCst->GetValue());
754             break;
755         }
756         case OP_abs: {
757             constValue = typename T::value_type(fabs(fpCst->GetValue()));
758             break;
759         }
760         case OP_sqrt: {
761             constValue = typename T::value_type(sqrt(fpCst->GetValue()));
762             break;
763         }
764         case OP_bnot:
765         case OP_lnot:
766         case OP_sext:
767         case OP_zext:
768         case OP_extractbits: {
769             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
770             break;
771         }
772         default:
773             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
774             break;
775     }
776     auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
777     resultConst->SetPrimType(resultType);
778     if (resultType == PTY_f32) {
779         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
780     } else if (resultType == PTY_f64) {
781         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
782     } else {
783         CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
784     }
785     return resultConst;
786 }
787 
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode & constNode) const788 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
789 {
790     ConstvalNode *returnValue = nullptr;
791     if (IsPrimitiveInteger(resultType)) {
792         const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
793         auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
794         returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
795         returnValue->SetPrimType(resultType);
796         returnValue->SetConstVal(constValue);
797     } else if (resultType == PTY_f32) {
798         returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
799     } else if (resultType == PTY_f64) {
800         returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
801     } else {
802         DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
803     }
804     return returnValue;
805 }
806 
FoldRetype(RetypeNode * node)807 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
808 {
809     CHECK_NULL_FATAL(node);
810     BaseNode *result = node;
811     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
812     if (node->Opnd(0) != p.first) {
813         RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
814         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
815         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
816         result = newRetNode;
817     }
818     return std::make_pair(result, std::nullopt);
819 }
820 
FoldUnary(UnaryNode * node)821 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
822 {
823     CHECK_NULL_FATAL(node);
824     BaseNode *result = nullptr;
825     std::optional<IntVal> sum = std::nullopt;
826     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
827     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
828     if (cst != nullptr) {
829         result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
830     } else {
831         bool isInt = IsPrimitiveInteger(node->GetPrimType());
832         // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
833         // primType will be set to opnd type. There will be problems in some cases. For example:
834         // before cf:
835         //   neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
836         // after cf:
837         //   neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))  # wrong!
838         // As a workaround, we exclude u1 opnd type
839         if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
840             result = NegateTree(p.first);
841             if (result->GetOpCode() == OP_neg) {
842                 PrimType origPtyp = node->GetPrimType();
843                 PrimType newPtyp = result->GetPrimType();
844                 if (newPtyp == origPtyp) {
845                 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
846                     // NegateTree returned an UnaryNode quivalent to `n`, so keep the
847                     // original UnaryNode to preserve identity
848                     result = node;
849                 }
850                 } else {
851                     if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
852                         // do not fold explicit cvt
853                         result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
854                             PairToExpr(node->Opnd(0)->GetPrimType(), p));
855                         return std::make_pair(result, std::nullopt);
856                     } else {
857                         result->SetPrimType(origPtyp);
858                     }
859                 }
860             }
861             if (p.second) {
862                 sum = -(*p.second);
863             }
864         } else {
865             result =
866                 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
867         }
868     }
869     return std::make_pair(result, sum);
870 }
871 
FloatToIntOverflow(float fval,PrimType totype)872 static bool FloatToIntOverflow(float fval, PrimType totype)
873 {
874     static const float safeFloatMaxToInt32 = 2147483520.0f;  // 2^31 - 128
875     static const float safeFloatMinToInt32 = -2147483520.0f;
876     static const float safeFloatMaxToInt64 = 9223372036854775680.0f;  // 2^63 - 128
877     static const float safeFloatMinToInt64 = -9223372036854775680.0f;
878     if (!std::isfinite(fval)) {
879         return true;
880     }
881     if (totype == PTY_i64 || totype == PTY_u64) {
882         if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
883             return true;
884         }
885     } else {
886         if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
887             return true;
888         }
889     }
890     return false;
891 }
892 
DoubleToIntOverflow(double dval,PrimType totype)893 static bool DoubleToIntOverflow(double dval, PrimType totype)
894 {
895     static const double safeDoubleMaxToInt32 = 2147482624.0;  // 2^31 - 1024
896     static const double safeDoubleMinToInt32 = -2147482624.0;
897     static const double safeDoubleMaxToInt64 = 9223372036854774784.0;  // 2^63 - 1024
898     static const double safeDoubleMinToInt64 = -9223372036854774784.0;
899     if (!std::isfinite(dval)) {
900         return true;
901     }
902     if (totype == PTY_i64 || totype == PTY_u64) {
903         if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
904             return true;
905         }
906     } else {
907         if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
908             return true;
909         }
910     }
911     return false;
912 }
913 
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const914 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
915 {
916     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
917     resultConst->SetPrimType(toType);
918     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
919     if (fromType == PTY_f32) {
920         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
921         ASSERT_NOT_NULL(constValue);
922         float floatValue = ceil(constValue->GetValue());
923         if (IsPrimitiveFloat(toType)) {
924             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
925         } else if (FloatToIntOverflow(floatValue, toType)) {
926             return nullptr;
927         } else {
928             resultConst->SetConstVal(
929                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
930         }
931     } else {
932         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
933         ASSERT_NOT_NULL(constValue);
934         double doubleValue = ceil(constValue->GetValue());
935         if (IsPrimitiveFloat(toType)) {
936             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
937         } else if (DoubleToIntOverflow(doubleValue, toType)) {
938             return nullptr;
939         } else {
940             resultConst->SetConstVal(
941                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
942         }
943     }
944     return resultConst;
945 }
946 
947 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const948 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
949 {
950     DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
951     size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
952     bool isSigned = IsSignedInteger(resultType.GetPrimType());
953     int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
954     uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
955     int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
956     if (isSigned && (value > max)) {
957         return static_cast<T>(max);
958     } else if (!isSigned && (value > umax)) {
959         return static_cast<T>(umax);
960     } else if (value < min) {
961         return static_cast<T>(min);
962     }
963     return value;
964 }
965 
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const966 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
967 {
968     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
969     if (fromType == PTY_f32) {
970         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
971         float floatValue = constValue.GetValue();
972         if (isFloor) {
973             floatValue = floor(constValue.GetValue());
974         }
975         if (IsPrimitiveFloat(toType)) {
976             return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
977         }
978         if (FloatToIntOverflow(floatValue, toType)) {
979             return nullptr;
980         }
981         floatValue = CalIntValueFromFloatValue(floatValue, resultType);
982         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
983     } else {
984         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
985         double doubleValue = constValue.GetValue();
986         if (isFloor) {
987             doubleValue = floor(constValue.GetValue());
988         }
989         if (IsPrimitiveFloat(toType)) {
990             return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
991         }
992         if (DoubleToIntOverflow(doubleValue, toType)) {
993             return nullptr;
994         }
995         doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
996         // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
997         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
998     }
999 }
1000 
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1001 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1002 {
1003     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1004     resultConst->SetPrimType(toType);
1005     resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1006     return resultConst;
1007 }
1008 
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1009 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1010 {
1011     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1012     if (fromType == PTY_f32) {
1013         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1014         float floatValue = round(constValue.GetValue());
1015         if (FloatToIntOverflow(floatValue, toType)) {
1016             return nullptr;
1017         }
1018         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1019     } else if (fromType == PTY_f64) {
1020         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1021         double doubleValue = round(constValue.GetValue());
1022         if (DoubleToIntOverflow(doubleValue, toType)) {
1023             return nullptr;
1024         }
1025         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1026             static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1027     } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1028         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1029         if (IsSignedInteger(fromType)) {
1030             int64 fromValue = constValue.GetExtValue();
1031             float floatValue = round(static_cast<float>(fromValue));
1032             if (static_cast<int64>(floatValue) == fromValue) {
1033                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1034             }
1035         } else {
1036             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1037             float floatValue = round(static_cast<float>(fromValue));
1038             if (static_cast<uint64>(floatValue) == fromValue) {
1039                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1040             }
1041         }
1042     } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1043         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1044         if (IsSignedInteger(fromType)) {
1045             int64 fromValue = constValue.GetExtValue();
1046             double doubleValue = round(static_cast<double>(fromValue));
1047             if (static_cast<int64>(doubleValue) == fromValue) {
1048                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1049             }
1050         } else {
1051             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1052             double doubleValue = round(static_cast<double>(fromValue));
1053             if (static_cast<uint64>(doubleValue) == fromValue) {
1054                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1055             }
1056         }
1057     }
1058     return nullptr;
1059 }
1060 
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1061 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1062 {
1063     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1064     resultConst->SetPrimType(toType);
1065     resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1066     return resultConst;
1067 }
1068 
FoldTrunc(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1069 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1070 {
1071     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1072     resultConst->SetPrimType(toType);
1073     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1074     if (fromType == PTY_f32) {
1075         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1076         CHECK_NULL_FATAL(constValue);
1077         float floatValue = trunc(constValue->GetValue());
1078         if (IsPrimitiveFloat(toType)) {
1079             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1080         } else if (FloatToIntOverflow(floatValue, toType)) {
1081             return nullptr;
1082         } else {
1083             resultConst->SetConstVal(
1084                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1085         }
1086     } else {
1087         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1088         CHECK_NULL_FATAL(constValue);
1089         double doubleValue = trunc(constValue->GetValue());
1090         if (IsPrimitiveFloat(toType)) {
1091             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1092         } else if (DoubleToIntOverflow(doubleValue, toType)) {
1093             return nullptr;
1094         } else {
1095             resultConst->SetConstVal(
1096                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1097         }
1098     }
1099     return resultConst;
1100 }
1101 
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1102 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1103 {
1104     if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1105         MIRConst *toConst = nullptr;
1106         uint32 fromSize = GetPrimTypeBitSize(fromType);
1107         uint32 toSize = GetPrimTypeBitSize(toType);
1108         // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1109         if (fromType == PTY_u1) {
1110             fromSize = 1;
1111         }
1112         if (toType == PTY_u1) {
1113             toSize = 1;
1114         }
1115         if (toSize > fromSize) {
1116             Opcode op = OP_zext;
1117             if (IsSignedInteger(fromType)) {
1118                 op = OP_sext;
1119             }
1120             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1121             ASSERT_NOT_NULL(constVal);
1122             toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1123                 constVal->GetValue().TruncOrExtend(fromType));
1124         } else {
1125             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1126             ASSERT_NOT_NULL(constVal);
1127             MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1128             toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1129                 static_cast<uint64>(constVal->GetExtValue()), type);
1130         }
1131         return toConst;
1132     }
1133     if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1134         MIRConst *toConst = nullptr;
1135         if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1136             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1137             const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1138             ASSERT_NOT_NULL(fromValue);
1139             float floatValue = static_cast<float>(fromValue->GetValue());
1140             MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1141             toConst = toValue;
1142         } else {
1143             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1144             const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1145             ASSERT_NOT_NULL(fromValue);
1146             double doubleValue = static_cast<double>(fromValue->GetValue());
1147             MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1148             toConst = toValue;
1149         }
1150         return toConst;
1151     }
1152     if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1153         return FoldFloorMIRConst(cst, fromType, toType, false);
1154     }
1155     if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1156         return FoldRoundMIRConst(cst, fromType, toType);
1157     }
1158     CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1159     return nullptr;
1160 }
1161 
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1162 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1163 {
1164     MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1165     if (toConstValue == nullptr) {
1166         return nullptr;
1167     }
1168     ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1169     toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1170     toConst->SetConstVal(toConstValue);
1171     return toConst;
1172 }
1173 
1174 // return a primType with bit size >= bitSize (and the nearest one),
1175 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1176 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1177 {
1178     bool isSigned = IsSignedInteger(ptyp);
1179     bool isFloat = IsPrimitiveFloat(ptyp);
1180     if (bitSize == 1) { // 1 bit
1181         return PTY_u1;
1182     }
1183     if (bitSize <= 8) { // 8 bit
1184         return isSigned ? PTY_i8 : PTY_u8;
1185     }
1186     if (bitSize <= 16) { // 16 bit
1187         return isSigned ? PTY_i16 : PTY_u16;
1188     }
1189     if (bitSize <= 32) { // 32 bit
1190         return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1191     }
1192     if (bitSize <= 64) { // 64 bit
1193         return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1194     }
1195     return ptyp;
1196 }
1197 
GetIntPrimTypeMax(PrimType ptyp)1198 size_t GetIntPrimTypeMax(PrimType ptyp)
1199 {
1200     switch (ptyp) {
1201         case PTY_u1:
1202             return 1;
1203         case PTY_u8:
1204             return UINT8_MAX;
1205         case PTY_i8:
1206             return INT8_MAX;
1207         case PTY_u16:
1208             return UINT16_MAX;
1209         case PTY_i16:
1210             return INT16_MAX;
1211         case PTY_u32:
1212             return UINT32_MAX;
1213         case PTY_i32:
1214             return INT32_MAX;
1215         case PTY_u64:
1216             return UINT64_MAX;
1217         case PTY_i64:
1218             return INT64_MAX;
1219         default:
1220             CHECK_FATAL(false, "NYI");
1221     }
1222 }
1223 
GetIntPrimTypeMin(PrimType ptyp)1224 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1225 {
1226     if (IsUnsignedInteger(ptyp)) {
1227         return 0;
1228     }
1229     switch (ptyp) {
1230         case PTY_i8:
1231             return INT8_MIN;
1232         case PTY_i16:
1233             return INT16_MIN;
1234         case PTY_i32:
1235             return INT32_MIN;
1236         case PTY_i64:
1237             return INT64_MIN;
1238         default:
1239             CHECK_FATAL(false, "NYI");
1240     }
1241 }
1242 
IsCvtEliminatable(PrimType fromPtyp,PrimType destPtyp,Opcode op,Opcode opndOp)1243 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1244 {
1245     if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1246         return false;
1247     }
1248     if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1249         return false;
1250     }
1251     return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1252         (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1253         (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1254 }
1255 
FoldTypeCvt(TypeCvtNode * node)1256 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1257 {
1258     CHECK_NULL_FATAL(node);
1259     BaseNode *result = nullptr;
1260     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1261         return {node, std::nullopt};
1262     }
1263     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1264     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1265     PrimType destPtyp = node->GetPrimType();
1266     PrimType fromPtyp = node->FromType();
1267     if (cst != nullptr) {
1268         switch (node->GetOpCode()) {
1269             case OP_ceil: {
1270                 result = FoldCeil(*cst, fromPtyp, destPtyp);
1271                 break;
1272             }
1273             case OP_cvt: {
1274                 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1275                 break;
1276             }
1277             case OP_floor: {
1278                 result = FoldFloor(*cst, fromPtyp, destPtyp);
1279                 break;
1280             }
1281             case OP_trunc: {
1282                 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1283                 break;
1284             }
1285             default:
1286                 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1287                 break;
1288         }
1289     } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1290         // the cvt is redundant
1291         return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1292     }
1293     if (result == nullptr) {
1294         BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1295         if (e != node->Opnd(0)) {
1296             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1297                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1298         } else {
1299             result = node;
1300         }
1301     }
1302     return std::make_pair(result, std::nullopt);
1303 }
1304 
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1305 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1306 {
1307     uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1308     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1309     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1310     return constValue;
1311 }
1312 
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1313 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1314                                            const ConstvalNode &cst) const
1315 {
1316     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1317     const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1318     ASSERT_NOT_NULL(intCst);
1319     IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1320     MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1321     resultConst->SetPrimType(toConst->GetType().GetPrimType());
1322     resultConst->SetConstVal(toConst);
1323     return resultConst;
1324 }
1325 
1326 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode & x,MIRFunction & f)1327 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1328 {
1329     if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1330         return false;  // this is trying to be conservative
1331     }
1332     BaseNode *opnd = x.Opnd(0);
1333     MIRType *mirType = nullptr;
1334     if (opnd->GetOpCode() == OP_dread) {
1335         DreadNode *dread = static_cast<DreadNode*>(opnd);
1336         MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1337         ASSERT_NOT_NULL(sym);
1338         mirType = sym->GetType();
1339     } else if (opnd->GetOpCode() == OP_iread) {
1340         IreadNode *iread = static_cast<IreadNode*>(opnd);
1341         MIRPtrType *ptrType =
1342             dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1343         if (ptrType == nullptr) {
1344             return false;
1345         }
1346         mirType = ptrType->GetPointedType();
1347     } else if (opnd->GetOpCode() == OP_extractbits &&
1348                 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1349         return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1350             IsUnsignedInteger(opnd->GetPrimType()));
1351     } else {
1352         return false;
1353     }
1354     return IsPrimitiveInteger(mirType->GetPrimType()) &&
1355             ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1356             (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1357             mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1358             mirType->GetPrimType() == x.GetPrimType();
1359 }
1360 
1361 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1362 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1363 {
1364     CHECK_NULL_FATAL(node);
1365     BaseNode *result = nullptr;
1366     uint8 offset = node->GetBitsOffset();
1367     uint8 size = node->GetBitsSize();
1368     Opcode opcode = node->GetOpCode();
1369     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1370     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1371     if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1372         result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1373         return std::make_pair(result, std::nullopt);
1374     }
1375     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1376     if (e != node->Opnd(0)) {
1377         result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1378                                                                        size, e);
1379     } else {
1380         result = node;
1381     }
1382     // check for consecutive and redundant extraction of same bits
1383     BaseNode *opnd = result->Opnd(0);
1384     DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1385     Opcode opndOp = opnd->GetOpCode();
1386     if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1387         uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1388         uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1389         if (offset == opndOffset && size == opndSize) {
1390             result->SetOpnd(opnd->Opnd(0), 0);  // delete the redundant extraction
1391         }
1392     }
1393     if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1394         if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1395             return std::make_pair(result->Opnd(0), std::nullopt);
1396         }
1397     }
1398     return std::make_pair(result, std::nullopt);
1399 }
1400 
FoldIread(IreadNode * node)1401 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1402 {
1403     CHECK_NULL_FATAL(node);
1404     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1405     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1406     node->SetOpnd(e, 0);
1407     BaseNode *result = node;
1408     if (e->GetOpCode() != OP_addrof) {
1409         return std::make_pair(result, std::nullopt);
1410     }
1411 
1412     AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1413     MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1414     DEBUG_ASSERT(msy != nullptr, "nullptr check");
1415     TyIdx typeId = msy->GetTyIdx();
1416     CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1417     MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1418     MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1419     // If the high level type of iaddrof/iread doesn't match
1420     // the type of addrof's rhs, this optimization cannot be done.
1421     if (ptrType->GetPointedType() != msyType) {
1422         return std::make_pair(result, std::nullopt);
1423     }
1424 
1425     Opcode op = node->GetOpCode();
1426     if (op == OP_iread) {
1427         result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1428                                                                   node->GetFieldID() + addrofNode->GetFieldID());
1429     }
1430     return std::make_pair(result, std::nullopt);
1431 }
1432 
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1433 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1434 {
1435     switch (op) {
1436         case OP_add: {
1437             int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1438             if (IsUnsignedInteger(primType)) {
1439                 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1440             }
1441             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1442             return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1443                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1444                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1445                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1446         }
1447         case OP_sub: {
1448             if (IsUnsignedInteger(primType)) {
1449                 return cstA < cstB;
1450             }
1451             int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1452             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1453             return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1454                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1455                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1456                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1457         }
1458         default: {
1459             return false;
1460         }
1461     }
1462 }
1463 
FoldBinary(BinaryNode * node)1464 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1465 {
1466     CHECK_NULL_FATAL(node);
1467     BaseNode *result = nullptr;
1468     std::optional<IntVal> sum = std::nullopt;
1469     Opcode op = node->GetOpCode();
1470     PrimType primType = node->GetPrimType();
1471     PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1472     PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1473     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1474     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1475     BaseNode *l = lp.first;
1476     BaseNode *r = rp.first;
1477     ASSERT_NOT_NULL(r);
1478     ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1479     ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1480     bool isInt = IsPrimitiveInteger(primType);
1481 
1482     if (lConst != nullptr && rConst != nullptr) {
1483         MIRConst *lConstVal = lConst->GetConstVal();
1484         MIRConst *rConstVal = rConst->GetConstVal();
1485         ASSERT_NOT_NULL(lConstVal);
1486         ASSERT_NOT_NULL(rConstVal);
1487         // Don't fold div by 0, for floats div by 0 is well defined.
1488         if ((op == OP_div || op == OP_rem) && isInt &&
1489             !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1490             result = NewBinaryNode(node, op, primType, lConst, rConst);
1491         } else {
1492             // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1493             // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1494             // logic since the alternative is to return pair(result = nullptr, sum = 6).
1495             // Doing so would introduce many nullptr checks in the code. See previous
1496             // commits that implemented that logic for a comparison.
1497             result = FoldConstBinary(op, primType, *lConst, *rConst);
1498         }
1499     } else if (lConst != nullptr && isInt) {
1500         MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1501         ASSERT_NOT_NULL(mcst);
1502         PrimType cstTyp = mcst->GetType().GetPrimType();
1503         IntVal cst = mcst->GetValue();
1504         if (op == OP_add) {
1505             if (IsSignedInteger(cstTyp) && rp.second &&
1506                 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1507                 // do not introduce signed integer overflow
1508                 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1509             } else {
1510                 sum = cst + rp.second;
1511                 result = r;
1512             }
1513         } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1514             // We exclude u1 type for fixing the following wrong example:
1515             // before cf:
1516             //   sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1517             // after cf:
1518             //   add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1519             sum = cst - rp.second;
1520             if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1521                 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1522             }
1523             result = NegateTree(r);
1524         } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1525                     op == OP_band) &&
1526                     cst == 0) {
1527             // 0 * X -> 0
1528             // 0 / X -> 0
1529             // 0 % X -> 0
1530             // 0 >> X -> 0
1531             // 0 << X -> 0
1532             // 0 & X -> 0
1533             // 0 && X -> 0
1534             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1535         } else if (op == OP_mul && cst == 1) {
1536             // 1 * X --> X
1537             sum = rp.second;
1538             result = r;
1539         } else if (op == OP_bior && cst == -1) {
1540             // (-1) | X -> -1
1541             result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1542         } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1543             // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1544             sum = cst * rp.second;
1545             if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1546                 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1547             }
1548             result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1549         } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1550             // 0 | X -> X
1551             // 0 ^ X -> X
1552             sum = rp.second;
1553             result = r;
1554         } else {
1555             result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1556         }
1557         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1558             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1559         }
1560     } else if (rConst != nullptr && isInt) {
1561         MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1562         ASSERT_NOT_NULL(mcst);
1563         PrimType cstTyp = mcst->GetType().GetPrimType();
1564         IntVal cst = mcst->GetValue();
1565         if (op == OP_add) {
1566             if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1567                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1568             } else {
1569                 result = l;
1570                 sum = lp.second + cst;
1571             }
1572         } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1573             result = l;
1574             sum = lp.second - cst;
1575         } else if ((op == OP_mul || op == OP_band) && cst == 0) {
1576             // X * 0 -> 0
1577             // X & 0 -> 0
1578             // X && 0 -> 0
1579             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1580         } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1581             // case [X * 1 -> X]
1582             // case [X / 1 = X]
1583             sum = lp.second;
1584             result = l;
1585         } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1586                 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1587             // temporary fix for constfold of mul/div in DejaGnu
1588             // Later we need a more formal interface for pattern match
1589             // X * Y / Y -> X
1590             BaseNode *x = l->Opnd(0);
1591             BaseNode *y = l->Opnd(1);
1592             ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1593             ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1594             bool foldMulDiv = false;
1595             if (yConst != nullptr && xConst == nullptr &&
1596                 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1597                 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1598                 ASSERT_NOT_NULL(yCst);
1599                 IntVal mulCst = yCst->GetValue();
1600                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1601                     mulCst.GetExtValue() == cst.GetExtValue()) {
1602                     foldMulDiv = true;
1603                     result = x;
1604                 }
1605             } else if (xConst != nullptr && yConst == nullptr &&
1606                         IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1607                 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1608                 ASSERT_NOT_NULL(xCst);
1609                 IntVal mulCst = xCst->GetValue();
1610                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1611                     mulCst.GetExtValue() == cst.GetExtValue()) {
1612                     foldMulDiv = true;
1613                     result = y;
1614                 }
1615             }
1616             if (!foldMulDiv) {
1617                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1618             }
1619         } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1620             // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1621             sum = lp.second * cst;
1622             if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1623                 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1624             }
1625             if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1626                 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1627                 result = lp.first->Opnd(0);
1628             } else {
1629                 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1630             }
1631         } else if (op == OP_band && cst == -1) {
1632             // X & (-1) -> X
1633             sum = lp.second;
1634             result = l;
1635         } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1636                    (!lp.second.has_value() || lp.second == 0)) {
1637             bool fold2extractbits = false;
1638             if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1639                 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1640                 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1641                     ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1642                     int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1643                     uint64 ucst = cst.GetZXTValue();
1644                     uint32 bsize = 0;
1645                     do {
1646                         bsize++;
1647                         ucst >>= 1;
1648                     } while (ucst != 0);
1649                     if (shrAmt + static_cast<int64>(bsize) <=
1650                         static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1651                         static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1652                         fold2extractbits = true;
1653                         // change to use extractbits
1654                         result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1655                             GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1656                         sum = std::nullopt;
1657                     }
1658                 }
1659             }
1660             if (!fold2extractbits) {
1661                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1662                 sum = std::nullopt;
1663             }
1664         } else if (op == OP_bior && cst == -1) {
1665             // X | (-1) -> -1
1666             result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1667         } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1668             // X >> 0 -> X
1669             // X << 0 -> X
1670             // X | 0 -> X
1671             // X ^ 0 -> X
1672             sum = lp.second;
1673             result = l;
1674         } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1675             // bxor i32 (
1676             //   cvt i32 u1 (regread u1 %13),
1677             //  constValue i32 1),
1678             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1679             if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1680                 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1681                 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1682                     BaseNode *base = cvtNode->Opnd(0);
1683                     BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1684                     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1685                     BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1686                     result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1687                 }
1688             }
1689         } else if (op == OP_rem && cst == 1) {
1690             // X % 1 -> 0
1691             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1692         } else {
1693             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1694         }
1695         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1696             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1697         }
1698     } else if (isInt && (op == OP_add || op == OP_sub)) {
1699         if (op == OP_add) {
1700             result = NewBinaryNode(node, op, primType, l, r);
1701             sum = lp.second + rp.second;
1702         } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1703             // if fold is (x - (y - z))    ->     (x - neg(z)) - y
1704             // (x - neg(z)) Could cross the int limit
1705             // return node
1706             result = node;
1707         } else {
1708             result = NewBinaryNode(node, op, primType, l, r);
1709             sum = lp.second - rp.second;
1710         }
1711     } else {
1712         result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1713     }
1714     return std::make_pair(result, sum);
1715 }
1716 
SimplifyDoubleConstvalCompare(CompareNode & node,bool isRConstval,bool isGtOrLt) const1717 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
1718 {
1719     if (isRConstval) {
1720         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
1721         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1722             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
1723             return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1724                 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1725         }
1726     } else {
1727         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
1728         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1729             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
1730             if (isGtOrLt) {
1731                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1732                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
1733             } else {
1734                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1735                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1736             }
1737         }
1738     }
1739     return &node;
1740 }
1741 
SimplifyDoubleCompare(CompareNode & compareNode) const1742 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1743 {
1744     // See arm manual B.cond(P2993) and FCMP(P1091)
1745     CompareNode *node = &compareNode;
1746     BaseNode *result = node;
1747     BaseNode *l = node->Opnd(0);
1748     BaseNode *r = node->Opnd(1);
1749     if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1750         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1751             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1752             result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
1753         } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1754             // ne (u1 x, constValue 0)  <==> x
1755             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1756             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1757                 BaseNode *opnd = l;
1758                 do {
1759                     if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1760                         result = opnd;
1761                         break;
1762                     } else if (opnd->GetOpCode() == OP_cvt) {
1763                         TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
1764                         opnd = cvtNode->Opnd(0);
1765                     } else {
1766                         opnd = nullptr;
1767                     }
1768                 } while (opnd != nullptr);
1769             }
1770         } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1771             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1772             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1773                 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1774                 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
1775                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1776                     resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
1777             }
1778         }
1779     } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1780         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1781             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1782             result = SimplifyDoubleConstvalCompare(*node,
1783                 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
1784         }
1785     }
1786     return result;
1787 }
1788 
FoldCompare(CompareNode * node)1789 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
1790 {
1791     CHECK_NULL_FATAL(node);
1792     BaseNode *result = nullptr;
1793     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1794     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1795     ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
1796     ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
1797     Opcode opcode = node->GetOpCode();
1798     if (lConst != nullptr && rConst != nullptr) {
1799         result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
1800     } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
1801                lConst->GetConstVal()->GetKind() == kConstInt) {
1802         BaseNode *l = lp.first;
1803         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1804         result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
1805     } else {
1806         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
1807         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1808         if (l != node->Opnd(0) || r != node->Opnd(1)) {
1809             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1810                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
1811         } else {
1812             result = node;
1813         }
1814         auto *compareNode = static_cast<CompareNode*>(result);
1815         CHECK_NULL_FATAL(compareNode);
1816         result = SimplifyDoubleCompare(*compareNode);
1817     }
1818     return std::make_pair(result, std::nullopt);
1819 }
1820 
Fold(BaseNode * node)1821 BaseNode *ConstantFold::Fold(BaseNode *node)
1822 {
1823     if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
1824         return nullptr;
1825     }
1826     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
1827     BaseNode *result = PairToExpr(node->GetPrimType(), p);
1828     if (result == node) {
1829         result = nullptr;
1830     }
1831     return result;
1832 }
1833 
1834 }  // namespace maple
1835