1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28
29 namespace maple {
30
31 namespace {
32
33 constexpr uint64 kJsTypeNumber = 4;
34 constexpr uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32; // set high 32 bit as JSTYPE_NUMBER
35 constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit
36 constexpr uint32 kBitSizePerByte = 8;
37 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
38
39 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
40
operator *(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)41 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
42 {
43 if (!v1 && !v2) {
44 return std::nullopt;
45 }
46
47 // Perform all calculations in terms of the maximum available signed type.
48 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
49 return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
50 }
51
52 // Perform all calculations in terms of the maximum available signed type.
53 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2,bool isAdd)54 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
55 {
56 if (!v1 && !v2) {
57 return std::nullopt;
58 }
59
60 if (v1 && v2) {
61 return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
62 }
63
64 if (v1) {
65 return v1->TruncOrExtend(PTY_i64);
66 }
67
68 // !v1 && v2
69 return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
70 }
71
operator +(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)72 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
73 {
74 return AddSub(v1, v2, true);
75 }
76
operator -(const std::optional<IntVal> & v1,const std::optional<IntVal> & v2)77 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
78 {
79 return AddSub(v1, v2, false);
80 }
81
82 } // anonymous namespace
83
84 // This phase is designed to achieve compiler optimization by
85 // simplifying constant expressions. The constant expression
86 // is evaluated and replaced by the value calculated on compile
87 // time to save time on runtime.
88 //
89 // The main procedure shows as following:
90 // A. Analyze expression type
91 // B. Analysis operator type
92 // C. Replace the expression with the result of the operation
93
94 // true if the constant's bits are made of only one group of contiguous 1's
95 // starting at bit 0
ContiguousBitsOf1(uint64 x)96 static bool ContiguousBitsOf1(uint64 x)
97 {
98 if (x == 0) {
99 return false;
100 }
101 return (~x & (x + 1)) == (x + 1);
102 }
103
IsPowerOf2(uint64 num)104 inline bool IsPowerOf2(uint64 num)
105 {
106 if (num == 0) {
107 return false;
108 }
109 return (~(num - 1) & num) == num;
110 }
111
NewBinaryNode(BinaryNode * old,Opcode op,PrimType primType,BaseNode * lhs,BaseNode * rhs) const112 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
113 BaseNode *rhs) const
114 {
115 CHECK_NULL_FATAL(old);
116 BinaryNode *result = nullptr;
117 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
118 result = old;
119 } else {
120 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
121 }
122 return result;
123 }
124
NewUnaryNode(UnaryNode * old,Opcode op,PrimType primType,BaseNode * expr) const125 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
126 {
127 CHECK_NULL_FATAL(old);
128 UnaryNode *result = nullptr;
129 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
130 result = old;
131 } else {
132 result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
133 }
134 return result;
135 }
136
PairToExpr(PrimType resultType,const std::pair<BaseNode *,std::optional<IntVal>> & pair) const137 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
138 {
139 CHECK_NULL_FATAL(pair.first);
140 BaseNode *result = pair.first;
141 if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
142 return result;
143 }
144 if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
145 // -a, 5 -> 5 - a
146 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
147 static_cast<uint64>(pair.second->GetExtValue()), resultType);
148 BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
149 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
150 } else {
151 if ((!pair.second->GetSignBit() &&
152 pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
153 pair.second->TruncOrExtend(resultType).IsMinValue() ||
154 pair.second->GetSXTValue() == INT64_MIN) {
155 // +-a, 5 -> a + 5
156 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
157 static_cast<uint64>(pair.second->GetExtValue()), resultType);
158 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
159 } else {
160 // +-a, -5 -> a + -5
161 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
162 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
163 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
164 }
165 }
166 return result;
167 }
168
FoldBase(BaseNode * node) const169 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
170 {
171 return std::make_pair(node, std::nullopt);
172 }
173
Simplify(StmtNode * node)174 StmtNode *ConstantFold::Simplify(StmtNode *node)
175 {
176 CHECK_NULL_FATAL(node);
177 switch (node->GetOpCode()) {
178 case OP_dassign:
179 case OP_maydassign:
180 return SimplifyDassign(static_cast<DassignNode*>(node));
181 case OP_iassign:
182 return SimplifyIassign(static_cast<IassignNode*>(node));
183 case OP_block:
184 return SimplifyBlock(static_cast<BlockNode*>(node));
185 case OP_if:
186 return SimplifyIf(static_cast<IfStmtNode*>(node));
187 case OP_dowhile:
188 case OP_while:
189 return SimplifyWhile(static_cast<WhileStmtNode*>(node));
190 case OP_switch:
191 return SimplifySwitch(static_cast<SwitchNode*>(node));
192 case OP_eval:
193 case OP_throw:
194 case OP_free:
195 case OP_decref:
196 case OP_incref:
197 case OP_decrefreset:
198 case OP_regassign:
199 CASE_OP_ASSERT_NONNULL
200 case OP_igoto:
201 return SimplifyUnary(static_cast<UnaryStmtNode*>(node));
202 case OP_brfalse:
203 case OP_brtrue:
204 return SimplifyCondGoto(static_cast<CondGotoNode*>(node));
205 case OP_return:
206 case OP_syncenter:
207 case OP_syncexit:
208 case OP_call:
209 case OP_virtualcall:
210 case OP_superclasscall:
211 case OP_interfacecall:
212 case OP_customcall:
213 case OP_polymorphiccall:
214 case OP_intrinsiccall:
215 case OP_xintrinsiccall:
216 case OP_intrinsiccallwithtype:
217 case OP_callassigned:
218 case OP_virtualcallassigned:
219 case OP_superclasscallassigned:
220 case OP_interfacecallassigned:
221 case OP_customcallassigned:
222 case OP_polymorphiccallassigned:
223 case OP_intrinsiccallassigned:
224 case OP_intrinsiccallwithtypeassigned:
225 case OP_xintrinsiccallassigned:
226 case OP_callinstant:
227 case OP_callinstantassigned:
228 case OP_virtualcallinstant:
229 case OP_virtualcallinstantassigned:
230 case OP_superclasscallinstant:
231 case OP_superclasscallinstantassigned:
232 case OP_interfacecallinstant:
233 case OP_interfacecallinstantassigned:
234 CASE_OP_ASSERT_BOUNDARY
235 return SimplifyNary(static_cast<NaryStmtNode*>(node));
236 case OP_icall:
237 case OP_icallassigned:
238 case OP_icallproto:
239 case OP_icallprotoassigned:
240 return SimplifyIcall(static_cast<IcallNode*>(node));
241 case OP_asm:
242 return SimplifyAsm(static_cast<AsmNode*>(node));
243 default:
244 return node;
245 }
246 }
247
DispatchFold(BaseNode * node)248 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
249 {
250 CHECK_NULL_FATAL(node);
251 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
252 return {node, std::nullopt};
253 }
254 switch (node->GetOpCode()) {
255 case OP_sizeoftype:
256 return FoldSizeoftype(static_cast<SizeoftypeNode*>(node));
257 case OP_abs:
258 case OP_bnot:
259 case OP_lnot:
260 case OP_neg:
261 case OP_recip:
262 case OP_sqrt:
263 return FoldUnary(static_cast<UnaryNode*>(node));
264 case OP_ceil:
265 case OP_floor:
266 case OP_round:
267 case OP_trunc:
268 case OP_cvt:
269 return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
270 case OP_sext:
271 case OP_zext:
272 case OP_extractbits:
273 return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
274 case OP_iaddrof:
275 case OP_iread:
276 return FoldIread(static_cast<IreadNode*>(node));
277 case OP_add:
278 case OP_ashr:
279 case OP_band:
280 case OP_bior:
281 case OP_bxor:
282 case OP_cand:
283 case OP_cior:
284 case OP_div:
285 case OP_land:
286 case OP_lior:
287 case OP_lshr:
288 case OP_max:
289 case OP_min:
290 case OP_mul:
291 case OP_rem:
292 case OP_shl:
293 case OP_sub:
294 return FoldBinary(static_cast<BinaryNode*>(node));
295 case OP_eq:
296 case OP_ne:
297 case OP_ge:
298 case OP_gt:
299 case OP_le:
300 case OP_lt:
301 case OP_cmp:
302 return FoldCompare(static_cast<CompareNode*>(node));
303 case OP_depositbits:
304 return FoldDepositbits(static_cast<DepositbitsNode*>(node));
305 case OP_select:
306 return FoldTernary(static_cast<TernaryNode*>(node));
307 case OP_array:
308 return FoldArray(static_cast<ArrayNode*>(node));
309 case OP_retype:
310 return FoldRetype(static_cast<RetypeNode*>(node));
311 case OP_gcmallocjarray:
312 case OP_gcpermallocjarray:
313 return FoldGcmallocjarray(static_cast<JarrayMallocNode*>(node));
314 default:
315 return FoldBase(static_cast<BaseNode*>(node));
316 }
317 }
318
Negate(BaseNode * node) const319 BaseNode *ConstantFold::Negate(BaseNode *node) const
320 {
321 CHECK_NULL_FATAL(node);
322 return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
323 }
324
Negate(UnaryNode * node) const325 BaseNode *ConstantFold::Negate(UnaryNode *node) const
326 {
327 CHECK_NULL_FATAL(node);
328 BaseNode *result = nullptr;
329 if (node->GetOpCode() == OP_neg) {
330 result = static_cast<BaseNode*>(node->Opnd(0));
331 } else {
332 BaseNode *n = static_cast<BaseNode*>(node);
333 result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
334 }
335 return result;
336 }
337
Negate(const ConstvalNode * node) const338 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
339 {
340 CHECK_NULL_FATAL(node);
341 ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
342 CHECK_NULL_FATAL(copy);
343 copy->GetConstVal()->Neg();
344 return copy;
345 }
346
NegateTree(BaseNode * node) const347 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
348 {
349 CHECK_NULL_FATAL(node);
350 if (node->IsUnaryNode()) {
351 return Negate(static_cast<UnaryNode*>(node));
352 } else if (node->GetOpCode() == OP_constval) {
353 return Negate(static_cast<ConstvalNode*>(node));
354 } else {
355 return Negate(static_cast<BaseNode*>(node));
356 }
357 }
358
FoldIntConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRIntConst & intConst0,const MIRIntConst & intConst1) const359 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
360 const MIRIntConst &intConst0,
361 const MIRIntConst &intConst1) const
362 {
363 uint64 result = 0;
364
365 bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
366 bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
367 bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
368
369 switch (opcode) {
370 case OP_eq: {
371 result = equal;
372 break;
373 }
374 case OP_ge: {
375 result = (greater || equal) ? 1 : 0;
376 break;
377 }
378 case OP_gt: {
379 result = greater;
380 break;
381 }
382 case OP_le: {
383 result = (less || equal) ? 1 : 0;
384 break;
385 }
386 case OP_lt: {
387 result = less;
388 break;
389 }
390 case OP_ne: {
391 result = !equal;
392 break;
393 }
394 case OP_cmp: {
395 if (greater) {
396 result = kGreater;
397 } else if (equal) {
398 result = kEqual;
399 } else if (less) {
400 result = static_cast<uint64>(kLess);
401 }
402 break;
403 }
404 default:
405 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
406 break;
407 }
408 // determine the type
409 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
410 // form the constant
411 MIRIntConst *constValue = nullptr;
412 if (type.GetPrimType() == PTY_dyni32) {
413 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
414 constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | result));
415 } else {
416 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
417 }
418 return constValue;
419 }
420
FoldIntConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const421 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
422 const ConstvalNode &const0, const ConstvalNode &const1) const
423 {
424 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
425 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
426 CHECK_NULL_FATAL(intConst0);
427 CHECK_NULL_FATAL(intConst1);
428 MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
429 // form the ConstvalNode
430 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
431 resultConst->SetPrimType(resultType);
432 resultConst->SetConstVal(constValue);
433 return resultConst;
434 }
435
FoldIntConstBinaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst & intConst0,const MIRIntConst & intConst1)436 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
437 const MIRIntConst &intConst1)
438 {
439 IntVal intVal0 = intConst0.GetValue();
440 IntVal intVal1 = intConst1.GetValue();
441 IntVal result(static_cast<uint64>(0), resultType);
442
443 switch (opcode) {
444 case OP_add: {
445 result = intVal0.Add(intVal1, resultType);
446 break;
447 }
448 case OP_sub: {
449 result = intVal0.Sub(intVal1, resultType);
450 break;
451 }
452 case OP_mul: {
453 result = intVal0.Mul(intVal1, resultType);
454 break;
455 }
456 case OP_div: {
457 result = intVal0.Div(intVal1, resultType);
458 break;
459 }
460 case OP_rem: {
461 result = intVal0.Rem(intVal1, resultType);
462 break;
463 }
464 case OP_ashr: {
465 result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
466 break;
467 }
468 case OP_lshr: {
469 result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
470 break;
471 }
472 case OP_shl: {
473 result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
474 break;
475 }
476 case OP_max: {
477 result = Max(intVal0, intVal1, resultType);
478 break;
479 }
480 case OP_min: {
481 result = Min(intVal0, intVal1, resultType);
482 break;
483 }
484 case OP_band: {
485 result = intVal0.And(intVal1, resultType);
486 break;
487 }
488 case OP_bior: {
489 result = intVal0.Or(intVal1, resultType);
490 break;
491 }
492 case OP_bxor: {
493 result = intVal0.Xor(intVal1, resultType);
494 break;
495 }
496 case OP_cand:
497 case OP_land: {
498 result = IntVal(intVal0.GetExtValue() && intVal1.GetExtValue(), resultType);
499 break;
500 }
501 case OP_cior:
502 case OP_lior: {
503 result = IntVal(intVal0.GetExtValue() || intVal1.GetExtValue(), resultType);
504 break;
505 }
506 case OP_depositbits: {
507 // handled in FoldDepositbits
508 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstBinary");
509 break;
510 }
511 default:
512 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
513 break;
514 }
515 // determine the type
516 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
517 // form the constant
518 MIRIntConst *constValue = nullptr;
519 if (type.GetPrimType() == PTY_dyni32) {
520 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
521 constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | static_cast<uint64>(result.GetExtValue())));
522 } else {
523 constValue =
524 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
525 }
526 return constValue;
527 }
528
FoldIntConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const529 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
530 const ConstvalNode &const1) const
531 {
532 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
533 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
534 CHECK_NULL_FATAL(intConst0);
535 CHECK_NULL_FATAL(intConst1);
536 MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
537 // form the ConstvalNode
538 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
539 resultConst->SetPrimType(resultType);
540 resultConst->SetConstVal(constValue);
541 return resultConst;
542 }
543
FoldFPConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const544 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
545 const ConstvalNode &const1) const
546 {
547 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
548 const MIRDoubleConst *doubleConst0 = nullptr;
549 const MIRDoubleConst *doubleConst1 = nullptr;
550 const MIRFloatConst *floatConst0 = nullptr;
551 const MIRFloatConst *floatConst1 = nullptr;
552 bool useDouble = (const0.GetPrimType() == PTY_f64);
553 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
554 resultConst->SetPrimType(resultType);
555 if (useDouble) {
556 doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
557 doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
558 CHECK_NULL_FATAL(doubleConst0);
559 CHECK_NULL_FATAL(doubleConst1);
560 } else {
561 floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
562 floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
563 CHECK_NULL_FATAL(floatConst0);
564 CHECK_NULL_FATAL(floatConst1);
565 }
566 float constValueFloat = 0.0;
567 double constValueDouble = 0.0;
568 switch (opcode) {
569 case OP_add: {
570 if (useDouble) {
571 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
572 } else {
573 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
574 }
575 break;
576 }
577 case OP_sub: {
578 if (useDouble) {
579 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
580 } else {
581 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
582 }
583 break;
584 }
585 case OP_mul: {
586 if (useDouble) {
587 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
588 } else {
589 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
590 }
591 break;
592 }
593 case OP_div: {
594 // for floats div by 0 is well defined
595 if (useDouble) {
596 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
597 } else {
598 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
599 }
600 break;
601 }
602 case OP_max: {
603 if (useDouble) {
604 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
605 : doubleConst1->GetValue();
606 } else {
607 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
608 : floatConst1->GetValue();
609 }
610 break;
611 }
612 case OP_min: {
613 if (useDouble) {
614 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
615 : doubleConst1->GetValue();
616 } else {
617 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
618 : floatConst1->GetValue();
619 }
620 break;
621 }
622 case OP_rem:
623 case OP_ashr:
624 case OP_lshr:
625 case OP_shl:
626 case OP_band:
627 case OP_bior:
628 case OP_bxor:
629 case OP_cand:
630 case OP_land:
631 case OP_cior:
632 case OP_lior:
633 case OP_depositbits: {
634 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
635 break;
636 }
637 default:
638 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
639 break;
640 }
641 if (resultType == PTY_f64) {
642 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
643 } else {
644 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
645 }
646 return resultConst;
647 }
648
ConstValueEqual(int64 leftValue,int64 rightValue) const649 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
650 {
651 return (leftValue == rightValue);
652 }
653
ConstValueEqual(float leftValue,float rightValue) const654 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
655 {
656 auto result = fabs(leftValue - rightValue);
657 return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
658 }
659
ConstValueEqual(double leftValue,double rightValue) const660 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
661 {
662 auto result = fabs(leftValue - rightValue);
663 return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
664 }
665
666 template<typename T>
FullyEqual(T leftValue,T rightValue) const667 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
668 {
669 if (std::isinf(leftValue) && std::isinf(rightValue)) {
670 // (inf == inf), add the judgement here in case of the subtraction between float type inf
671 return true;
672 } else {
673 return ConstValueEqual(leftValue, rightValue);
674 }
675 }
676
677 template<typename T>
ComparisonResult(Opcode op,T * leftConst,T * rightConst) const678 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
679 {
680 DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
681 typename T::value_type leftValue = leftConst->GetValue();
682 DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
683 typename T::value_type rightValue = rightConst->GetValue();
684 int64 result = 0;
685 switch (op) {
686 case OP_eq: {
687 result = FullyEqual(leftValue, rightValue);
688 break;
689 }
690 case OP_ge: {
691 result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
692 break;
693 }
694 case OP_gt: {
695 result = (leftValue > rightValue);
696 break;
697 }
698 case OP_le: {
699 result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
700 break;
701 }
702 case OP_lt: {
703 result = (leftValue < rightValue);
704 break;
705 }
706 case OP_ne: {
707 result = !FullyEqual(leftValue, rightValue);
708 break;
709 }
710 case OP_cmpl:
711 case OP_cmpg: {
712 if (std::isnan(leftValue) || std::isnan(rightValue)) {
713 result = (op == OP_cmpg) ? kGreater : kLess;
714 break;
715 }
716 }
717 [[clang::fallthrough]];
718 case OP_cmp: {
719 if (leftValue > rightValue) {
720 result = kGreater;
721 } else if (FullyEqual(leftValue, rightValue)) {
722 result = kEqual;
723 } else {
724 result = kLess;
725 }
726 break;
727 }
728 default:
729 DEBUG_ASSERT(false, "Unknown opcode for Comparison");
730 break;
731 }
732 return result;
733 }
734
FoldFPConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & leftConst,const MIRConst & rightConst) const735 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
736 const MIRConst &leftConst, const MIRConst &rightConst) const
737 {
738 int64 result = 0;
739 bool useDouble = (opndType == PTY_f64);
740 if (useDouble) {
741 result =
742 ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
743 } else {
744 result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
745 }
746 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
747 MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
748 return resultConst;
749 }
750
FoldFPConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const751 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
752 const ConstvalNode &const0, const ConstvalNode &const1) const
753 {
754 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
755 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
756 resultConst->SetPrimType(resultType);
757 resultConst->SetConstVal(
758 FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
759 return resultConst;
760 }
761
FoldConstComparisonMIRConst(Opcode opcode,PrimType resultType,PrimType opndType,const MIRConst & const0,const MIRConst & const1) const762 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
763 const MIRConst &const0, const MIRConst &const1) const
764 {
765 MIRConst *returnValue = nullptr;
766 if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
767 const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
768 const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
769 ASSERT_NOT_NULL(intConst0);
770 ASSERT_NOT_NULL(intConst1);
771 returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
772 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
773 returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
774 } else {
775 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
776 }
777 return returnValue;
778 }
779
FoldConstComparison(Opcode opcode,PrimType resultType,PrimType opndType,const ConstvalNode & const0,const ConstvalNode & const1) const780 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
781 const ConstvalNode &const0, const ConstvalNode &const1) const
782 {
783 ConstvalNode *returnValue = nullptr;
784 if (IsPrimitiveInteger(opndType) || IsPrimitiveDynInteger(opndType)) {
785 returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
786 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
787 returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
788 } else {
789 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
790 }
791 return returnValue;
792 }
793
FoldConstComparisonReverse(Opcode opcode,PrimType resultType,PrimType opndType,BaseNode & l,BaseNode & r) const794 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
795 BaseNode &l, BaseNode &r) const
796 {
797 CompareNode *result = nullptr;
798 Opcode op = opcode;
799 switch (opcode) {
800 case OP_gt: {
801 op = OP_lt;
802 break;
803 }
804 case OP_lt: {
805 op = OP_gt;
806 break;
807 }
808 case OP_ge: {
809 op = OP_le;
810 break;
811 }
812 case OP_le: {
813 op = OP_ge;
814 break;
815 }
816 case OP_eq: {
817 break;
818 }
819 case OP_ne: {
820 break;
821 }
822 default:
823 DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
824 break;
825 }
826
827 result =
828 mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
829 return result;
830 }
831
FoldConstBinary(Opcode opcode,PrimType resultType,const ConstvalNode & const0,const ConstvalNode & const1) const832 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
833 const ConstvalNode &const1) const
834 {
835 ConstvalNode *returnValue = nullptr;
836 if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
837 returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
838 } else if (resultType == PTY_f32 || resultType == PTY_f64) {
839 returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
840 } else {
841 DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
842 }
843 return returnValue;
844 }
845
FoldIntConstUnaryMIRConst(Opcode opcode,PrimType resultType,const MIRIntConst * constNode)846 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
847 {
848 CHECK_NULL_FATAL(constNode);
849 IntVal result = constNode->GetValue().TruncOrExtend(resultType);
850 switch (opcode) {
851 case OP_abs: {
852 if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
853 result = -result;
854 }
855 break;
856 }
857 case OP_bnot: {
858 result = ~result;
859 break;
860 }
861 case OP_lnot: {
862 uint64 resultInt = result == 0 ? 1 : 0;
863 result = {resultInt, resultType};
864 break;
865 }
866 case OP_neg: {
867 result = -result;
868 break;
869 }
870 case OP_sext: // handled in FoldExtractbits
871 case OP_zext: // handled in FoldExtractbits
872 case OP_extractbits: // handled in FoldExtractbits
873 case OP_recip:
874 case OP_sqrt: {
875 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
876 break;
877 }
878 default:
879 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
880 break;
881 }
882 // determine the type
883 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
884 // form the constant
885 MIRIntConst *constValue = nullptr;
886 if (type.GetPrimType() == PTY_dyni32) {
887 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
888 constValue->SetValue(static_cast<int64>(kJsTypeNumberInHigh32Bit | static_cast<uint64>(result.GetExtValue())));
889 } else {
890 constValue =
891 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
892 }
893 return constValue;
894 }
895
896 template <typename T>
FoldFPConstUnary(Opcode opcode,PrimType resultType,ConstvalNode * constNode) const897 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
898 {
899 CHECK_NULL_FATAL(constNode);
900 double constValue = 0;
901 T *fpCst = static_cast<T*>(constNode->GetConstVal());
902 switch (opcode) {
903 case OP_recip: {
904 constValue = typename T::value_type(1.0L / fpCst->GetValue());
905 break;
906 }
907 case OP_neg: {
908 constValue = typename T::value_type(-fpCst->GetValue());
909 break;
910 }
911 case OP_abs: {
912 constValue = typename T::value_type(fabs(fpCst->GetValue()));
913 break;
914 }
915 case OP_sqrt: {
916 constValue = typename T::value_type(sqrt(fpCst->GetValue()));
917 break;
918 }
919 case OP_bnot:
920 case OP_lnot:
921 case OP_sext:
922 case OP_zext:
923 case OP_extractbits: {
924 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
925 break;
926 }
927 default:
928 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
929 break;
930 }
931 auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
932 resultConst->SetPrimType(resultType);
933 if (resultType == PTY_f32) {
934 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
935 } else if (resultType == PTY_f64) {
936 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
937 } else {
938 CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
939 }
940 return resultConst;
941 }
942
FoldConstUnary(Opcode opcode,PrimType resultType,ConstvalNode & constNode) const943 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
944 {
945 ConstvalNode *returnValue = nullptr;
946 if (IsPrimitiveInteger(resultType) || IsPrimitiveDynInteger(resultType)) {
947 const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
948 auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
949 returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
950 returnValue->SetPrimType(resultType);
951 returnValue->SetConstVal(constValue);
952 } else if (resultType == PTY_f32) {
953 returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
954 } else if (resultType == PTY_f64) {
955 returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
956 } else if (resultType == PTY_f128) {
957 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
958 } else {
959 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
960 }
961 return returnValue;
962 }
963
FoldSizeoftype(SizeoftypeNode * node) const964 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldSizeoftype(SizeoftypeNode *node) const
965 {
966 CHECK_NULL_FATAL(node);
967 BaseNode *result = node;
968 MIRType *argType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
969 if (argType->GetKind() == kTypeScalar) {
970 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
971 uint32 size = GetPrimTypeSize(argType->GetPrimType());
972 ConstvalNode *constValueNode = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
973 constValueNode->SetPrimType(node->GetPrimType());
974 constValueNode->SetConstVal(GlobalTables::GetIntConstTable().GetOrCreateIntConst(size, resultType));
975 result = constValueNode;
976 }
977 return std::make_pair(result, std::nullopt);
978 }
979
FoldRetype(RetypeNode * node)980 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
981 {
982 CHECK_NULL_FATAL(node);
983 BaseNode *result = node;
984 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
985 if (node->Opnd(0) != p.first) {
986 RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
987 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
988 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
989 result = newRetNode;
990 }
991 return std::make_pair(result, std::nullopt);
992 }
993
FoldGcmallocjarray(JarrayMallocNode * node)994 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldGcmallocjarray(JarrayMallocNode *node)
995 {
996 CHECK_NULL_FATAL(node);
997 BaseNode *result = node;
998 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
999 if (node->Opnd(0) != p.first) {
1000 JarrayMallocNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1001 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldGcmallocjarray");
1002 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
1003 result = newRetNode;
1004 }
1005 return std::make_pair(result, std::nullopt);
1006 }
1007
FoldUnary(UnaryNode * node)1008 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
1009 {
1010 CHECK_NULL_FATAL(node);
1011 BaseNode *result = nullptr;
1012 std::optional<IntVal> sum = std::nullopt;
1013 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1014 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1015 if (cst != nullptr) {
1016 result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
1017 } else {
1018 bool isInt = IsPrimitiveInteger(node->GetPrimType());
1019 // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
1020 // primType will be set to opnd type. There will be problems in some cases. For example:
1021 // before cf:
1022 // neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
1023 // after cf:
1024 // neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) # wrong!
1025 // As a workaround, we exclude u1 opnd type
1026 if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
1027 result = NegateTree(p.first);
1028 if (result->GetOpCode() == OP_neg) {
1029 PrimType origPtyp = node->GetPrimType();
1030 PrimType newPtyp = result->GetPrimType();
1031 if (newPtyp == origPtyp) {
1032 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
1033 // NegateTree returned an UnaryNode quivalent to `n`, so keep the
1034 // original UnaryNode to preserve identity
1035 result = node;
1036 }
1037 } else {
1038 if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
1039 // do not fold explicit cvt
1040 result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
1041 PairToExpr(node->Opnd(0)->GetPrimType(), p));
1042 return std::make_pair(result, std::nullopt);
1043 } else {
1044 result->SetPrimType(origPtyp);
1045 }
1046 }
1047 }
1048 if (p.second) {
1049 sum = -(*p.second);
1050 }
1051 } else {
1052 result =
1053 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
1054 }
1055 }
1056 return std::make_pair(result, sum);
1057 }
1058
FloatToIntOverflow(float fval,PrimType totype)1059 static bool FloatToIntOverflow(float fval, PrimType totype)
1060 {
1061 static const float safeFloatMaxToInt32 = 2147483520.0f; // 2^31 - 128
1062 static const float safeFloatMinToInt32 = -2147483520.0f;
1063 static const float safeFloatMaxToInt64 = 9223372036854775680.0f; // 2^63 - 128
1064 static const float safeFloatMinToInt64 = -9223372036854775680.0f;
1065 if (!std::isfinite(fval)) {
1066 return true;
1067 }
1068 if (totype == PTY_i64 || totype == PTY_u64) {
1069 if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
1070 return true;
1071 }
1072 } else {
1073 if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
1074 return true;
1075 }
1076 }
1077 return false;
1078 }
1079
DoubleToIntOverflow(double dval,PrimType totype)1080 static bool DoubleToIntOverflow(double dval, PrimType totype)
1081 {
1082 static const double safeDoubleMaxToInt32 = 2147482624.0; // 2^31 - 1024
1083 static const double safeDoubleMinToInt32 = -2147482624.0;
1084 static const double safeDoubleMaxToInt64 = 9223372036854774784.0; // 2^63 - 1024
1085 static const double safeDoubleMinToInt64 = -9223372036854774784.0;
1086 if (!std::isfinite(dval)) {
1087 return true;
1088 }
1089 if (totype == PTY_i64 || totype == PTY_u64) {
1090 if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
1091 return true;
1092 }
1093 } else {
1094 if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
1095 return true;
1096 }
1097 }
1098 return false;
1099 }
1100
FoldCeil(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1101 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1102 {
1103 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1104 resultConst->SetPrimType(toType);
1105 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1106 if (fromType == PTY_f32) {
1107 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1108 ASSERT_NOT_NULL(constValue);
1109 float floatValue = ceil(constValue->GetValue());
1110 if (IsPrimitiveFloat(toType)) {
1111 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1112 } else if (FloatToIntOverflow(floatValue, toType)) {
1113 return nullptr;
1114 } else {
1115 resultConst->SetConstVal(
1116 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1117 }
1118 } else {
1119 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1120 ASSERT_NOT_NULL(constValue);
1121 double doubleValue = ceil(constValue->GetValue());
1122 if (IsPrimitiveFloat(toType)) {
1123 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1124 } else if (DoubleToIntOverflow(doubleValue, toType)) {
1125 return nullptr;
1126 } else {
1127 resultConst->SetConstVal(
1128 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1129 }
1130 }
1131 return resultConst;
1132 }
1133
1134 template <class T>
CalIntValueFromFloatValue(T value,const MIRType & resultType) const1135 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
1136 {
1137 DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
1138 size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
1139 bool isSigned = IsSignedInteger(resultType.GetPrimType());
1140 int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
1141 uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
1142 int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
1143 if (isSigned && (value > max)) {
1144 return static_cast<T>(max);
1145 } else if (!isSigned && (value > umax)) {
1146 return static_cast<T>(umax);
1147 } else if (value < min) {
1148 return static_cast<T>(min);
1149 }
1150 return value;
1151 }
1152
FoldFloorMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType,bool isFloor) const1153 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
1154 {
1155 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1156 if (fromType == PTY_f32) {
1157 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1158 float floatValue = constValue.GetValue();
1159 if (isFloor) {
1160 floatValue = floor(constValue.GetValue());
1161 }
1162 if (IsPrimitiveFloat(toType)) {
1163 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1164 }
1165 if (FloatToIntOverflow(floatValue, toType)) {
1166 return nullptr;
1167 }
1168 floatValue = CalIntValueFromFloatValue(floatValue, resultType);
1169 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
1170 } else {
1171 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1172 double doubleValue = constValue.GetValue();
1173 if (isFloor) {
1174 doubleValue = floor(constValue.GetValue());
1175 }
1176 if (IsPrimitiveFloat(toType)) {
1177 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1178 }
1179 if (DoubleToIntOverflow(doubleValue, toType)) {
1180 return nullptr;
1181 }
1182 doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1183 // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
1184 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1185 }
1186 }
1187
FoldFloor(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1188 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1189 {
1190 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1191 resultConst->SetPrimType(toType);
1192 resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1193 return resultConst;
1194 }
1195
FoldRoundMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1196 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1197 {
1198 if (fromType == PTY_f128 || toType == PTY_f128) {
1199 // folding while rounding float128 is not supported yet
1200 return nullptr;
1201 }
1202
1203 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1204 if (fromType == PTY_f32) {
1205 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1206 float floatValue = round(constValue.GetValue());
1207 if (FloatToIntOverflow(floatValue, toType)) {
1208 return nullptr;
1209 }
1210 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1211 } else if (fromType == PTY_f64) {
1212 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1213 double doubleValue = round(constValue.GetValue());
1214 if (DoubleToIntOverflow(doubleValue, toType)) {
1215 return nullptr;
1216 }
1217 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1218 static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1219 } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1220 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1221 if (IsSignedInteger(fromType)) {
1222 int64 fromValue = constValue.GetExtValue();
1223 float floatValue = round(static_cast<float>(fromValue));
1224 if (static_cast<int64>(floatValue) == fromValue) {
1225 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1226 }
1227 } else {
1228 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1229 float floatValue = round(static_cast<float>(fromValue));
1230 if (static_cast<uint64>(floatValue) == fromValue) {
1231 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1232 }
1233 }
1234 } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1235 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1236 if (IsSignedInteger(fromType)) {
1237 int64 fromValue = constValue.GetExtValue();
1238 double doubleValue = round(static_cast<double>(fromValue));
1239 if (static_cast<int64>(doubleValue) == fromValue) {
1240 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1241 }
1242 } else {
1243 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1244 double doubleValue = round(static_cast<double>(fromValue));
1245 if (static_cast<uint64>(doubleValue) == fromValue) {
1246 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1247 }
1248 }
1249 }
1250 return nullptr;
1251 }
1252
FoldRound(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1253 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1254 {
1255 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1256 resultConst->SetPrimType(toType);
1257 resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1258 return resultConst;
1259 }
1260
FoldTrunc(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1261 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1262 {
1263 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1264 resultConst->SetPrimType(toType);
1265 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1266 if (fromType == PTY_f32) {
1267 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1268 CHECK_NULL_FATAL(constValue);
1269 float floatValue = trunc(constValue->GetValue());
1270 if (IsPrimitiveFloat(toType)) {
1271 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1272 } else if (FloatToIntOverflow(floatValue, toType)) {
1273 return nullptr;
1274 } else {
1275 resultConst->SetConstVal(
1276 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1277 }
1278 } else {
1279 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1280 CHECK_NULL_FATAL(constValue);
1281 double doubleValue = trunc(constValue->GetValue());
1282 if (IsPrimitiveFloat(toType)) {
1283 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1284 } else if (DoubleToIntOverflow(doubleValue, toType)) {
1285 return nullptr;
1286 } else {
1287 resultConst->SetConstVal(
1288 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1289 }
1290 }
1291 return resultConst;
1292 }
1293
FoldTypeCvtMIRConst(const MIRConst & cst,PrimType fromType,PrimType toType) const1294 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1295 {
1296 if (IsPrimitiveDynType(fromType) || IsPrimitiveDynType(toType) ||
1297 IsPrimitiveVector(fromType) || IsPrimitiveVector(toType)) {
1298 // do not fold
1299 return nullptr;
1300 }
1301 if (fromType == PTY_f128 || toType == PTY_f128) {
1302 // folding while Cvt float128 is not supported yet
1303 return nullptr;
1304 }
1305
1306 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1307 MIRConst *toConst = nullptr;
1308 uint32 fromSize = GetPrimTypeBitSize(fromType);
1309 uint32 toSize = GetPrimTypeBitSize(toType);
1310 // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1311 if (fromType == PTY_u1) {
1312 fromSize = 1;
1313 }
1314 if (toType == PTY_u1) {
1315 toSize = 1;
1316 }
1317 if (toSize > fromSize) {
1318 Opcode op = OP_zext;
1319 if (IsSignedInteger(fromType)) {
1320 op = OP_sext;
1321 }
1322 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1323 ASSERT_NOT_NULL(constVal);
1324 toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1325 constVal->GetValue().TruncOrExtend(fromType));
1326 } else {
1327 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1328 ASSERT_NOT_NULL(constVal);
1329 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1330 toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1331 static_cast<uint64>(constVal->GetExtValue()), type);
1332 }
1333 return toConst;
1334 }
1335 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1336 MIRConst *toConst = nullptr;
1337 if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1338 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1339 const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1340 ASSERT_NOT_NULL(fromValue);
1341 float floatValue = static_cast<float>(fromValue->GetValue());
1342 MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1343 toConst = toValue;
1344 } else {
1345 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1346 const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1347 ASSERT_NOT_NULL(fromValue);
1348 double doubleValue = static_cast<double>(fromValue->GetValue());
1349 MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1350 toConst = toValue;
1351 }
1352 return toConst;
1353 }
1354 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1355 return FoldFloorMIRConst(cst, fromType, toType, false);
1356 }
1357 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1358 return FoldRoundMIRConst(cst, fromType, toType);
1359 }
1360 CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1361 return nullptr;
1362 }
1363
FoldTypeCvt(const ConstvalNode & cst,PrimType fromType,PrimType toType) const1364 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1365 {
1366 MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1367 if (toConstValue == nullptr) {
1368 return nullptr;
1369 }
1370 ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1371 toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1372 toConst->SetConstVal(toConstValue);
1373 return toConst;
1374 }
1375
1376 // return a primType with bit size >= bitSize (and the nearest one),
1377 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize,PrimType ptyp)1378 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1379 {
1380 bool isSigned = IsSignedInteger(ptyp);
1381 bool isFloat = IsPrimitiveFloat(ptyp);
1382 if (bitSize == 1) { // 1 bit
1383 return PTY_u1;
1384 }
1385 if (bitSize <= 8) { // 8 bit
1386 return isSigned ? PTY_i8 : PTY_u8;
1387 }
1388 if (bitSize <= 16) { // 16 bit
1389 return isSigned ? PTY_i16 : PTY_u16;
1390 }
1391 if (bitSize <= 32) { // 32 bit
1392 return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1393 }
1394 if (bitSize <= 64) { // 64 bit
1395 return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1396 }
1397 if (bitSize <= 128) { // 128 bit
1398 return isFloat ? PTY_f128 : (isSigned ? PTY_i128 : PTY_u128);
1399 }
1400 return ptyp;
1401 }
1402
GetIntPrimTypeMax(PrimType ptyp)1403 size_t GetIntPrimTypeMax(PrimType ptyp)
1404 {
1405 switch (ptyp) {
1406 case PTY_u1:
1407 return 1;
1408 case PTY_u8:
1409 return UINT8_MAX;
1410 case PTY_i8:
1411 return INT8_MAX;
1412 case PTY_u16:
1413 return UINT16_MAX;
1414 case PTY_i16:
1415 return INT16_MAX;
1416 case PTY_u32:
1417 return UINT32_MAX;
1418 case PTY_i32:
1419 return INT32_MAX;
1420 case PTY_u64:
1421 return UINT64_MAX;
1422 case PTY_i64:
1423 return INT64_MAX;
1424 default:
1425 CHECK_FATAL(false, "NYI");
1426 }
1427 }
1428
GetIntPrimTypeMin(PrimType ptyp)1429 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1430 {
1431 if (IsUnsignedInteger(ptyp)) {
1432 return 0;
1433 }
1434 switch (ptyp) {
1435 case PTY_i8:
1436 return INT8_MIN;
1437 case PTY_i16:
1438 return INT16_MIN;
1439 case PTY_i32:
1440 return INT32_MIN;
1441 case PTY_i64:
1442 return INT64_MIN;
1443 default:
1444 CHECK_FATAL(false, "NYI");
1445 }
1446 }
1447
1448 // return a primtype to represent value range of expr
GetExprValueRangePtyp(BaseNode * expr)1449 PrimType GetExprValueRangePtyp(BaseNode *expr)
1450 {
1451 PrimType ptyp = expr->GetPrimType();
1452 Opcode op = expr->GetOpCode();
1453 if (expr->IsLeaf()) {
1454 return ptyp;
1455 }
1456 if (kOpcodeInfo.IsTypeCvt(op)) {
1457 auto *node = static_cast<TypeCvtNode *>(expr);
1458 if (GetPrimTypeSize(node->FromType()) < GetPrimTypeSize(node->GetPrimType())) {
1459 return GetExprValueRangePtyp(expr->Opnd(0));
1460 }
1461 return ptyp;
1462 }
1463 if (op == OP_sext || op == OP_zext || op == OP_extractbits) {
1464 auto *node = static_cast<ExtractbitsNode *>(expr);
1465 uint8 size = node->GetBitsSize();
1466 return GetNearestSizePtyp(size, expr->GetPrimType());
1467 }
1468 // find max size primtype of opnds.
1469 size_t maxTypeSize = 1;
1470 size_t ptypSize = GetPrimTypeSize(ptyp);
1471 for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
1472 PrimType opndPtyp = GetExprValueRangePtyp(expr->Opnd(i));
1473 size_t opndSize = GetPrimTypeSize(opndPtyp);
1474 if (ptypSize <= opndSize) {
1475 return ptyp;
1476 }
1477 if (maxTypeSize < opndSize) {
1478 maxTypeSize = opndSize;
1479 constexpr size_t intMaxSize = 8;
1480 if (maxTypeSize == intMaxSize) {
1481 break;
1482 }
1483 }
1484 }
1485 return GetNearestSizePtyp(static_cast<uint8>(maxTypeSize), ptyp);
1486 }
1487
IsCvtEliminatable(PrimType fromPtyp,PrimType destPtyp,Opcode op,Opcode opndOp)1488 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1489 {
1490 if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1491 return false;
1492 }
1493 if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1494 return false;
1495 }
1496 return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1497 (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1498 (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1499 }
1500
FoldTypeCvt(TypeCvtNode * node)1501 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1502 {
1503 CHECK_NULL_FATAL(node);
1504 BaseNode *result = nullptr;
1505 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1506 return {node, std::nullopt};
1507 }
1508 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1509 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1510 PrimType destPtyp = node->GetPrimType();
1511 PrimType fromPtyp = node->FromType();
1512 if (cst != nullptr) {
1513 switch (node->GetOpCode()) {
1514 case OP_ceil: {
1515 result = FoldCeil(*cst, fromPtyp, destPtyp);
1516 break;
1517 }
1518 case OP_cvt: {
1519 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1520 break;
1521 }
1522 case OP_floor: {
1523 result = FoldFloor(*cst, fromPtyp, destPtyp);
1524 break;
1525 }
1526 case OP_round: {
1527 result = FoldRound(*cst, fromPtyp, destPtyp);
1528 break;
1529 }
1530 case OP_trunc: {
1531 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1532 break;
1533 }
1534 default:
1535 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1536 break;
1537 }
1538 } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1539 // the cvt is redundant
1540 return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1541 }
1542 if (result == nullptr) {
1543 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1544 if (e != node->Opnd(0)) {
1545 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1546 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1547 } else {
1548 result = node;
1549 }
1550 }
1551 return std::make_pair(result, std::nullopt);
1552 }
1553
FoldSignExtendMIRConst(Opcode opcode,PrimType resultType,uint8 size,const IntVal & val) const1554 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1555 {
1556 uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1557 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1558 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1559 return constValue;
1560 }
1561
FoldSignExtend(Opcode opcode,PrimType resultType,uint8 size,const ConstvalNode & cst) const1562 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1563 const ConstvalNode &cst) const
1564 {
1565 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1566 const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1567 ASSERT_NOT_NULL(intCst);
1568 IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1569 MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1570 resultConst->SetPrimType(toConst->GetType().GetPrimType());
1571 resultConst->SetConstVal(toConst);
1572 return resultConst;
1573 }
1574
1575 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode & x,MIRFunction & f)1576 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1577 {
1578 if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1579 return false; // this is trying to be conservative
1580 }
1581 BaseNode *opnd = x.Opnd(0);
1582 MIRType *mirType = nullptr;
1583 if (opnd->GetOpCode() == OP_dread) {
1584 DreadNode *dread = static_cast<DreadNode*>(opnd);
1585 MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1586 ASSERT_NOT_NULL(sym);
1587 mirType = sym->GetType();
1588 if (dread->GetFieldID() != 0) {
1589 MIRStructType *structType = dynamic_cast<MIRStructType*>(mirType);
1590 if (structType == nullptr) {
1591 return false;
1592 }
1593 mirType = structType->GetFieldType(dread->GetFieldID());
1594 }
1595 } else if (opnd->GetOpCode() == OP_iread) {
1596 IreadNode *iread = static_cast<IreadNode*>(opnd);
1597 MIRPtrType *ptrType =
1598 dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1599 if (ptrType == nullptr) {
1600 return false;
1601 }
1602 mirType = ptrType->GetPointedType();
1603 if (iread->GetFieldID() != 0) {
1604 MIRStructType *structType = dynamic_cast<MIRStructType*>(mirType);
1605 if (structType == nullptr) {
1606 return false;
1607 }
1608 mirType = structType->GetFieldType(iread->GetFieldID());
1609 }
1610 } else if (opnd->GetOpCode() == OP_extractbits &&
1611 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1612 return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1613 IsUnsignedInteger(opnd->GetPrimType()));
1614 } else {
1615 return false;
1616 }
1617 return IsPrimitiveInteger(mirType->GetPrimType()) &&
1618 ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1619 (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1620 mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1621 mirType->GetPrimType() == x.GetPrimType();
1622 }
1623
1624 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode * node)1625 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1626 {
1627 CHECK_NULL_FATAL(node);
1628 BaseNode *result = nullptr;
1629 uint8 offset = node->GetBitsOffset();
1630 uint8 size = node->GetBitsSize();
1631 Opcode opcode = node->GetOpCode();
1632 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1633 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1634 if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1635 result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1636 return std::make_pair(result, std::nullopt);
1637 }
1638 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1639 if (e != node->Opnd(0)) {
1640 result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1641 size, e);
1642 } else {
1643 result = node;
1644 }
1645 // check for consecutive and redundant extraction of same bits
1646 BaseNode *opnd = result->Opnd(0);
1647 DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1648 Opcode opndOp = opnd->GetOpCode();
1649 if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1650 uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1651 uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1652 if (offset == opndOffset && size == opndSize) {
1653 result->SetOpnd(opnd->Opnd(0), 0); // delete the redundant extraction
1654 }
1655 }
1656 if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1657 if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1658 return std::make_pair(result->Opnd(0), std::nullopt);
1659 }
1660 }
1661 return std::make_pair(result, std::nullopt);
1662 }
1663
FoldIread(IreadNode * node)1664 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1665 {
1666 CHECK_NULL_FATAL(node);
1667 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1668 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1669 node->SetOpnd(e, 0);
1670 BaseNode *result = node;
1671 if (e->GetOpCode() != OP_addrof) {
1672 return std::make_pair(result, std::nullopt);
1673 }
1674
1675 AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1676 MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1677 DEBUG_ASSERT(msy != nullptr, "nullptr check");
1678 TyIdx typeId = msy->GetTyIdx();
1679 CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1680 MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1681 if (addrofNode->GetFieldID() != 0) {
1682 CHECK_FATAL(msyType->IsStructType(), "must be");
1683 msyType = static_cast<MIRStructType*>(msyType)->GetFieldType(addrofNode->GetFieldID());
1684 }
1685 MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1686 // If the high level type of iaddrof/iread doesn't match
1687 // the type of addrof's rhs, this optimization cannot be done.
1688 if (ptrType->GetPointedType() != msyType) {
1689 return std::make_pair(result, std::nullopt);
1690 }
1691
1692 Opcode op = node->GetOpCode();
1693 FieldID fieldID = node->GetFieldID();
1694 if (op == OP_iaddrof) {
1695 AddrofNode *newAddrof = addrofNode->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
1696 CHECK_NULL_FATAL(newAddrof);
1697 newAddrof->SetFieldID(newAddrof->GetFieldID() + fieldID);
1698 result = newAddrof;
1699 } else if (op == OP_iread) {
1700 result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1701 node->GetFieldID() + addrofNode->GetFieldID());
1702 }
1703 return std::make_pair(result, std::nullopt);
1704 }
1705
IntegerOpIsOverflow(Opcode op,PrimType primType,int64 cstA,int64 cstB)1706 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1707 {
1708 switch (op) {
1709 case OP_add: {
1710 int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1711 if (IsUnsignedInteger(primType)) {
1712 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1713 }
1714 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1715 return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1716 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1717 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1718 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1719 }
1720 case OP_sub: {
1721 if (IsUnsignedInteger(primType)) {
1722 return cstA < cstB;
1723 }
1724 int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1725 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1726 return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1727 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1728 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1729 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1730 }
1731 default: {
1732 return false;
1733 }
1734 }
1735 }
1736
FoldBinary(BinaryNode * node)1737 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1738 {
1739 CHECK_NULL_FATAL(node);
1740 BaseNode *result = nullptr;
1741 std::optional<IntVal> sum = std::nullopt;
1742 Opcode op = node->GetOpCode();
1743 PrimType primType = node->GetPrimType();
1744 PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1745 PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1746 if (lPrimTypes == PTY_f128 || rPrimTypes == PTY_f128 || node->GetPrimType() == PTY_f128) {
1747 // folding of non-unary float128 is not supported yet
1748 return std::make_pair(static_cast<BaseNode*>(node), std::nullopt);
1749 }
1750 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1751 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1752 BaseNode *l = lp.first;
1753 BaseNode *r = rp.first;
1754 ASSERT_NOT_NULL(r);
1755 ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1756 ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1757 bool isInt = IsPrimitiveInteger(primType);
1758
1759 if (lConst != nullptr && rConst != nullptr) {
1760 MIRConst *lConstVal = lConst->GetConstVal();
1761 MIRConst *rConstVal = rConst->GetConstVal();
1762 ASSERT_NOT_NULL(lConstVal);
1763 ASSERT_NOT_NULL(rConstVal);
1764 // Don't fold div by 0, for floats div by 0 is well defined.
1765 if ((op == OP_div || op == OP_rem) && isInt &&
1766 !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1767 result = NewBinaryNode(node, op, primType, lConst, rConst);
1768 } else {
1769 // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1770 // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1771 // logic since the alternative is to return pair(result = nullptr, sum = 6).
1772 // Doing so would introduce many nullptr checks in the code. See previous
1773 // commits that implemented that logic for a comparison.
1774 result = FoldConstBinary(op, primType, *lConst, *rConst);
1775 }
1776 } else if (lConst != nullptr && isInt) {
1777 MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1778 ASSERT_NOT_NULL(mcst);
1779 PrimType cstTyp = mcst->GetType().GetPrimType();
1780 IntVal cst = mcst->GetValue();
1781 if (op == OP_add) {
1782 if (IsSignedInteger(cstTyp) && rp.second &&
1783 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1784 // do not introduce signed integer overflow
1785 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1786 } else {
1787 sum = cst + rp.second;
1788 result = r;
1789 }
1790 } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1791 // We exclude u1 type for fixing the following wrong example:
1792 // before cf:
1793 // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1794 // after cf:
1795 // add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1796 sum = cst - rp.second;
1797 if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1798 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1799 }
1800 result = NegateTree(r);
1801 } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1802 op == OP_band || op == OP_cand || op == OP_land) &&
1803 cst == 0) {
1804 // 0 * X -> 0
1805 // 0 / X -> 0
1806 // 0 % X -> 0
1807 // 0 >> X -> 0
1808 // 0 << X -> 0
1809 // 0 & X -> 0
1810 // 0 && X -> 0
1811 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1812 } else if (op == OP_mul && cst == 1) {
1813 // 1 * X --> X
1814 sum = rp.second;
1815 result = r;
1816 } else if (op == OP_bior && cst == -1) {
1817 // (-1) | X -> -1
1818 result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1819 } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1820 // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1821 sum = cst * rp.second;
1822 if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1823 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1824 }
1825 result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1826 } else if (op == OP_lior || op == OP_cior) {
1827 if (cst != 0) {
1828 // 5 || X -> 1
1829 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1830 } else {
1831 // when cst is zero
1832 // 0 || X -> (X != 0);
1833 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1834 OP_ne, primType, r->GetPrimType(), r,
1835 mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1836 }
1837 } else if ((op == OP_cand || op == OP_land) && cst != 0) {
1838 // 5 && X -> (X != 0)
1839 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1840 OP_ne, primType, r->GetPrimType(), r, mirModule->GetMIRBuilder()->CreateIntConst(0, r->GetPrimType()));
1841 } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1842 // 0 | X -> X
1843 // 0 ^ X -> X
1844 sum = rp.second;
1845 result = r;
1846 } else {
1847 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1848 }
1849 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1850 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1851 }
1852 } else if (rConst != nullptr && isInt) {
1853 MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1854 ASSERT_NOT_NULL(mcst);
1855 PrimType cstTyp = mcst->GetType().GetPrimType();
1856 IntVal cst = mcst->GetValue();
1857 if (op == OP_add) {
1858 if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1859 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1860 } else {
1861 result = l;
1862 sum = lp.second + cst;
1863 }
1864 } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1865 result = l;
1866 sum = lp.second - cst;
1867 } else if ((op == OP_mul || op == OP_band || op == OP_cand || op == OP_land) && cst == 0) {
1868 // X * 0 -> 0
1869 // X & 0 -> 0
1870 // X && 0 -> 0
1871 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1872 } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1873 // case [X * 1 -> X]
1874 // case [X / 1 = X]
1875 sum = lp.second;
1876 result = l;
1877 } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1878 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1879 // temporary fix for constfold of mul/div in DejaGnu
1880 // Later we need a more formal interface for pattern match
1881 // X * Y / Y -> X
1882 BaseNode *x = l->Opnd(0);
1883 BaseNode *y = l->Opnd(1);
1884 ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1885 ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1886 bool foldMulDiv = false;
1887 if (yConst != nullptr && xConst == nullptr &&
1888 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1889 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1890 ASSERT_NOT_NULL(yCst);
1891 IntVal mulCst = yCst->GetValue();
1892 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1893 mulCst.GetExtValue() == cst.GetExtValue()) {
1894 foldMulDiv = true;
1895 result = x;
1896 }
1897 } else if (xConst != nullptr && yConst == nullptr &&
1898 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1899 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1900 ASSERT_NOT_NULL(xCst);
1901 IntVal mulCst = xCst->GetValue();
1902 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1903 mulCst.GetExtValue() == cst.GetExtValue()) {
1904 foldMulDiv = true;
1905 result = y;
1906 }
1907 }
1908 if (!foldMulDiv) {
1909 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1910 }
1911 } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1912 // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1913 sum = lp.second * cst;
1914 if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1915 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1916 }
1917 if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1918 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1919 result = lp.first->Opnd(0);
1920 } else {
1921 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1922 }
1923 } else if (op == OP_band && cst == -1) {
1924 // X & (-1) -> X
1925 sum = lp.second;
1926 result = l;
1927 } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1928 (!lp.second.has_value() || lp.second == 0)) {
1929 bool fold2extractbits = false;
1930 if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1931 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1932 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1933 ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1934 int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1935 uint64 ucst = cst.GetZXTValue();
1936 uint32 bsize = 0;
1937 do {
1938 bsize++;
1939 ucst >>= 1;
1940 } while (ucst != 0);
1941 if (shrAmt + static_cast<int64>(bsize) <=
1942 static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1943 static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1944 fold2extractbits = true;
1945 // change to use extractbits
1946 result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1947 GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1948 sum = std::nullopt;
1949 }
1950 }
1951 }
1952 if (!fold2extractbits) {
1953 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1954 sum = std::nullopt;
1955 }
1956 } else if (op == OP_bior && cst == -1) {
1957 // X | (-1) -> -1
1958 result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1959 } else if ((op == OP_lior || op == OP_cior)) {
1960 if (cst == 0) {
1961 // X || 0 -> X
1962 sum = lp.second;
1963 result = l;
1964 } else if (!cst.GetSignBit()) {
1965 // X || 5 -> 1
1966 result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp);
1967 } else {
1968 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1969 }
1970 } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1971 // X >> 0 -> X
1972 // X << 0 -> X
1973 // X | 0 -> X
1974 // X ^ 0 -> X
1975 sum = lp.second;
1976 result = l;
1977 } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1978 // bxor i32 (
1979 // cvt i32 u1 (regread u1 %13),
1980 // constValue i32 1),
1981 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1982 if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1983 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1984 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1985 BaseNode *base = cvtNode->Opnd(0);
1986 BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1987 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1988 BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1989 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1990 }
1991 }
1992 } else if (op == OP_rem && cst == 1) {
1993 // X % 1 -> 0
1994 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1995 } else {
1996 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1997 }
1998 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1999 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
2000 }
2001 } else if (isInt && (op == OP_add || op == OP_sub)) {
2002 if (op == OP_add) {
2003 result = NewBinaryNode(node, op, primType, l, r);
2004 sum = lp.second + rp.second;
2005 } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
2006 // if fold is (x - (y - z)) -> (x - neg(z)) - y
2007 // (x - neg(z)) Could cross the int limit
2008 // return node
2009 result = node;
2010 } else {
2011 result = NewBinaryNode(node, op, primType, l, r);
2012 sum = lp.second - rp.second;
2013 }
2014 } else {
2015 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
2016 }
2017 return std::make_pair(result, sum);
2018 }
2019
SimplifyDoubleConstvalCompare(CompareNode & node,bool isRConstval,bool isGtOrLt) const2020 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
2021 {
2022 if (isRConstval) {
2023 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
2024 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2025 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
2026 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2027 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
2028 }
2029 } else {
2030 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
2031 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2032 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
2033 if (isGtOrLt) {
2034 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2035 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
2036 } else {
2037 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
2038 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
2039 }
2040 }
2041 }
2042 return &node;
2043 }
2044
SimplifyDoubleCompare(CompareNode & compareNode) const2045 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
2046 {
2047 // See arm manual B.cond(P2993) and FCMP(P1091)
2048 CompareNode *node = &compareNode;
2049 BaseNode *result = node;
2050 BaseNode *l = node->Opnd(0);
2051 BaseNode *r = node->Opnd(1);
2052 if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
2053 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
2054 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
2055 result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
2056 } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
2057 // ne (u1 x, constValue 0) <==> x
2058 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
2059 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
2060 BaseNode *opnd = l;
2061 do {
2062 if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
2063 result = opnd;
2064 break;
2065 } else if (opnd->GetOpCode() == OP_cvt) {
2066 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
2067 opnd = cvtNode->Opnd(0);
2068 } else {
2069 opnd = nullptr;
2070 }
2071 } while (opnd != nullptr);
2072 }
2073 } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
2074 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
2075 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
2076 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
2077 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
2078 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2079 resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
2080 }
2081 }
2082 } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
2083 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
2084 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
2085 result = SimplifyDoubleConstvalCompare(*node,
2086 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
2087 }
2088 }
2089 return result;
2090 }
2091
FoldCompare(CompareNode * node)2092 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
2093 {
2094 CHECK_NULL_FATAL(node);
2095 BaseNode *result = nullptr;
2096 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
2097 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
2098 ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
2099 ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
2100 if (node->GetOpndType() == PTY_f128 || node->GetPrimType() == PTY_f128) {
2101 // folding of non-unary float128 is not supported yet
2102 return std::make_pair(static_cast<BaseNode*>(node), std::nullopt);
2103 }
2104 Opcode opcode = node->GetOpCode();
2105 if (lConst != nullptr && rConst != nullptr && !IsPrimitiveDynType(node->GetOpndType())) {
2106 result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
2107 } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
2108 lConst->GetConstVal()->GetKind() == kConstInt) {
2109 BaseNode *l = lp.first;
2110 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2111 result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
2112 } else {
2113 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
2114 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
2115 if (l != node->Opnd(0) || r != node->Opnd(1)) {
2116 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2117 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
2118 } else {
2119 result = node;
2120 }
2121 auto *compareNode = static_cast<CompareNode*>(result);
2122 CHECK_NULL_FATAL(compareNode);
2123 result = SimplifyDoubleCompare(*compareNode);
2124 }
2125 return std::make_pair(result, std::nullopt);
2126 }
2127
Fold(BaseNode * node)2128 BaseNode *ConstantFold::Fold(BaseNode *node)
2129 {
2130 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
2131 return nullptr;
2132 }
2133 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
2134 BaseNode *result = PairToExpr(node->GetPrimType(), p);
2135 if (result == node) {
2136 result = nullptr;
2137 }
2138 return result;
2139 }
2140
FoldDepositbits(DepositbitsNode * node)2141 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldDepositbits(DepositbitsNode *node)
2142 {
2143 CHECK_NULL_FATAL(node);
2144 BaseNode *result = nullptr;
2145 uint8 bitsOffset = node->GetBitsOffset();
2146 uint8 bitsSize = node->GetBitsSize();
2147 std::pair<BaseNode*, std::optional<IntVal>> leftPair = DispatchFold(node->Opnd(0));
2148 std::pair<BaseNode*, std::optional<IntVal>> rightPair = DispatchFold(node->Opnd(1));
2149 ConstvalNode *leftConst = safe_cast<ConstvalNode>(leftPair.first);
2150 ConstvalNode *rightConst = safe_cast<ConstvalNode>(rightPair.first);
2151 if (leftConst != nullptr && rightConst != nullptr) {
2152 MIRIntConst *intConst0 = safe_cast<MIRIntConst>(leftConst->GetConstVal());
2153 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(rightConst->GetConstVal());
2154 ASSERT_NOT_NULL(intConst0);
2155 ASSERT_NOT_NULL(intConst1);
2156 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
2157 resultConst->SetPrimType(node->GetPrimType());
2158 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(node->GetPrimType());
2159 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, type);
2160 uint64 op0ExtractVal = 0;
2161 uint64 op1ExtractVal = 0;
2162 uint64 mask0 = (1LLU << (bitsSize + bitsOffset)) - 1;
2163 uint64 mask1 = (1LLU << bitsOffset) - 1;
2164 uint64 op0Mask = ~(mask0 ^ mask1);
2165 op0ExtractVal = (static_cast<uint64>(intConst0->GetExtValue()) & op0Mask);
2166 op1ExtractVal = (static_cast<uint64>(intConst1->GetExtValue()) << bitsOffset) &
2167 ((1ULL << (bitsSize + bitsOffset)) - 1);
2168 constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
2169 (op0ExtractVal | op1ExtractVal), constValue->GetType());
2170 resultConst->SetConstVal(constValue);
2171 result = resultConst;
2172 } else {
2173 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), leftPair);
2174 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rightPair);
2175 if (l != node->Opnd(0) || r != node->Opnd(1)) {
2176 result = mirModule->CurFuncCodeMemPool()->New<DepositbitsNode>(
2177 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), bitsOffset, bitsSize, l, r);
2178 } else {
2179 result = node;
2180 }
2181 }
2182 return std::make_pair(result, std::nullopt);
2183 }
2184
FoldArray(ArrayNode * node)2185 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldArray(ArrayNode *node)
2186 {
2187 CHECK_NULL_FATAL(node);
2188 BaseNode *result = nullptr;
2189 size_t i = 0;
2190 bool isFolded = false;
2191 ArrayNode *arrNode = mirModule->CurFuncCodeMemPool()->New<ArrayNode>(*mirModule, PrimType(node->GetPrimType()),
2192 node->GetTyIdx(), node->GetBoundsCheck());
2193 for (i = 0; i < node->GetNopndSize(); i++) {
2194 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->GetNopndAt(i));
2195 BaseNode *tmpNode = PairToExpr(node->GetNopndAt(i)->GetPrimType(), p);
2196 if (tmpNode != node->GetNopndAt(i)) {
2197 isFolded = true;
2198 }
2199 arrNode->GetNopnd().push_back(tmpNode);
2200 arrNode->SetNumOpnds(arrNode->GetNumOpnds() + 1);
2201 }
2202 if (isFolded) {
2203 result = arrNode;
2204 } else {
2205 result = node;
2206 }
2207 return std::make_pair(result, std::nullopt);
2208 }
2209
FoldTernary(TernaryNode * node)2210 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTernary(TernaryNode *node)
2211 {
2212 CHECK_NULL_FATAL(node);
2213 constexpr size_t kFirst = 0;
2214 constexpr size_t kSecond = 1;
2215 constexpr size_t kThird = 2;
2216 BaseNode *result = node;
2217 std::vector<PrimType> primTypes;
2218 std::vector<std::pair<BaseNode*, std::optional<IntVal>>> p;
2219 for (size_t i = 0; i < node->NumOpnds(); i++) {
2220 BaseNode *tempNopnd = node->Opnd(i);
2221 CHECK_NULL_FATAL(tempNopnd);
2222 primTypes.push_back(tempNopnd->GetPrimType());
2223 p.push_back(DispatchFold(tempNopnd));
2224 }
2225 if (node->GetOpCode() == OP_select) {
2226 ConstvalNode *const0 = safe_cast<ConstvalNode>(p[kFirst].first);
2227 if (const0 != nullptr) {
2228 MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0->GetConstVal());
2229 ASSERT_NOT_NULL(intConst0);
2230 // Selecting the first value if not 0, selecting the second value otherwise.
2231 if (!intConst0->IsZero()) {
2232 result = PairToExpr(primTypes[kSecond], p[kSecond]);
2233 } else {
2234 result = PairToExpr(primTypes[kThird], p[kThird]);
2235 }
2236 } else {
2237 ConstvalNode *const1 = safe_cast<ConstvalNode>(p[kSecond].first);
2238 ConstvalNode *const2 = safe_cast<ConstvalNode>(p[kThird].first);
2239 if (const1 != nullptr && const2 != nullptr) {
2240 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2241 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2242 double dconst1 = 0.0;
2243 double dconst2 = 0.0;
2244 // for fpconst
2245 if (intConst1 == nullptr || intConst2 == nullptr) {
2246 PrimType ptyp = const1->GetPrimType();
2247 if (ptyp == PTY_f64) {
2248 MIRDoubleConst *dConst1 = safe_cast<MIRDoubleConst>(const1->GetConstVal());
2249 dconst1 = dConst1->GetValue();
2250 MIRDoubleConst *dConst2 = safe_cast<MIRDoubleConst>(const2->GetConstVal());
2251 dconst2 = dConst2->GetValue();
2252 } else if (ptyp == PTY_f32) {
2253 MIRFloatConst *fConst1 = safe_cast<MIRFloatConst>(const1->GetConstVal());
2254 dconst1 = static_cast<double>(fConst1->GetFloatValue());
2255 MIRFloatConst *fConst2 = safe_cast<MIRFloatConst>(const2->GetConstVal());
2256 dconst2 = static_cast<double>(fConst2->GetFloatValue());
2257 }
2258 } else {
2259 dconst1 = static_cast<double>(intConst1->GetExtValue());
2260 dconst2 = static_cast<double>(intConst2->GetExtValue());
2261 }
2262 PrimType foldedPrimType = primTypes[kSecond];
2263 if (!IsPrimitiveInteger(foldedPrimType)) {
2264 foldedPrimType = primTypes[kThird];
2265 }
2266 if (dconst1 == 1.0 && dconst2 == 0.0 && GetPrimTypeActualBitSize(primTypes[0]) == 1) {
2267 if (IsPrimitiveInteger(foldedPrimType)) {
2268 result = PairToExpr(foldedPrimType, p[0]);
2269 } else {
2270 result = PairToExpr(primTypes[0], p[0]);
2271 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2272 result);
2273 }
2274 return std::make_pair(result, std::nullopt);
2275 }
2276 if (dconst1 == 0.0 && dconst2 == 1.0 && GetPrimTypeActualBitSize(primTypes[0]) == 1) {
2277 BaseNode *lnot = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
2278 OP_eq, primTypes[0], primTypes[0], PairToExpr(primTypes[0], p[0]),
2279 mirModule->GetMIRBuilder()->CreateIntConst(0, primTypes[0]));
2280 std::pair<BaseNode*, std::optional<IntVal>> pairTemp = DispatchFold(lnot);
2281 if (IsPrimitiveInteger(foldedPrimType)) {
2282 result = PairToExpr(foldedPrimType, pairTemp);
2283 } else {
2284 result = PairToExpr(primTypes[0], pairTemp);
2285 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, foldedPrimType, primTypes[0],
2286 result);
2287 }
2288 return std::make_pair(result, std::nullopt);
2289 }
2290 }
2291 }
2292 }
2293 BaseNode *e0 = PairToExpr(primTypes[kFirst], p[kFirst]);
2294 BaseNode *e1 = PairToExpr(primTypes[kSecond], p[kSecond]);
2295 BaseNode *e2 = PairToExpr(primTypes[kThird], p[kThird]); // count up to 3 for ternary node
2296 if (e0 != node->Opnd(kFirst) || e1 != node->Opnd(kSecond) || e2 != node->Opnd(kThird)) {
2297 result = mirModule->CurFuncCodeMemPool()->New<TernaryNode>(Opcode(node->GetOpCode()),
2298 PrimType(node->GetPrimType()), e0, e1, e2);
2299 }
2300 return std::make_pair(result, std::nullopt);
2301 }
2302
SimplifyDassign(DassignNode * node)2303 StmtNode *ConstantFold::SimplifyDassign(DassignNode *node)
2304 {
2305 CHECK_NULL_FATAL(node);
2306 BaseNode *returnValue = nullptr;
2307 returnValue = Fold(node->GetRHS());
2308 if (returnValue != nullptr) {
2309 node->SetRHS(returnValue);
2310 }
2311 return node;
2312 }
2313
SimplifyIassignWithAddrofBaseNode(IassignNode & node,const AddrofNode & base) const2314 StmtNode *ConstantFold::SimplifyIassignWithAddrofBaseNode(IassignNode &node, const AddrofNode &base) const
2315 {
2316 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2317 if (!mirTypeOfIass->IsMIRPtrType()) {
2318 return &node;
2319 }
2320 auto *iassPtType = static_cast<MIRPtrType*>(mirTypeOfIass);
2321
2322 MIRSymbol *lhsSym = mirModule->CurFunction()->GetLocalOrGlobalSymbol(base.GetStIdx());
2323 TyIdx lhsTyIdx = lhsSym->GetTyIdx();
2324 if (base.GetFieldID() != 0) {
2325 auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(lhsTyIdx);
2326 if (!mirType->IsStructType()) {
2327 return &node;
2328 }
2329 lhsTyIdx = static_cast<MIRStructType*>(mirType)->GetFieldType(base.GetFieldID())->GetTypeIndex();
2330 }
2331 if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2332 DassignNode *dassignNode = mirModule->CurFuncCodeMemPool()->New<DassignNode>();
2333 dassignNode->SetStIdx(base.GetStIdx());
2334 dassignNode->SetRHS(node.GetRHS());
2335 dassignNode->SetFieldID(base.GetFieldID() + node.GetFieldID());
2336 // reuse stmtid to maintain stmtFreqs if profileUse is on
2337 dassignNode->SetStmtID(node.GetStmtID());
2338 return dassignNode;
2339 }
2340 return &node;
2341 }
2342
SimplifyIassignWithIaddrofBaseNode(IassignNode & node,const IaddrofNode & base)2343 StmtNode *ConstantFold::SimplifyIassignWithIaddrofBaseNode(IassignNode &node, const IaddrofNode &base)
2344 {
2345 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node.GetTyIdx());
2346 if (!mirTypeOfIass->IsMIRPtrType()) {
2347 return &node;
2348 }
2349 auto *iassPtType = static_cast<MIRPtrType*>(mirTypeOfIass);
2350
2351 if (base.GetFieldID() == 0) {
2352 // this iaddrof is redundant
2353 node.SetAddrExpr(base.Opnd(0));
2354 return &node;
2355 }
2356
2357 auto *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(base.GetTyIdx());
2358 if (!mirType->IsMIRPtrType()) {
2359 return &node;
2360 }
2361 auto *iaddrofPtType = static_cast<MIRPtrType*>(mirType);
2362
2363 MIRStructType *lhsStructTy =
2364 static_cast<MIRStructType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iaddrofPtType->GetPointedTyIdx()));
2365 TyIdx lhsTyIdx = lhsStructTy->GetFieldType(base.GetFieldID())->GetTypeIndex();
2366 if (iassPtType->GetPointedTyIdx() == lhsTyIdx) {
2367 // eliminate the iaddrof by updating the iassign's fieldID and tyIdx
2368 node.SetFieldID(node.GetFieldID() + base.GetFieldID());
2369 node.SetTyIdx(base.GetTyIdx());
2370 node.SetOpnd(base.Opnd(0), 0);
2371 // recursive call for the new iassign
2372 return SimplifyIassign(&node);
2373 }
2374 return &node;
2375 }
2376
SimplifyIassign(IassignNode * node)2377 StmtNode *ConstantFold::SimplifyIassign(IassignNode *node)
2378 {
2379 CHECK_NULL_FATAL(node);
2380 BaseNode *returnValue = nullptr;
2381 returnValue = Fold(node->Opnd(0));
2382 if (returnValue != nullptr) {
2383 node->SetOpnd(returnValue, 0);
2384 }
2385 returnValue = Fold(node->GetRHS());
2386 if (returnValue != nullptr) {
2387 node->SetRHS(returnValue);
2388 }
2389 auto *mirTypeOfIass = GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx());
2390 if (!mirTypeOfIass->IsMIRPtrType()) {
2391 return node;
2392 }
2393
2394 auto *opnd = node->Opnd(0);
2395 ASSERT_NOT_NULL(opnd);
2396 switch (opnd->GetOpCode()) {
2397 case OP_addrof: {
2398 return SimplifyIassignWithAddrofBaseNode(*node, static_cast<AddrofNode&>(*opnd));
2399 }
2400 case OP_iaddrof: {
2401 return SimplifyIassignWithIaddrofBaseNode(*node, static_cast<IreadNode&>(*opnd));
2402 }
2403 default:
2404 break;
2405 }
2406 return node;
2407 }
2408
SimplifyCondGoto(CondGotoNode * node)2409 StmtNode *ConstantFold::SimplifyCondGoto(CondGotoNode *node)
2410 {
2411 CHECK_NULL_FATAL(node);
2412 // optimize condgoto need to update frequency, skip here
2413 if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2414 return node;
2415 }
2416 BaseNode *returnValue = nullptr;
2417 returnValue = Fold(node->Opnd(0));
2418 returnValue = (returnValue == nullptr) ? node : returnValue;
2419 if (returnValue == node && node->Opnd(0)->GetOpCode() == OP_select) {
2420 return SimplifyCondGotoSelect(node);
2421 } else {
2422 if (returnValue != node) {
2423 node->SetOpnd(returnValue, 0);
2424 }
2425 ConstvalNode *cst = safe_cast<ConstvalNode>(node->Opnd(0));
2426 if (cst == nullptr) {
2427 return node;
2428 }
2429 MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2430 ASSERT_NOT_NULL(intConst);
2431 if ((node->GetOpCode() == OP_brtrue && !intConst->IsZero()) ||
2432 (node->GetOpCode() == OP_brfalse && intConst->IsZero())) {
2433 uint32 freq = static_cast<uint32>(mirModule->CurFunction()->GetFreqFromLastStmt(node->GetStmtID()));
2434 GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2435 gotoNode->SetOffset(node->GetOffset());
2436 if (Options::profileUse && mirModule->CurFunction()->GetFuncProfData()) {
2437 gotoNode->SetStmtID(node->GetStmtID()); // reuse condnode stmtid
2438 }
2439 mirModule->CurFunction()->SetLastFreqMap(gotoNode->GetStmtID(), freq);
2440 return gotoNode;
2441 } else {
2442 return nullptr;
2443 }
2444 }
2445 return node;
2446 }
2447
SimplifyCondGotoSelect(CondGotoNode * node) const2448 StmtNode *ConstantFold::SimplifyCondGotoSelect(CondGotoNode *node) const
2449 {
2450 CHECK_NULL_FATAL(node);
2451 TernaryNode *sel = static_cast<TernaryNode*>(node->Opnd(0));
2452 if (sel == nullptr || sel->GetOpCode() != OP_select) {
2453 return node;
2454 }
2455 ConstvalNode *const1 = safe_cast<ConstvalNode>(sel->Opnd(1));
2456 ConstvalNode *const2 = safe_cast<ConstvalNode>(sel->Opnd(2));
2457 if (const1 != nullptr && const2 != nullptr) {
2458 MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1->GetConstVal());
2459 MIRIntConst *intConst2 = safe_cast<MIRIntConst>(const2->GetConstVal());
2460 ASSERT_NOT_NULL(intConst1);
2461 ASSERT_NOT_NULL(intConst2);
2462 if (intConst1->GetValue() == 1 && intConst2->GetValue() == 0) {
2463 node->SetOpnd(sel->Opnd(0), 0);
2464 } else if (intConst1->GetValue() == 0 && intConst2->GetValue() == 1) {
2465 node->SetOpCode((node->GetOpCode() == OP_brfalse) ? OP_brtrue : OP_brfalse);
2466 node->SetOpnd(sel->Opnd(0), 0);
2467 }
2468 }
2469 return node;
2470 }
2471
SimplifySwitch(SwitchNode * node)2472 StmtNode *ConstantFold::SimplifySwitch(SwitchNode *node)
2473 {
2474 CHECK_NULL_FATAL(node);
2475 BaseNode *returnValue = nullptr;
2476 returnValue = Fold(node->GetSwitchOpnd());
2477 if (returnValue != nullptr) {
2478 node->SetSwitchOpnd(returnValue);
2479 ConstvalNode *cst = safe_cast<ConstvalNode>(node->GetSwitchOpnd());
2480 if (cst == nullptr) {
2481 return node;
2482 }
2483 MIRIntConst *intConst = safe_cast<MIRIntConst>(cst->GetConstVal());
2484 ASSERT_NOT_NULL(intConst);
2485 GotoNode *gotoNode = mirModule->CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
2486 bool isdefault = true;
2487 for (unsigned i = 0; i < node->GetSwitchTable().size(); i++) {
2488 if (node->GetCasePair(i).first == intConst->GetValue()) {
2489 isdefault = false;
2490 gotoNode->SetOffset(static_cast<LabelIdx>(node->GetCasePair(i).second));
2491 break;
2492 }
2493 }
2494 if (isdefault) {
2495 gotoNode->SetOffset(node->GetDefaultLabel());
2496 }
2497 return gotoNode;
2498 }
2499 return node;
2500 }
2501
SimplifyUnary(UnaryStmtNode * node)2502 StmtNode *ConstantFold::SimplifyUnary(UnaryStmtNode *node)
2503 {
2504 CHECK_NULL_FATAL(node);
2505 BaseNode *returnValue = nullptr;
2506 if (node->Opnd(0) == nullptr) {
2507 return node;
2508 }
2509 returnValue = Fold(node->Opnd(0));
2510 if (returnValue != nullptr) {
2511 node->SetOpnd(returnValue, 0);
2512 }
2513 return node;
2514 }
2515
SimplifyBinary(BinaryStmtNode * node)2516 StmtNode *ConstantFold::SimplifyBinary(BinaryStmtNode *node)
2517 {
2518 CHECK_NULL_FATAL(node);
2519 BaseNode *returnValue = nullptr;
2520 returnValue = Fold(node->GetBOpnd(0));
2521 if (returnValue != nullptr) {
2522 node->SetBOpnd(returnValue, 0);
2523 }
2524 returnValue = Fold(node->GetBOpnd(1));
2525 if (returnValue != nullptr) {
2526 node->SetBOpnd(returnValue, 1);
2527 }
2528 return node;
2529 }
2530
SimplifyBlock(BlockNode * node)2531 StmtNode *ConstantFold::SimplifyBlock(BlockNode *node)
2532 {
2533 CHECK_NULL_FATAL(node);
2534 if (node->GetFirst() == nullptr) {
2535 return node;
2536 }
2537 StmtNode *s = node->GetFirst();
2538 StmtNode *prevStmt = nullptr;
2539 do {
2540 StmtNode *returnValue = Simplify(s);
2541 if (returnValue != nullptr) {
2542 if (returnValue->GetOpCode() == OP_block) {
2543 BlockNode *blk = static_cast<BlockNode*>(returnValue);
2544 if (blk->IsEmpty()) {
2545 node->RemoveStmt(s);
2546 } else {
2547 node->ReplaceStmtWithBlock(*s, *blk);
2548 prevStmt = s;
2549 }
2550 } else {
2551 node->ReplaceStmt1WithStmt2(s, returnValue);
2552 prevStmt = s;
2553 }
2554 s = s->GetNext();
2555 } else {
2556 // delete s from block
2557 StmtNode *nextStmt = s->GetNext();
2558 if (s == node->GetFirst()) {
2559 node->SetFirst(nextStmt);
2560 if (nextStmt != nullptr) {
2561 nextStmt->SetPrev(nullptr);
2562 }
2563 } else {
2564 CHECK_NULL_FATAL(prevStmt);
2565 prevStmt->SetNext(nextStmt);
2566 if (nextStmt != nullptr) {
2567 nextStmt->SetPrev(prevStmt);
2568 }
2569 }
2570 if (s == node->GetLast()) {
2571 node->SetLast(prevStmt);
2572 }
2573 s = nextStmt;
2574 }
2575 } while (s != nullptr);
2576 return node;
2577 }
2578
SimplifyAsm(AsmNode * node)2579 StmtNode *ConstantFold::SimplifyAsm(AsmNode *node)
2580 {
2581 CHECK_NULL_FATAL(node);
2582 /* fold constval in input */
2583 for (size_t i = 0; i < node->NumOpnds(); i++) {
2584 const std::string &str = GlobalTables::GetUStrTable().GetStringFromStrIdx(node->inputConstraints[i]);
2585 if (str == "i") {
2586 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(i));
2587 node->SetOpnd(p.first, i);
2588 continue;
2589 }
2590 }
2591 return node;
2592 }
2593
SimplifyIf(IfStmtNode * node)2594 StmtNode *ConstantFold::SimplifyIf(IfStmtNode *node)
2595 {
2596 CHECK_NULL_FATAL(node);
2597 BaseNode *returnValue = nullptr;
2598 (void)Simplify(node->GetThenPart());
2599 if (node->GetElsePart()) {
2600 (void)Simplify(node->GetElsePart());
2601 }
2602 returnValue = Fold(node->Opnd());
2603 if (returnValue != nullptr) {
2604 node->SetOpnd(returnValue, 0);
2605 // do not delete c/c++ dead if-body here
2606 return node;
2607 }
2608 return node;
2609 }
2610
SimplifyWhile(WhileStmtNode * node)2611 StmtNode *ConstantFold::SimplifyWhile(WhileStmtNode *node)
2612 {
2613 CHECK_NULL_FATAL(node);
2614 BaseNode *returnValue = nullptr;
2615 if (node->Opnd(0) == nullptr) {
2616 return node;
2617 }
2618 if (node->GetBody()) {
2619 (void)Simplify(node->GetBody());
2620 }
2621 returnValue = Fold(node->Opnd(0));
2622 if (returnValue != nullptr) {
2623 node->SetOpnd(returnValue, 0);
2624 // do not delete c/c++ dead while-body here
2625 return node;
2626 }
2627 return node;
2628 }
2629
SimplifyNary(NaryStmtNode * node)2630 StmtNode *ConstantFold::SimplifyNary(NaryStmtNode *node)
2631 {
2632 CHECK_NULL_FATAL(node);
2633 BaseNode *returnValue = nullptr;
2634 for (size_t i = 0; i < node->NumOpnds(); i++) {
2635 returnValue = Fold(node->GetNopndAt(i));
2636 if (returnValue != nullptr) {
2637 node->SetNOpndAt(i, returnValue);
2638 }
2639 }
2640 return node;
2641 }
2642
SimplifyIcall(IcallNode * node)2643 StmtNode *ConstantFold::SimplifyIcall(IcallNode *node)
2644 {
2645 CHECK_NULL_FATAL(node);
2646 BaseNode *returnValue = nullptr;
2647 for (size_t i = 0; i < node->NumOpnds(); i++) {
2648 returnValue = Fold(node->GetNopndAt(i));
2649 if (returnValue != nullptr) {
2650 node->SetNOpndAt(i, returnValue);
2651 }
2652 }
2653 // icall node transform to call node
2654 CHECK_FATAL(!node->GetNopnd().empty(), "container check");
2655 switch (node->GetNopndAt(0)->GetOpCode()) {
2656 case OP_addroffunc: {
2657 AddroffuncNode *addrofNode = static_cast<AddroffuncNode*>(node->GetNopndAt(0));
2658 CallNode *callNode = mirModule->CurFuncCodeMemPool()->New<CallNode>(
2659 *mirModule,
2660 (node->GetOpCode() == OP_icall || node->GetOpCode() == OP_icallproto) ? OP_call : OP_callassigned);
2661 if (node->GetOpCode() == OP_icallassigned || node->GetOpCode() == OP_icallprotoassigned) {
2662 callNode->SetReturnVec(node->GetReturnVec());
2663 }
2664 callNode->SetPUIdx(addrofNode->GetPUIdx());
2665 for (size_t i = 1; i < node->GetNopndSize(); i++) {
2666 callNode->GetNopnd().push_back(node->GetNopndAt(i));
2667 }
2668 callNode->SetNumOpnds(callNode->GetNopndSize());
2669 // reuse stmtID to skip update stmtFreqs when profileUse is on
2670 callNode->SetStmtID(node->GetStmtID());
2671 return callNode;
2672 }
2673 default:
2674 break;
2675 }
2676 return node;
2677 }
2678
ProcessFunc(MIRFunction * func)2679 void ConstantFold::ProcessFunc(MIRFunction *func)
2680 {
2681 if (func->IsEmpty()) {
2682 return;
2683 }
2684 mirModule->SetCurFunction(func);
2685 (void)Simplify(func->GetBody());
2686 }
2687
2688 } // namespace maple
2689