1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "constantfold.h"
17 #include <cfloat>
18
19 namespace maple {
20
21 namespace {
22 constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit
23 constexpr uint32 kBitSizePerByte = 8;
24 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
25
26 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
27
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)28 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
29 {
30 if (!v1 && !v2) {
31 return std::nullopt;
32 }
33
34 // Perform all calculations in terms of the maximum available signed type.
35 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
36 return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
37 }
38
39 // Perform all calculations in terms of the maximum available signed type.
40 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)41 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
42 {
43 if (!v1 && !v2) {
44 return std::nullopt;
45 }
46
47 if (v1 && v2) {
48 return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
49 }
50
51 if (v1) {
52 return v1->TruncOrExtend(PTY_i64);
53 }
54
55 // !v1 && v2
56 return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
57 }
58
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)59 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
60 {
61 return AddSub(v1, v2, true);
62 }
63
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)64 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
65 {
66 return AddSub(v1, v2, false);
67 }
68
69 } // anonymous namespace
70
71 // This phase is designed to achieve compiler optimization by
72 // simplifying constant expressions. The constant expression
73 // is evaluated and replaced by the value calculated on compile
74 // time to save time on runtime.
75 //
76 // The main procedure shows as following:
77 // A. Analyze expression type
78 // B. Analysis operator type
79 // C. Replace the expression with the result of the operation
80
81 // true if the constant's bits are made of only one group of contiguous 1's
82 // starting at bit 0
ContiguousBitsOf1(uint64 x)83 static bool ContiguousBitsOf1(uint64 x)
84 {
85 if (x == 0) {
86 return false;
87 }
88 return (~x & (x + 1)) == (x + 1);
89 }
90
IsPowerOf2(uint64 num)91 inline bool IsPowerOf2(uint64 num)
92 {
93 if (num == 0) {
94 return false;
95 }
96 return (~(num - 1) & num) == num;
97 }
98
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const99 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
100 BaseNode *rhs) const
101 {
102 CHECK_NULL_FATAL(old);
103 BinaryNode *result = nullptr;
104 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
105 result = old;
106 } else {
107 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
108 }
109 return result;
110 }
111
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const112 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
113 {
114 CHECK_NULL_FATAL(old);
115 UnaryNode *result = nullptr;
116 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
117 result = old;
118 } else {
119 result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
120 }
121 return result;
122 }
123
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const124 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
125 {
126 CHECK_NULL_FATAL(pair.first);
127 BaseNode *result = pair.first;
128 if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
129 return result;
130 }
131 if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
132 // -a, 5 -> 5 - a
133 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
134 static_cast<uint64>(pair.second->GetExtValue()), resultType);
135 BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
136 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
137 } else {
138 if ((!pair.second->GetSignBit() &&
139 pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
140 pair.second->TruncOrExtend(resultType).IsMinValue() ||
141 pair.second->GetSXTValue() == INT64_MIN) {
142 // +-a, 5 -> a + 5
143 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
144 static_cast<uint64>(pair.second->GetExtValue()), resultType);
145 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
146 } else {
147 // +-a, -5 -> a + -5
148 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
149 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
150 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
151 }
152 }
153 return result;
154 }
155
FoldBase(BaseNode * node) const156 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
157 {
158 return std::make_pair(node, std::nullopt);
159 }
160
DispatchFold(BaseNode * node)161 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
162 {
163 CHECK_NULL_FATAL(node);
164 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
165 return {node, std::nullopt};
166 }
167 switch (node->GetOpCode()) {
168 case OP_abs:
169 case OP_bnot:
170 case OP_lnot:
171 case OP_neg:
172 case OP_sqrt:
173 return FoldUnary(static_cast<UnaryNode*>(node));
174 case OP_ceil:
175 case OP_floor:
176 case OP_trunc:
177 case OP_cvt:
178 return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
179 case OP_sext:
180 case OP_zext:
181 case OP_extractbits:
182 return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
183 case OP_iread:
184 return FoldIread(static_cast<IreadNode*>(node));
185 case OP_add:
186 case OP_ashr:
187 case OP_band:
188 case OP_bior:
189 case OP_bxor:
190 case OP_div:
191 case OP_lshr:
192 case OP_max:
193 case OP_min:
194 case OP_mul:
195 case OP_rem:
196 case OP_shl:
197 case OP_sub:
198 return FoldBinary(static_cast<BinaryNode*>(node));
199 case OP_eq:
200 case OP_ne:
201 case OP_ge:
202 case OP_gt:
203 case OP_le:
204 case OP_lt:
205 case OP_cmp:
206 return FoldCompare(static_cast<CompareNode*>(node));
207 case OP_retype:
208 return FoldRetype(static_cast<RetypeNode*>(node));
209 default:
210 return FoldBase(static_cast<BaseNode*>(node));
211 }
212 }
213
Negate(BaseNode * node) const214 BaseNode *ConstantFold::Negate(BaseNode *node) const
215 {
216 CHECK_NULL_FATAL(node);
217 return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
218 }
219
Negate(UnaryNode * node) const220 BaseNode *ConstantFold::Negate(UnaryNode *node) const
221 {
222 CHECK_NULL_FATAL(node);
223 BaseNode *result = nullptr;
224 if (node->GetOpCode() == OP_neg) {
225 result = static_cast<BaseNode*>(node->Opnd(0));
226 } else {
227 BaseNode *n = static_cast<BaseNode*>(node);
228 result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
229 }
230 return result;
231 }
232
Negate(const ConstvalNode * node) const233 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
234 {
235 CHECK_NULL_FATAL(node);
236 ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
237 CHECK_NULL_FATAL(copy);
238 copy->GetConstVal()->Neg();
239 return copy;
240 }
241
NegateTree(BaseNode * node) const242 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
243 {
244 CHECK_NULL_FATAL(node);
245 if (node->IsUnaryNode()) {
246 return Negate(static_cast<UnaryNode*>(node));
247 } else if (node->GetOpCode() == OP_constval) {
248 return Negate(static_cast<ConstvalNode*>(node));
249 } else {
250 return Negate(static_cast<BaseNode*>(node));
251 }
252 }
253
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const254 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
255 const MIRIntConst &intConst0,
256 const MIRIntConst &intConst1) const
257 {
258 uint64 result = 0;
259
260 bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
261 bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
262 bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
263
264 switch (opcode) {
265 case OP_eq: {
266 result = equal;
267 break;
268 }
269 case OP_ge: {
270 result = (greater || equal) ? 1 : 0;
271 break;
272 }
273 case OP_gt: {
274 result = greater;
275 break;
276 }
277 case OP_le: {
278 result = (less || equal) ? 1 : 0;
279 break;
280 }
281 case OP_lt: {
282 result = less;
283 break;
284 }
285 case OP_ne: {
286 result = !equal;
287 break;
288 }
289 case OP_cmp: {
290 if (greater) {
291 result = kGreater;
292 } else if (equal) {
293 result = kEqual;
294 } else if (less) {
295 result = static_cast<uint64>(kLess);
296 }
297 break;
298 }
299 default:
300 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
301 break;
302 }
303 // determine the type
304 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
305 // form the constant
306 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
307 return constValue;
308 }
309
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const310 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
311 const ConstvalNode &const0, const ConstvalNode &const1) const
312 {
313 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
314 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
315 CHECK_NULL_FATAL(intConst0);
316 CHECK_NULL_FATAL(intConst1);
317 MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
318 // form the ConstvalNode
319 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
320 resultConst->SetPrimType(resultType);
321 resultConst->SetConstVal(constValue);
322 return resultConst;
323 }
324
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst & intConst0,const MIRIntConst & intConst1)325 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
326 const MIRIntConst &intConst1)
327 {
328 IntVal intVal0 = intConst0.GetValue();
329 IntVal intVal1 = intConst1.GetValue();
330 IntVal result(static_cast<uint64>(0), resultType);
331
332 switch (opcode) {
333 case OP_add: {
334 result = intVal0.Add(intVal1, resultType);
335 break;
336 }
337 case OP_sub: {
338 result = intVal0.Sub(intVal1, resultType);
339 break;
340 }
341 case OP_mul: {
342 result = intVal0.Mul(intVal1, resultType);
343 break;
344 }
345 case OP_div: {
346 result = intVal0.Div(intVal1, resultType);
347 break;
348 }
349 case OP_rem: {
350 result = intVal0.Rem(intVal1, resultType);
351 break;
352 }
353 case OP_ashr: {
354 result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
355 break;
356 }
357 case OP_lshr: {
358 result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
359 break;
360 }
361 case OP_shl: {
362 result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
363 break;
364 }
365 case OP_max: {
366 result = Max(intVal0, intVal1, resultType);
367 break;
368 }
369 case OP_min: {
370 result = Min(intVal0, intVal1, resultType);
371 break;
372 }
373 case OP_band: {
374 result = intVal0.And(intVal1, resultType);
375 break;
376 }
377 case OP_bior: {
378 result = intVal0.Or(intVal1, resultType);
379 break;
380 }
381 case OP_bxor: {
382 result = intVal0.Xor(intVal1, resultType);
383 break;
384 }
385 default:
386 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
387 break;
388 }
389 // determine the type
390 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
391 // form the constant
392 MIRIntConst *constValue =
393 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
394 return constValue;
395 }
396
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const397 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
398 const ConstvalNode &const1) const
399 {
400 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
401 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
402 CHECK_NULL_FATAL(intConst0);
403 CHECK_NULL_FATAL(intConst1);
404 MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
405 // form the ConstvalNode
406 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
407 resultConst->SetPrimType(resultType);
408 resultConst->SetConstVal(constValue);
409 return resultConst;
410 }
411
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const412 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
413 const ConstvalNode &const1) const
414 {
415 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
416 const MIRDoubleConst *doubleConst0 = nullptr;
417 const MIRDoubleConst *doubleConst1 = nullptr;
418 const MIRFloatConst *floatConst0 = nullptr;
419 const MIRFloatConst *floatConst1 = nullptr;
420 bool useDouble = (const0.GetPrimType() == PTY_f64);
421 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
422 resultConst->SetPrimType(resultType);
423 if (useDouble) {
424 doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
425 doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
426 CHECK_NULL_FATAL(doubleConst0);
427 CHECK_NULL_FATAL(doubleConst1);
428 } else {
429 floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
430 floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
431 CHECK_NULL_FATAL(floatConst0);
432 CHECK_NULL_FATAL(floatConst1);
433 }
434 float constValueFloat = 0.0;
435 double constValueDouble = 0.0;
436 switch (opcode) {
437 case OP_add: {
438 if (useDouble) {
439 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
440 } else {
441 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
442 }
443 break;
444 }
445 case OP_sub: {
446 if (useDouble) {
447 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
448 } else {
449 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
450 }
451 break;
452 }
453 case OP_mul: {
454 if (useDouble) {
455 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
456 } else {
457 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
458 }
459 break;
460 }
461 case OP_div: {
462 // for floats div by 0 is well defined
463 if (useDouble) {
464 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
465 } else {
466 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
467 }
468 break;
469 }
470 case OP_max: {
471 if (useDouble) {
472 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
473 : doubleConst1->GetValue();
474 } else {
475 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
476 : floatConst1->GetValue();
477 }
478 break;
479 }
480 case OP_min: {
481 if (useDouble) {
482 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
483 : doubleConst1->GetValue();
484 } else {
485 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
486 : floatConst1->GetValue();
487 }
488 break;
489 }
490 case OP_rem:
491 case OP_ashr:
492 case OP_lshr:
493 case OP_shl:
494 case OP_band:
495 case OP_bior:
496 case OP_bxor: {
497 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
498 break;
499 }
500 default:
501 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
502 break;
503 }
504 if (resultType == PTY_f64) {
505 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
506 } else {
507 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
508 }
509 return resultConst;
510 }
511
ConstValueEqual(int64 leftValue,int64 rightValue) const512 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
513 {
514 return (leftValue == rightValue);
515 }
516
ConstValueEqual(float leftValue,float rightValue) const517 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
518 {
519 auto result = fabs(leftValue - rightValue);
520 return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
521 }
522
ConstValueEqual(double leftValue,double rightValue) const523 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
524 {
525 auto result = fabs(leftValue - rightValue);
526 return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
527 }
528
529 template<typename T>
FullyEqual(T leftValue,T rightValue) const530 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
531 {
532 if (std::isinf(leftValue) && std::isinf(rightValue)) {
533 // (inf == inf), add the judgement here in case of the subtraction between float type inf
534 return true;
535 } else {
536 return ConstValueEqual(leftValue, rightValue);
537 }
538 }
539
540 template<typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const541 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
542 {
543 DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
544 typename T::value_type leftValue = leftConst->GetValue();
545 DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
546 typename T::value_type rightValue = rightConst->GetValue();
547 int64 result = 0;
548 switch (op) {
549 case OP_eq: {
550 result = FullyEqual(leftValue, rightValue);
551 break;
552 }
553 case OP_ge: {
554 result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
555 break;
556 }
557 case OP_gt: {
558 result = (leftValue > rightValue);
559 break;
560 }
561 case OP_le: {
562 result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
563 break;
564 }
565 case OP_lt: {
566 result = (leftValue < rightValue);
567 break;
568 }
569 case OP_ne: {
570 result = !FullyEqual(leftValue, rightValue);
571 break;
572 }
573 [[clang::fallthrough]];
574 case OP_cmp: {
575 if (leftValue > rightValue) {
576 result = kGreater;
577 } else if (FullyEqual(leftValue, rightValue)) {
578 result = kEqual;
579 } else {
580 result = kLess;
581 }
582 break;
583 }
584 default:
585 DEBUG_ASSERT(false, "Unknown opcode for Comparison");
586 break;
587 }
588 return result;
589 }
590
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const591 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
592 const MIRConst &leftConst, const MIRConst &rightConst) const
593 {
594 int64 result = 0;
595 bool useDouble = (opndType == PTY_f64);
596 if (useDouble) {
597 result =
598 ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
599 } else {
600 result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
601 }
602 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
603 MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
604 return resultConst;
605 }
606
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const607 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
608 const ConstvalNode &const0, const ConstvalNode &const1) const
609 {
610 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
611 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
612 resultConst->SetPrimType(resultType);
613 resultConst->SetConstVal(
614 FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
615 return resultConst;
616 }
617
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1) const618 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
619 const MIRConst &const0, const MIRConst &const1) const
620 {
621 MIRConst *returnValue = nullptr;
622 if (IsPrimitiveInteger(opndType)) {
623 const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
624 const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
625 ASSERT_NOT_NULL(intConst0);
626 ASSERT_NOT_NULL(intConst1);
627 returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
628 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
629 returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
630 } else {
631 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
632 }
633 return returnValue;
634 }
635
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const636 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
637 const ConstvalNode &const0, const ConstvalNode &const1) const
638 {
639 ConstvalNode *returnValue = nullptr;
640 if (IsPrimitiveInteger(opndType)) {
641 returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
642 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
643 returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
644 } else {
645 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
646 }
647 return returnValue;
648 }
649
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r) const650 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
651 BaseNode &l, BaseNode &r) const
652 {
653 CompareNode *result = nullptr;
654 Opcode op = opcode;
655 switch (opcode) {
656 case OP_gt: {
657 op = OP_lt;
658 break;
659 }
660 case OP_lt: {
661 op = OP_gt;
662 break;
663 }
664 case OP_ge: {
665 op = OP_le;
666 break;
667 }
668 case OP_le: {
669 op = OP_ge;
670 break;
671 }
672 case OP_eq: {
673 break;
674 }
675 case OP_ne: {
676 break;
677 }
678 default:
679 DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
680 break;
681 }
682
683 result =
684 mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
685 return result;
686 }
687
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const688 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
689 const ConstvalNode &const1) const
690 {
691 ConstvalNode *returnValue = nullptr;
692 if (IsPrimitiveInteger(resultType)) {
693 returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
694 } else if (resultType == PTY_f32 || resultType == PTY_f64) {
695 returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
696 } else {
697 DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
698 }
699 return returnValue;
700 }
701
FoldIntConstUnaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * constNode)702 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
703 {
704 CHECK_NULL_FATAL(constNode);
705 IntVal result = constNode->GetValue().TruncOrExtend(resultType);
706 switch (opcode) {
707 case OP_abs: {
708 if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
709 result = -result;
710 }
711 break;
712 }
713 case OP_bnot: {
714 result = ~result;
715 break;
716 }
717 case OP_lnot: {
718 uint64 resultInt = result == 0 ? 1 : 0;
719 result = {resultInt, resultType};
720 break;
721 }
722 case OP_neg: {
723 result = -result;
724 break;
725 }
726 case OP_sext: // handled in FoldExtractbits
727 case OP_zext: // handled in FoldExtractbits
728 case OP_extractbits: // handled in FoldExtractbits
729 case OP_sqrt: {
730 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
731 break;
732 }
733 default:
734 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
735 break;
736 }
737 // determine the type
738 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
739 // form the constant
740 MIRIntConst *constValue =
741 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
742 return constValue;
743 }
744
745 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const746 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
747 {
748 CHECK_NULL_FATAL(constNode);
749 double constValue = 0;
750 T *fpCst = static_cast<T*>(constNode->GetConstVal());
751 switch (opcode) {
752 case OP_neg: {
753 constValue = typename T::value_type(-fpCst->GetValue());
754 break;
755 }
756 case OP_abs: {
757 constValue = typename T::value_type(fabs(fpCst->GetValue()));
758 break;
759 }
760 case OP_sqrt: {
761 constValue = typename T::value_type(sqrt(fpCst->GetValue()));
762 break;
763 }
764 case OP_bnot:
765 case OP_lnot:
766 case OP_sext:
767 case OP_zext:
768 case OP_extractbits: {
769 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
770 break;
771 }
772 default:
773 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
774 break;
775 }
776 auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
777 resultConst->SetPrimType(resultType);
778 if (resultType == PTY_f32) {
779 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
780 } else if (resultType == PTY_f64) {
781 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
782 } else {
783 CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
784 }
785 return resultConst;
786 }
787
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode & constNode) const788 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
789 {
790 ConstvalNode *returnValue = nullptr;
791 if (IsPrimitiveInteger(resultType)) {
792 const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
793 auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
794 returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
795 returnValue->SetPrimType(resultType);
796 returnValue->SetConstVal(constValue);
797 } else if (resultType == PTY_f32) {
798 returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
799 } else if (resultType == PTY_f64) {
800 returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
801 } else {
802 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
803 }
804 return returnValue;
805 }
806
FoldRetype(RetypeNode * node)807 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
808 {
809 CHECK_NULL_FATAL(node);
810 BaseNode *result = node;
811 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
812 if (node->Opnd(0) != p.first) {
813 RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
814 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
815 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
816 result = newRetNode;
817 }
818 return std::make_pair(result, std::nullopt);
819 }
820
FoldUnary(UnaryNode * node)821 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
822 {
823 CHECK_NULL_FATAL(node);
824 BaseNode *result = nullptr;
825 std::optional<IntVal> sum = std::nullopt;
826 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
827 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
828 if (cst != nullptr) {
829 result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
830 } else {
831 bool isInt = IsPrimitiveInteger(node->GetPrimType());
832 // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
833 // primType will be set to opnd type. There will be problems in some cases. For example:
834 // before cf:
835 // neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
836 // after cf:
837 // neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) # wrong!
838 // As a workaround, we exclude u1 opnd type
839 if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
840 result = NegateTree(p.first);
841 if (result->GetOpCode() == OP_neg) {
842 PrimType origPtyp = node->GetPrimType();
843 PrimType newPtyp = result->GetPrimType();
844 if (newPtyp == origPtyp) {
845 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
846 // NegateTree returned an UnaryNode quivalent to `n`, so keep the
847 // original UnaryNode to preserve identity
848 result = node;
849 }
850 } else {
851 if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
852 // do not fold explicit cvt
853 result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
854 PairToExpr(node->Opnd(0)->GetPrimType(), p));
855 return std::make_pair(result, std::nullopt);
856 } else {
857 result->SetPrimType(origPtyp);
858 }
859 }
860 }
861 if (p.second) {
862 sum = -(*p.second);
863 }
864 } else {
865 result =
866 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
867 }
868 }
869 return std::make_pair(result, sum);
870 }
871
FloatToIntOverflow(float fval,PrimType totype)872 static bool FloatToIntOverflow(float fval, PrimType totype)
873 {
874 static const float safeFloatMaxToInt32 = 2147483520.0f; // 2^31 - 128
875 static const float safeFloatMinToInt32 = -2147483520.0f;
876 static const float safeFloatMaxToInt64 = 9223372036854775680.0f; // 2^63 - 128
877 static const float safeFloatMinToInt64 = -9223372036854775680.0f;
878 if (!std::isfinite(fval)) {
879 return true;
880 }
881 if (totype == PTY_i64 || totype == PTY_u64) {
882 if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
883 return true;
884 }
885 } else {
886 if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
887 return true;
888 }
889 }
890 return false;
891 }
892
DoubleToIntOverflow(double dval,PrimType totype)893 static bool DoubleToIntOverflow(double dval, PrimType totype)
894 {
895 static const double safeDoubleMaxToInt32 = 2147482624.0; // 2^31 - 1024
896 static const double safeDoubleMinToInt32 = -2147482624.0;
897 static const double safeDoubleMaxToInt64 = 9223372036854774784.0; // 2^63 - 1024
898 static const double safeDoubleMinToInt64 = -9223372036854774784.0;
899 if (!std::isfinite(dval)) {
900 return true;
901 }
902 if (totype == PTY_i64 || totype == PTY_u64) {
903 if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
904 return true;
905 }
906 } else {
907 if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
908 return true;
909 }
910 }
911 return false;
912 }
913
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const914 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
915 {
916 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
917 resultConst->SetPrimType(toType);
918 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
919 if (fromType == PTY_f32) {
920 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
921 ASSERT_NOT_NULL(constValue);
922 float floatValue = ceil(constValue->GetValue());
923 if (IsPrimitiveFloat(toType)) {
924 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
925 } else if (FloatToIntOverflow(floatValue, toType)) {
926 return nullptr;
927 } else {
928 resultConst->SetConstVal(
929 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
930 }
931 } else {
932 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
933 ASSERT_NOT_NULL(constValue);
934 double doubleValue = ceil(constValue->GetValue());
935 if (IsPrimitiveFloat(toType)) {
936 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
937 } else if (DoubleToIntOverflow(doubleValue, toType)) {
938 return nullptr;
939 } else {
940 resultConst->SetConstVal(
941 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
942 }
943 }
944 return resultConst;
945 }
946
947 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const948 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
949 {
950 DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
951 size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
952 bool isSigned = IsSignedInteger(resultType.GetPrimType());
953 int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
954 uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
955 int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
956 if (isSigned && (value > max)) {
957 return static_cast<T>(max);
958 } else if (!isSigned && (value > umax)) {
959 return static_cast<T>(umax);
960 } else if (value < min) {
961 return static_cast<T>(min);
962 }
963 return value;
964 }
965
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const966 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
967 {
968 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
969 if (fromType == PTY_f32) {
970 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
971 float floatValue = constValue.GetValue();
972 if (isFloor) {
973 floatValue = floor(constValue.GetValue());
974 }
975 if (IsPrimitiveFloat(toType)) {
976 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
977 }
978 if (FloatToIntOverflow(floatValue, toType)) {
979 return nullptr;
980 }
981 floatValue = CalIntValueFromFloatValue(floatValue, resultType);
982 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
983 } else {
984 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
985 double doubleValue = constValue.GetValue();
986 if (isFloor) {
987 doubleValue = floor(constValue.GetValue());
988 }
989 if (IsPrimitiveFloat(toType)) {
990 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
991 }
992 if (DoubleToIntOverflow(doubleValue, toType)) {
993 return nullptr;
994 }
995 doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
996 // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
997 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
998 }
999 }
1000
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1001 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1002 {
1003 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1004 resultConst->SetPrimType(toType);
1005 resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1006 return resultConst;
1007 }
1008
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1009 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1010 {
1011 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1012 if (fromType == PTY_f32) {
1013 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1014 float floatValue = round(constValue.GetValue());
1015 if (FloatToIntOverflow(floatValue, toType)) {
1016 return nullptr;
1017 }
1018 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1019 } else if (fromType == PTY_f64) {
1020 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1021 double doubleValue = round(constValue.GetValue());
1022 if (DoubleToIntOverflow(doubleValue, toType)) {
1023 return nullptr;
1024 }
1025 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1026 static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1027 } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1028 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1029 if (IsSignedInteger(fromType)) {
1030 int64 fromValue = constValue.GetExtValue();
1031 float floatValue = round(static_cast<float>(fromValue));
1032 if (static_cast<int64>(floatValue) == fromValue) {
1033 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1034 }
1035 } else {
1036 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1037 float floatValue = round(static_cast<float>(fromValue));
1038 if (static_cast<uint64>(floatValue) == fromValue) {
1039 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1040 }
1041 }
1042 } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1043 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1044 if (IsSignedInteger(fromType)) {
1045 int64 fromValue = constValue.GetExtValue();
1046 double doubleValue = round(static_cast<double>(fromValue));
1047 if (static_cast<int64>(doubleValue) == fromValue) {
1048 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1049 }
1050 } else {
1051 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1052 double doubleValue = round(static_cast<double>(fromValue));
1053 if (static_cast<uint64>(doubleValue) == fromValue) {
1054 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1055 }
1056 }
1057 }
1058 return nullptr;
1059 }
1060
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1061 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1062 {
1063 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1064 resultConst->SetPrimType(toType);
1065 resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1066 return resultConst;
1067 }
1068
FoldTrunc(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1069 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1070 {
1071 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1072 resultConst->SetPrimType(toType);
1073 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1074 if (fromType == PTY_f32) {
1075 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1076 CHECK_NULL_FATAL(constValue);
1077 float floatValue = trunc(constValue->GetValue());
1078 if (IsPrimitiveFloat(toType)) {
1079 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1080 } else if (FloatToIntOverflow(floatValue, toType)) {
1081 return nullptr;
1082 } else {
1083 resultConst->SetConstVal(
1084 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1085 }
1086 } else {
1087 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1088 CHECK_NULL_FATAL(constValue);
1089 double doubleValue = trunc(constValue->GetValue());
1090 if (IsPrimitiveFloat(toType)) {
1091 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1092 } else if (DoubleToIntOverflow(doubleValue, toType)) {
1093 return nullptr;
1094 } else {
1095 resultConst->SetConstVal(
1096 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1097 }
1098 }
1099 return resultConst;
1100 }
1101
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1102 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1103 {
1104 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1105 MIRConst *toConst = nullptr;
1106 uint32 fromSize = GetPrimTypeBitSize(fromType);
1107 uint32 toSize = GetPrimTypeBitSize(toType);
1108 // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1109 if (fromType == PTY_u1) {
1110 fromSize = 1;
1111 }
1112 if (toType == PTY_u1) {
1113 toSize = 1;
1114 }
1115 if (toSize > fromSize) {
1116 Opcode op = OP_zext;
1117 if (IsSignedInteger(fromType)) {
1118 op = OP_sext;
1119 }
1120 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1121 ASSERT_NOT_NULL(constVal);
1122 toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1123 constVal->GetValue().TruncOrExtend(fromType));
1124 } else {
1125 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1126 ASSERT_NOT_NULL(constVal);
1127 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1128 toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1129 static_cast<uint64>(constVal->GetExtValue()), type);
1130 }
1131 return toConst;
1132 }
1133 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1134 MIRConst *toConst = nullptr;
1135 if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1136 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1137 const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1138 ASSERT_NOT_NULL(fromValue);
1139 float floatValue = static_cast<float>(fromValue->GetValue());
1140 MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1141 toConst = toValue;
1142 } else {
1143 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1144 const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1145 ASSERT_NOT_NULL(fromValue);
1146 double doubleValue = static_cast<double>(fromValue->GetValue());
1147 MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1148 toConst = toValue;
1149 }
1150 return toConst;
1151 }
1152 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1153 return FoldFloorMIRConst(cst, fromType, toType, false);
1154 }
1155 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1156 return FoldRoundMIRConst(cst, fromType, toType);
1157 }
1158 CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1159 return nullptr;
1160 }
1161
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1162 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1163 {
1164 MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1165 if (toConstValue == nullptr) {
1166 return nullptr;
1167 }
1168 ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1169 toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1170 toConst->SetConstVal(toConstValue);
1171 return toConst;
1172 }
1173
1174 // return a primType with bit size >= bitSize (and the nearest one),
1175 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1176 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1177 {
1178 bool isSigned = IsSignedInteger(ptyp);
1179 bool isFloat = IsPrimitiveFloat(ptyp);
1180 if (bitSize == 1) { // 1 bit
1181 return PTY_u1;
1182 }
1183 if (bitSize <= 8) { // 8 bit
1184 return isSigned ? PTY_i8 : PTY_u8;
1185 }
1186 if (bitSize <= 16) { // 16 bit
1187 return isSigned ? PTY_i16 : PTY_u16;
1188 }
1189 if (bitSize <= 32) { // 32 bit
1190 return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1191 }
1192 if (bitSize <= 64) { // 64 bit
1193 return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1194 }
1195 return ptyp;
1196 }
1197
GetIntPrimTypeMax(PrimType ptyp)1198 size_t GetIntPrimTypeMax(PrimType ptyp)
1199 {
1200 switch (ptyp) {
1201 case PTY_u1:
1202 return 1;
1203 case PTY_u8:
1204 return UINT8_MAX;
1205 case PTY_i8:
1206 return INT8_MAX;
1207 case PTY_u16:
1208 return UINT16_MAX;
1209 case PTY_i16:
1210 return INT16_MAX;
1211 case PTY_u32:
1212 return UINT32_MAX;
1213 case PTY_i32:
1214 return INT32_MAX;
1215 case PTY_u64:
1216 return UINT64_MAX;
1217 case PTY_i64:
1218 return INT64_MAX;
1219 default:
1220 CHECK_FATAL(false, "NYI");
1221 }
1222 }
1223
GetIntPrimTypeMin(PrimType ptyp)1224 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1225 {
1226 if (IsUnsignedInteger(ptyp)) {
1227 return 0;
1228 }
1229 switch (ptyp) {
1230 case PTY_i8:
1231 return INT8_MIN;
1232 case PTY_i16:
1233 return INT16_MIN;
1234 case PTY_i32:
1235 return INT32_MIN;
1236 case PTY_i64:
1237 return INT64_MIN;
1238 default:
1239 CHECK_FATAL(false, "NYI");
1240 }
1241 }
1242
IsCvtEliminatable(PrimType fromPtyp,PrimType destPtyp,Opcode op,Opcode opndOp)1243 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1244 {
1245 if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1246 return false;
1247 }
1248 if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1249 return false;
1250 }
1251 return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1252 (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1253 (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1254 }
1255
FoldTypeCvt(TypeCvtNode * node)1256 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1257 {
1258 CHECK_NULL_FATAL(node);
1259 BaseNode *result = nullptr;
1260 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1261 return {node, std::nullopt};
1262 }
1263 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1264 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1265 PrimType destPtyp = node->GetPrimType();
1266 PrimType fromPtyp = node->FromType();
1267 if (cst != nullptr) {
1268 switch (node->GetOpCode()) {
1269 case OP_ceil: {
1270 result = FoldCeil(*cst, fromPtyp, destPtyp);
1271 break;
1272 }
1273 case OP_cvt: {
1274 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1275 break;
1276 }
1277 case OP_floor: {
1278 result = FoldFloor(*cst, fromPtyp, destPtyp);
1279 break;
1280 }
1281 case OP_trunc: {
1282 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1283 break;
1284 }
1285 default:
1286 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1287 break;
1288 }
1289 } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1290 // the cvt is redundant
1291 return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1292 }
1293 if (result == nullptr) {
1294 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1295 if (e != node->Opnd(0)) {
1296 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1297 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1298 } else {
1299 result = node;
1300 }
1301 }
1302 return std::make_pair(result, std::nullopt);
1303 }
1304
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1305 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1306 {
1307 uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1308 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1309 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1310 return constValue;
1311 }
1312
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1313 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1314 const ConstvalNode &cst) const
1315 {
1316 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1317 const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1318 ASSERT_NOT_NULL(intCst);
1319 IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1320 MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1321 resultConst->SetPrimType(toConst->GetType().GetPrimType());
1322 resultConst->SetConstVal(toConst);
1323 return resultConst;
1324 }
1325
1326 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode & x,MIRFunction & f)1327 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1328 {
1329 if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1330 return false; // this is trying to be conservative
1331 }
1332 BaseNode *opnd = x.Opnd(0);
1333 MIRType *mirType = nullptr;
1334 if (opnd->GetOpCode() == OP_dread) {
1335 DreadNode *dread = static_cast<DreadNode*>(opnd);
1336 MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1337 ASSERT_NOT_NULL(sym);
1338 mirType = sym->GetType();
1339 } else if (opnd->GetOpCode() == OP_iread) {
1340 IreadNode *iread = static_cast<IreadNode*>(opnd);
1341 MIRPtrType *ptrType =
1342 dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1343 if (ptrType == nullptr) {
1344 return false;
1345 }
1346 mirType = ptrType->GetPointedType();
1347 } else if (opnd->GetOpCode() == OP_extractbits &&
1348 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1349 return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1350 IsUnsignedInteger(opnd->GetPrimType()));
1351 } else {
1352 return false;
1353 }
1354 return IsPrimitiveInteger(mirType->GetPrimType()) &&
1355 ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1356 (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1357 mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1358 mirType->GetPrimType() == x.GetPrimType();
1359 }
1360
1361 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1362 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1363 {
1364 CHECK_NULL_FATAL(node);
1365 BaseNode *result = nullptr;
1366 uint8 offset = node->GetBitsOffset();
1367 uint8 size = node->GetBitsSize();
1368 Opcode opcode = node->GetOpCode();
1369 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1370 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1371 if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1372 result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1373 return std::make_pair(result, std::nullopt);
1374 }
1375 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1376 if (e != node->Opnd(0)) {
1377 result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1378 size, e);
1379 } else {
1380 result = node;
1381 }
1382 // check for consecutive and redundant extraction of same bits
1383 BaseNode *opnd = result->Opnd(0);
1384 DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1385 Opcode opndOp = opnd->GetOpCode();
1386 if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1387 uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1388 uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1389 if (offset == opndOffset && size == opndSize) {
1390 result->SetOpnd(opnd->Opnd(0), 0); // delete the redundant extraction
1391 }
1392 }
1393 if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1394 if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1395 return std::make_pair(result->Opnd(0), std::nullopt);
1396 }
1397 }
1398 return std::make_pair(result, std::nullopt);
1399 }
1400
FoldIread(IreadNode * node)1401 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1402 {
1403 CHECK_NULL_FATAL(node);
1404 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1405 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1406 node->SetOpnd(e, 0);
1407 BaseNode *result = node;
1408 if (e->GetOpCode() != OP_addrof) {
1409 return std::make_pair(result, std::nullopt);
1410 }
1411
1412 AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1413 MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1414 DEBUG_ASSERT(msy != nullptr, "nullptr check");
1415 TyIdx typeId = msy->GetTyIdx();
1416 CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1417 MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1418 MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1419 // If the high level type of iaddrof/iread doesn't match
1420 // the type of addrof's rhs, this optimization cannot be done.
1421 if (ptrType->GetPointedType() != msyType) {
1422 return std::make_pair(result, std::nullopt);
1423 }
1424
1425 Opcode op = node->GetOpCode();
1426 if (op == OP_iread) {
1427 result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1428 node->GetFieldID() + addrofNode->GetFieldID());
1429 }
1430 return std::make_pair(result, std::nullopt);
1431 }
1432
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1433 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1434 {
1435 switch (op) {
1436 case OP_add: {
1437 int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1438 if (IsUnsignedInteger(primType)) {
1439 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1440 }
1441 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1442 return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1443 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1444 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1445 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1446 }
1447 case OP_sub: {
1448 if (IsUnsignedInteger(primType)) {
1449 return cstA < cstB;
1450 }
1451 int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1452 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1453 return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1454 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1455 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1456 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1457 }
1458 default: {
1459 return false;
1460 }
1461 }
1462 }
1463
FoldBinary(BinaryNode * node)1464 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1465 {
1466 CHECK_NULL_FATAL(node);
1467 BaseNode *result = nullptr;
1468 std::optional<IntVal> sum = std::nullopt;
1469 Opcode op = node->GetOpCode();
1470 PrimType primType = node->GetPrimType();
1471 PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1472 PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1473 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1474 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1475 BaseNode *l = lp.first;
1476 BaseNode *r = rp.first;
1477 ASSERT_NOT_NULL(r);
1478 ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1479 ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1480 bool isInt = IsPrimitiveInteger(primType);
1481
1482 if (lConst != nullptr && rConst != nullptr) {
1483 MIRConst *lConstVal = lConst->GetConstVal();
1484 MIRConst *rConstVal = rConst->GetConstVal();
1485 ASSERT_NOT_NULL(lConstVal);
1486 ASSERT_NOT_NULL(rConstVal);
1487 // Don't fold div by 0, for floats div by 0 is well defined.
1488 if ((op == OP_div || op == OP_rem) && isInt &&
1489 !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1490 result = NewBinaryNode(node, op, primType, lConst, rConst);
1491 } else {
1492 // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1493 // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1494 // logic since the alternative is to return pair(result = nullptr, sum = 6).
1495 // Doing so would introduce many nullptr checks in the code. See previous
1496 // commits that implemented that logic for a comparison.
1497 result = FoldConstBinary(op, primType, *lConst, *rConst);
1498 }
1499 } else if (lConst != nullptr && isInt) {
1500 MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1501 ASSERT_NOT_NULL(mcst);
1502 PrimType cstTyp = mcst->GetType().GetPrimType();
1503 IntVal cst = mcst->GetValue();
1504 if (op == OP_add) {
1505 if (IsSignedInteger(cstTyp) && rp.second &&
1506 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1507 // do not introduce signed integer overflow
1508 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1509 } else {
1510 sum = cst + rp.second;
1511 result = r;
1512 }
1513 } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1514 // We exclude u1 type for fixing the following wrong example:
1515 // before cf:
1516 // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1517 // after cf:
1518 // add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1519 sum = cst - rp.second;
1520 if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1521 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1522 }
1523 result = NegateTree(r);
1524 } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1525 op == OP_band) &&
1526 cst == 0) {
1527 // 0 * X -> 0
1528 // 0 / X -> 0
1529 // 0 % X -> 0
1530 // 0 >> X -> 0
1531 // 0 << X -> 0
1532 // 0 & X -> 0
1533 // 0 && X -> 0
1534 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1535 } else if (op == OP_mul && cst == 1) {
1536 // 1 * X --> X
1537 sum = rp.second;
1538 result = r;
1539 } else if (op == OP_bior && cst == -1) {
1540 // (-1) | X -> -1
1541 result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1542 } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1543 // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1544 sum = cst * rp.second;
1545 if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1546 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1547 }
1548 result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1549 } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1550 // 0 | X -> X
1551 // 0 ^ X -> X
1552 sum = rp.second;
1553 result = r;
1554 } else {
1555 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1556 }
1557 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1558 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1559 }
1560 } else if (rConst != nullptr && isInt) {
1561 MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1562 ASSERT_NOT_NULL(mcst);
1563 PrimType cstTyp = mcst->GetType().GetPrimType();
1564 IntVal cst = mcst->GetValue();
1565 if (op == OP_add) {
1566 if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1567 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1568 } else {
1569 result = l;
1570 sum = lp.second + cst;
1571 }
1572 } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1573 result = l;
1574 sum = lp.second - cst;
1575 } else if ((op == OP_mul || op == OP_band) && cst == 0) {
1576 // X * 0 -> 0
1577 // X & 0 -> 0
1578 // X && 0 -> 0
1579 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1580 } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1581 // case [X * 1 -> X]
1582 // case [X / 1 = X]
1583 sum = lp.second;
1584 result = l;
1585 } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1586 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1587 // temporary fix for constfold of mul/div in DejaGnu
1588 // Later we need a more formal interface for pattern match
1589 // X * Y / Y -> X
1590 BaseNode *x = l->Opnd(0);
1591 BaseNode *y = l->Opnd(1);
1592 ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1593 ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1594 bool foldMulDiv = false;
1595 if (yConst != nullptr && xConst == nullptr &&
1596 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1597 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1598 ASSERT_NOT_NULL(yCst);
1599 IntVal mulCst = yCst->GetValue();
1600 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1601 mulCst.GetExtValue() == cst.GetExtValue()) {
1602 foldMulDiv = true;
1603 result = x;
1604 }
1605 } else if (xConst != nullptr && yConst == nullptr &&
1606 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1607 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1608 ASSERT_NOT_NULL(xCst);
1609 IntVal mulCst = xCst->GetValue();
1610 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1611 mulCst.GetExtValue() == cst.GetExtValue()) {
1612 foldMulDiv = true;
1613 result = y;
1614 }
1615 }
1616 if (!foldMulDiv) {
1617 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1618 }
1619 } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1620 // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1621 sum = lp.second * cst;
1622 if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1623 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1624 }
1625 if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1626 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1627 result = lp.first->Opnd(0);
1628 } else {
1629 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1630 }
1631 } else if (op == OP_band && cst == -1) {
1632 // X & (-1) -> X
1633 sum = lp.second;
1634 result = l;
1635 } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1636 (!lp.second.has_value() || lp.second == 0)) {
1637 bool fold2extractbits = false;
1638 if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1639 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1640 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1641 ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1642 int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1643 uint64 ucst = cst.GetZXTValue();
1644 uint32 bsize = 0;
1645 do {
1646 bsize++;
1647 ucst >>= 1;
1648 } while (ucst != 0);
1649 if (shrAmt + static_cast<int64>(bsize) <=
1650 static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1651 static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1652 fold2extractbits = true;
1653 // change to use extractbits
1654 result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1655 GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1656 sum = std::nullopt;
1657 }
1658 }
1659 }
1660 if (!fold2extractbits) {
1661 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1662 sum = std::nullopt;
1663 }
1664 } else if (op == OP_bior && cst == -1) {
1665 // X | (-1) -> -1
1666 result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1667 } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1668 // X >> 0 -> X
1669 // X << 0 -> X
1670 // X | 0 -> X
1671 // X ^ 0 -> X
1672 sum = lp.second;
1673 result = l;
1674 } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1675 // bxor i32 (
1676 // cvt i32 u1 (regread u1 %13),
1677 // constValue i32 1),
1678 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1679 if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1680 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1681 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1682 BaseNode *base = cvtNode->Opnd(0);
1683 BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1684 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1685 BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1686 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1687 }
1688 }
1689 } else if (op == OP_rem && cst == 1) {
1690 // X % 1 -> 0
1691 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1692 } else {
1693 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1694 }
1695 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1696 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1697 }
1698 } else if (isInt && (op == OP_add || op == OP_sub)) {
1699 if (op == OP_add) {
1700 result = NewBinaryNode(node, op, primType, l, r);
1701 sum = lp.second + rp.second;
1702 } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1703 // if fold is (x - (y - z)) -> (x - neg(z)) - y
1704 // (x - neg(z)) Could cross the int limit
1705 // return node
1706 result = node;
1707 } else {
1708 result = NewBinaryNode(node, op, primType, l, r);
1709 sum = lp.second - rp.second;
1710 }
1711 } else {
1712 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1713 }
1714 return std::make_pair(result, sum);
1715 }
1716
SimplifyDoubleConstvalCompare(CompareNode & node,bool isRConstval,bool isGtOrLt) const1717 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
1718 {
1719 if (isRConstval) {
1720 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
1721 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1722 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
1723 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1724 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1725 }
1726 } else {
1727 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
1728 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1729 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
1730 if (isGtOrLt) {
1731 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1732 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
1733 } else {
1734 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1735 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1736 }
1737 }
1738 }
1739 return &node;
1740 }
1741
SimplifyDoubleCompare(CompareNode & compareNode) const1742 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1743 {
1744 // See arm manual B.cond(P2993) and FCMP(P1091)
1745 CompareNode *node = &compareNode;
1746 BaseNode *result = node;
1747 BaseNode *l = node->Opnd(0);
1748 BaseNode *r = node->Opnd(1);
1749 if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1750 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1751 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1752 result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
1753 } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1754 // ne (u1 x, constValue 0) <==> x
1755 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1756 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1757 BaseNode *opnd = l;
1758 do {
1759 if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1760 result = opnd;
1761 break;
1762 } else if (opnd->GetOpCode() == OP_cvt) {
1763 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
1764 opnd = cvtNode->Opnd(0);
1765 } else {
1766 opnd = nullptr;
1767 }
1768 } while (opnd != nullptr);
1769 }
1770 } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1771 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1772 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1773 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1774 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
1775 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1776 resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
1777 }
1778 }
1779 } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1780 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1781 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1782 result = SimplifyDoubleConstvalCompare(*node,
1783 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
1784 }
1785 }
1786 return result;
1787 }
1788
FoldCompare(CompareNode * node)1789 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
1790 {
1791 CHECK_NULL_FATAL(node);
1792 BaseNode *result = nullptr;
1793 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1794 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1795 ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
1796 ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
1797 Opcode opcode = node->GetOpCode();
1798 if (lConst != nullptr && rConst != nullptr) {
1799 result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
1800 } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
1801 lConst->GetConstVal()->GetKind() == kConstInt) {
1802 BaseNode *l = lp.first;
1803 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1804 result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
1805 } else {
1806 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
1807 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1808 if (l != node->Opnd(0) || r != node->Opnd(1)) {
1809 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1810 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
1811 } else {
1812 result = node;
1813 }
1814 auto *compareNode = static_cast<CompareNode*>(result);
1815 CHECK_NULL_FATAL(compareNode);
1816 result = SimplifyDoubleCompare(*compareNode);
1817 }
1818 return std::make_pair(result, std::nullopt);
1819 }
1820
Fold(BaseNode * node)1821 BaseNode *ConstantFold::Fold(BaseNode *node)
1822 {
1823 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
1824 return nullptr;
1825 }
1826 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
1827 BaseNode *result = PairToExpr(node->GetPrimType(), p);
1828 if (result == node) {
1829 result = nullptr;
1830 }
1831 return result;
1832 }
1833
1834 } // namespace maple
1835