• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "triple.h"
17 #include "simplify.h"
18 #include <functional>
19 #include <initializer_list>
20 #include <iostream>
21 #include <algorithm>
22 #include "constantfold.h"
23 #include "mpl_logging.h"
24 
25 namespace maple {
26 
27 namespace {
28 
29 constexpr char kFuncNameOfMemset[] = "memset";
30 constexpr char kFuncNameOfMemcpy[] = "memcpy";
31 constexpr char kFuncNameOfMemsetS[] = "memset_s";
32 constexpr char kFuncNameOfMemcpyS[] = "memcpy_s";
33 
34 // Truncate the constant field of 'union' if it's written as scalar type (e.g. int),
35 // but accessed as bit-field type with smaller size.
36 //
37 // Return the truncated constant or nullptr if the constant doesn't need to be truncated.
TruncateUnionConstant(const MIRStructType & unionType,MIRConst * fieldCst,const MIRType & unionFieldType)38 MIRConst *TruncateUnionConstant(const MIRStructType &unionType, MIRConst *fieldCst, const MIRType &unionFieldType)
39 {
40     if (unionType.GetKind() != kTypeUnion) {
41         return nullptr;
42     }
43 
44     auto *bitFieldType = safe_cast<MIRBitFieldType>(unionFieldType);
45     auto *intCst = safe_cast<MIRIntConst>(fieldCst);
46     if (!bitFieldType || !intCst) {
47         return nullptr;
48     }
49 
50     bool isBigEndian = Triple::GetTriple().IsBigEndian();
51     IntVal val = intCst->GetValue();
52     uint8 bitSize = bitFieldType->GetFieldSize();
53     if (bitSize >= val.GetBitWidth()) {
54         return nullptr;
55     }
56 
57     if (isBigEndian) {
58         val = val.LShr(val.GetBitWidth() - bitSize);
59     } else {
60         val = val & ((uint64(1) << bitSize) - 1);
61     }
62 
63     return GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, fieldCst->GetType());
64 }
65 
66 }  // namespace
67 
68 // If size (in byte) is bigger than this threshold, we won't expand memop
69 const uint32 SimplifyMemOp::thresholdMemsetExpand = 512;
70 const uint32 SimplifyMemOp::thresholdMemcpyExpand = 512;
71 const uint32 SimplifyMemOp::thresholdMemsetSExpand = 1024;
72 const uint32 SimplifyMemOp::thresholdMemcpySExpand = 1024;
73 static const uint32 kMaxMemoryBlockSizeToAssign = 8;  // in byte
74 
IsMathSqrt(const std::string funcName)75 bool Simplify::IsMathSqrt(const std::string funcName)
76 {
77     return false;
78 }
79 
IsMathAbs(const std::string funcName)80 bool Simplify::IsMathAbs(const std::string funcName)
81 {
82     return false;
83 }
84 
IsMathMax(const std::string funcName)85 bool Simplify::IsMathMax(const std::string funcName)
86 {
87     return false;
88 }
89 
IsMathMin(const std::string funcName)90 bool Simplify::IsMathMin(const std::string funcName)
91 {
92     return false;
93 }
94 
SimplifyMathMethod(const StmtNode & stmt,BlockNode & block)95 bool Simplify::SimplifyMathMethod(const StmtNode &stmt, BlockNode &block)
96 {
97     if (stmt.GetOpCode() != OP_callassigned) {
98         return false;
99     }
100     auto &cnode = static_cast<const CallNode &>(stmt);
101     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(cnode.GetPUIdx());
102     DEBUG_ASSERT(calleeFunc != nullptr, "null ptr check");
103     const std::string &funcName = calleeFunc->GetName();
104     if (funcName.empty()) {
105         return false;
106     }
107     if (!mirMod.IsCModule()) {
108         return false;
109     }
110     if (cnode.GetNumOpnds() == 0 || cnode.GetReturnVec().empty()) {
111         return false;
112     }
113 
114     auto *opnd0 = cnode.Opnd(0);
115     DEBUG_ASSERT(opnd0 != nullptr, "null ptr check");
116     auto *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(opnd0->GetPrimType());
117 
118     BaseNode *opExpr = nullptr;
119     if (IsMathSqrt(funcName) && !IsPrimitiveFloat(opnd0->GetPrimType())) {
120         opExpr = builder->CreateExprUnary(OP_sqrt, *type, opnd0);
121     } else if (IsMathAbs(funcName)) {
122         opExpr = builder->CreateExprUnary(OP_abs, *type, opnd0);
123     } else if (IsMathMax(funcName)) {
124         opExpr = builder->CreateExprBinary(OP_max, *type, opnd0, cnode.Opnd(1));
125     } else if (IsMathMin(funcName)) {
126         opExpr = builder->CreateExprBinary(OP_min, *type, opnd0, cnode.Opnd(1));
127     }
128     if (opExpr != nullptr) {
129         auto stIdx = cnode.GetNthReturnVec(0).first;
130         auto *dassign = builder->CreateStmtDassign(stIdx, 0, opExpr);
131         block.ReplaceStmt1WithStmt2(&stmt, dassign);
132         return true;
133     }
134     return false;
135 }
136 
SimplifyCallAssigned(StmtNode & stmt,BlockNode & block)137 void Simplify::SimplifyCallAssigned(StmtNode &stmt, BlockNode &block)
138 {
139     if (SimplifyMathMethod(stmt, block)) {
140         return;
141     }
142     simplifyMemOp.SetDebug(dump);
143     simplifyMemOp.SetFunction(currFunc);
144     if (simplifyMemOp.AutoSimplify(stmt, block, false)) {
145         return;
146     }
147 }
148 
149 constexpr uint32 kUpperLimitOfFieldNum = 10;
GetDassignedStructType(const DassignNode * dassign,MIRFunction * func)150 static MIRStructType *GetDassignedStructType(const DassignNode *dassign, MIRFunction *func)
151 {
152     const auto &lhsStIdx = dassign->GetStIdx();
153     auto lhsSymbol = func->GetLocalOrGlobalSymbol(lhsStIdx);
154     DEBUG_ASSERT(lhsSymbol != nullptr, "lhsSymbol should not be nullptr");
155     auto lhsAggType = lhsSymbol->GetType();
156     if (!lhsAggType->IsStructType()) {
157         return nullptr;
158     }
159     if (lhsAggType->GetKind() == kTypeUnion) {  // no need to split union's field
160         return nullptr;
161     }
162     auto lhsFieldID = dassign->GetFieldID();
163     if (lhsFieldID != 0) {
164         CHECK_FATAL(lhsAggType->IsStructType(), "only struct has non-zero fieldID");
165         lhsAggType = static_cast<MIRStructType *>(lhsAggType)->GetFieldType(lhsFieldID);
166         if (!lhsAggType->IsStructType()) {
167             return nullptr;
168         }
169         if (lhsAggType->GetKind() == kTypeUnion) {  // no need to split union's field
170             return nullptr;
171         }
172     }
173     if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
174         return nullptr;
175     }
176     return static_cast<MIRStructType *>(lhsAggType);
177 }
178 
GetIassignedStructType(const IassignNode * iassign)179 static MIRStructType *GetIassignedStructType(const IassignNode *iassign)
180 {
181     auto ptrTyIdx = iassign->GetTyIdx();
182     auto *ptrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(ptrTyIdx);
183     CHECK_FATAL(ptrType->IsMIRPtrType(), "must be pointer type");
184     auto aggTyIdx = static_cast<MIRPtrType *>(ptrType)->GetPointedTyIdxWithFieldID(iassign->GetFieldID());
185     auto *lhsAggType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(aggTyIdx);
186     if (!lhsAggType->IsStructType()) {
187         return nullptr;
188     }
189     if (lhsAggType->GetKind() == kTypeUnion) {
190         return nullptr;
191     }
192     if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
193         return nullptr;
194     }
195     return static_cast<MIRStructType *>(lhsAggType);
196 }
197 
GetReadedStructureType(const DreadNode * dread,const MIRFunction * func)198 static MIRStructType *GetReadedStructureType(const DreadNode *dread, const MIRFunction *func)
199 {
200     const auto &rhsStIdx = dread->GetStIdx();
201     auto rhsSymbol = func->GetLocalOrGlobalSymbol(rhsStIdx);
202     DEBUG_ASSERT(rhsSymbol != nullptr, "rhsSymbol should not be nullptr");
203     auto rhsAggType = rhsSymbol->GetType();
204     auto rhsFieldID = dread->GetFieldID();
205     if (rhsFieldID != 0) {
206         CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
207         rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
208     }
209     if (!rhsAggType->IsStructType()) {
210         return nullptr;
211     }
212     return static_cast<MIRStructType *>(rhsAggType);
213 }
214 
GetReadedStructureType(const IreadNode * iread,const MIRFunction *)215 static MIRStructType *GetReadedStructureType(const IreadNode *iread, const MIRFunction *)
216 {
217     auto rhsPtrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx());
218     CHECK_FATAL(rhsPtrType->IsMIRPtrType(), "must be pointer type");
219     auto rhsAggType = static_cast<MIRPtrType *>(rhsPtrType)->GetPointedType();
220     auto rhsFieldID = iread->GetFieldID();
221     if (rhsFieldID != 0) {
222         CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
223         rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
224     }
225     if (!rhsAggType->IsStructType()) {
226         return nullptr;
227     }
228     return static_cast<MIRStructType *>(rhsAggType);
229 }
230 
231 template <class RhsType, class AssignType>
SplitAggCopy(const AssignType * assignNode,MIRStructType * structureType,BlockNode * block,MIRFunction * func)232 static StmtNode *SplitAggCopy(const AssignType *assignNode, MIRStructType *structureType, BlockNode *block,
233                               MIRFunction *func)
234 {
235     auto *readNode = static_cast<RhsType *>(assignNode->GetRHS());
236     auto rhsFieldID = readNode->GetFieldID();
237     auto *rhsAggType = GetReadedStructureType(readNode, func);
238     if (structureType != rhsAggType) {
239         return nullptr;
240     }
241 
242     for (FieldID id = 1; id <= static_cast<FieldID>(structureType->NumberOfFieldIDs()); ++id) {
243         MIRType *fieldType = structureType->GetFieldType(id);
244         if (fieldType->GetSize() == 0) {
245             continue;  // field size is zero for empty struct/union;
246         }
247         if (fieldType->GetKind() == kTypeBitField && static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
248             continue;  // bitfield size is zero
249         }
250         auto *newDassign = assignNode->CloneTree(func->GetCodeMemPoolAllocator());
251         newDassign->SetFieldID(assignNode->GetFieldID() + id);
252         auto *newRHS = static_cast<RhsType *>(newDassign->GetRHS());
253         newRHS->SetFieldID(rhsFieldID + id);
254         newRHS->SetPrimType(fieldType->GetPrimType());
255         block->InsertAfter(assignNode, newDassign);
256         if (fieldType->IsMIRUnionType()) {
257             id += static_cast<FieldID>(fieldType->NumberOfFieldIDs());
258         }
259     }
260     auto newAssign = assignNode->GetNext();
261     block->RemoveStmt(assignNode);
262     return newAssign;
263 }
264 
SplitDassignAggCopy(DassignNode * dassign,BlockNode * block,MIRFunction * func)265 static StmtNode *SplitDassignAggCopy(DassignNode *dassign, BlockNode *block, MIRFunction *func)
266 {
267     auto *rhs = dassign->GetRHS();
268     if (rhs->GetPrimType() != PTY_agg) {
269         return nullptr;
270     }
271 
272     auto *lhsAggType = GetDassignedStructType(dassign, func);
273     if (lhsAggType == nullptr) {
274         return nullptr;
275     }
276 
277     if (rhs->GetOpCode() == OP_dread) {
278         auto *lhsSymbol = func->GetLocalOrGlobalSymbol(dassign->GetStIdx());
279         auto *rhsSymbol = func->GetLocalOrGlobalSymbol(static_cast<DreadNode *>(rhs)->GetStIdx());
280         if (!lhsSymbol->IsLocal() && !rhsSymbol->IsLocal()) {
281             return nullptr;
282         }
283 
284         return SplitAggCopy<DreadNode>(dassign, lhsAggType, block, func);
285     } else if (rhs->GetOpCode() == OP_iread) {
286         return SplitAggCopy<IreadNode>(dassign, lhsAggType, block, func);
287     }
288     return nullptr;
289 }
290 
SplitIassignAggCopy(IassignNode * iassign,BlockNode * block,MIRFunction * func)291 static StmtNode *SplitIassignAggCopy(IassignNode *iassign, BlockNode *block, MIRFunction *func)
292 {
293     auto rhs = iassign->GetRHS();
294     if (rhs->GetPrimType() != PTY_agg) {
295         return nullptr;
296     }
297 
298     auto *lhsAggType = GetIassignedStructType(iassign);
299     if (lhsAggType == nullptr) {
300         return nullptr;
301     }
302 
303     if (rhs->GetOpCode() == OP_dread) {
304         return SplitAggCopy<DreadNode>(iassign, lhsAggType, block, func);
305     } else if (rhs->GetOpCode() == OP_iread) {
306         return SplitAggCopy<IreadNode>(iassign, lhsAggType, block, func);
307     }
308     return nullptr;
309 }
310 
UseGlobalVar(const BaseNode * expr)311 bool UseGlobalVar(const BaseNode *expr)
312 {
313     if (expr->GetOpCode() == OP_addrof || expr->GetOpCode() == OP_dread) {
314         StIdx stIdx = static_cast<const AddrofNode *>(expr)->GetStIdx();
315         if (stIdx.IsGlobal()) {
316             return true;
317         }
318     }
319     for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
320         if (UseGlobalVar(expr->Opnd(i))) {
321             return true;
322         }
323     }
324     return false;
325 }
326 
ProcessStmt(StmtNode & stmt)327 void Simplify::ProcessStmt(StmtNode &stmt)
328 {
329     switch (stmt.GetOpCode()) {
330         case OP_callassigned: {
331             SimplifyCallAssigned(stmt, *currBlock);
332             break;
333         }
334         case OP_intrinsiccall: {
335             simplifyMemOp.SetDebug(dump);
336             simplifyMemOp.SetFunction(currFunc);
337             (void)simplifyMemOp.AutoSimplify(stmt, *currBlock, false);
338             break;
339         }
340         case OP_dassign: {
341             auto *newStmt = SplitDassignAggCopy(static_cast<DassignNode *>(&stmt), currBlock, currFunc);
342             if (newStmt) {
343                 ProcessBlock(*newStmt);
344             }
345             break;
346         }
347         case OP_iassign: {
348             auto *newStmt = SplitIassignAggCopy(static_cast<IassignNode *>(&stmt), currBlock, currFunc);
349             if (newStmt) {
350                 ProcessBlock(*newStmt);
351             }
352             break;
353         }
354         case OP_if:
355         case OP_while:
356         case OP_dowhile: {
357             auto unaryStmt = static_cast<UnaryStmtNode &>(stmt);
358             unaryStmt.SetRHS(SimplifyExpr(*unaryStmt.GetRHS()));
359             return;
360         }
361         default: {
362             break;
363         }
364     }
365     for (size_t i = 0; i < stmt.NumOpnds(); ++i) {
366         if (stmt.Opnd(i)) {
367             stmt.SetOpnd(SimplifyExpr(*stmt.Opnd(i)), i);
368         }
369     }
370 }
371 
SimplifyExpr(BaseNode & expr)372 BaseNode *Simplify::SimplifyExpr(BaseNode &expr)
373 {
374     switch (expr.GetOpCode()) {
375         case OP_dread: {
376             auto &dread = static_cast<DreadNode &>(expr);
377             return ReplaceExprWithConst(dread);
378         }
379         default: {
380             for (auto i = 0; i < expr.GetNumOpnds(); i++) {
381                 if (expr.Opnd(i)) {
382                     expr.SetOpnd(SimplifyExpr(*expr.Opnd(i)), i);
383                 }
384             }
385             break;
386         }
387     }
388     return &expr;
389 }
390 
ReplaceExprWithConst(DreadNode & dread)391 BaseNode *Simplify::ReplaceExprWithConst(DreadNode &dread)
392 {
393     auto stIdx = dread.GetStIdx();
394     auto fieldId = dread.GetFieldID();
395     auto *symbol = currFunc->GetLocalOrGlobalSymbol(stIdx);
396     DEBUG_ASSERT(symbol != nullptr, "nullptr check");
397     auto *symbolConst = symbol->GetKonst();
398     if (!currFunc->GetModule()->IsCModule() || !symbolConst || !stIdx.IsGlobal() ||
399         !IsSymbolReplaceableWithConst(*symbol)) {
400         return &dread;
401     }
402     if (fieldId != 0) {
403         symbolConst = GetElementConstFromFieldId(fieldId, symbolConst);
404     }
405     if (!symbolConst || !IsConstRepalceable(*symbolConst)) {
406         return &dread;
407     }
408     return currFunc->GetModule()->GetMIRBuilder()->CreateConstval(symbolConst);
409 }
410 
IsSymbolReplaceableWithConst(const MIRSymbol & symbol) const411 bool Simplify::IsSymbolReplaceableWithConst(const MIRSymbol &symbol) const
412 {
413     return (symbol.GetStorageClass() == kScFstatic && !symbol.HasPotentialAssignment()) ||
414            symbol.GetAttrs().GetAttr(ATTR_const);
415 }
416 
IsConstRepalceable(const MIRConst & mirConst) const417 bool Simplify::IsConstRepalceable(const MIRConst &mirConst) const
418 {
419     switch (mirConst.GetKind()) {
420         case kConstInt:
421         case kConstFloatConst:
422         case kConstDoubleConst:
423         case kConstFloat128Const:
424         case kConstLblConst:
425             return true;
426         default:
427             return false;
428     }
429 }
430 
GetElementConstFromFieldId(FieldID fieldId,MIRConst * mirConst)431 MIRConst *Simplify::GetElementConstFromFieldId(FieldID fieldId, MIRConst *mirConst)
432 {
433     FieldID currFieldId = 1;
434     MIRConst *resultConst = nullptr;
435     auto originAggConst = static_cast<MIRAggConst *>(mirConst);
436     auto originAggType = static_cast<MIRStructType &>(originAggConst->GetType());
437     bool hasReached = false;
438     std::function<void(MIRConst *)> traverseAgg = [&](MIRConst *currConst) {
439         auto *currAggConst = safe_cast<MIRAggConst>(currConst);
440         ASSERT_NOT_NULL(currAggConst);
441         auto *currAggType = safe_cast<MIRStructType>(currAggConst->GetType());
442         ASSERT_NOT_NULL(currAggType);
443         for (size_t iter = 0; iter < currAggType->GetFieldsSize() && !hasReached; ++iter) {
444             size_t constIdx = currAggType->GetKind() == kTypeUnion ? 1 : iter + 1;
445             auto *fieldConst = currAggConst->GetAggConstElement(constIdx);
446             auto *fieldType = originAggType.GetFieldType(currFieldId);
447 
448             if (currFieldId == fieldId) {
449                 if (auto *truncCst = TruncateUnionConstant(*currAggType, fieldConst, *fieldType)) {
450                     resultConst = truncCst;
451                 } else {
452                     resultConst = fieldConst;
453                 }
454 
455                 hasReached = true;
456                 return;
457             }
458 
459             ++currFieldId;
460             if (fieldType->GetKind() == kTypeUnion || fieldType->GetKind() == kTypeStruct) {
461                 traverseAgg(fieldConst);
462             }
463         }
464     };
465     traverseAgg(mirConst);
466     CHECK_FATAL(hasReached, "const not found");
467     return resultConst;
468 }
469 
Finish()470 void Simplify::Finish() {}
471 
472 // Join `num` `byte`s into a number
473 // Example:
474 //   byte   num                output
475 //   0x0a    2                 0x0a0a
476 //   0x12    4             0x12121212
477 //   0xff    8     0xffffffffffffffff
JoinBytes(int byte,uint32 num)478 static uint64 JoinBytes(int byte, uint32 num)
479 {
480     CHECK_FATAL(num <= 8, "not support"); // just support num less or equal 8, see comment above
481     uint64 realByte = static_cast<uint64>(byte % 256);
482     if (realByte == 0) {
483         return 0;
484     }
485     uint64 result = 0;
486     for (uint32 i = 0; i < num; ++i) {
487         result += (realByte << (i * k8BitSize));
488     }
489     return result;
490 }
491 
ConstructConstvalNode(uint64 val,PrimType primType,MIRBuilder & mirBuilder)492 static BaseNode *ConstructConstvalNode(uint64 val, PrimType primType, MIRBuilder &mirBuilder)
493 {
494     PrimType constPrimType = primType;
495     if (IsPrimitiveFloat(primType)) {
496         constPrimType = GetIntegerPrimTypeBySizeAndSign(GetPrimTypeBitSize(primType), false);
497     }
498     MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constPrimType));
499     MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, *constType);
500     BaseNode *ret = mirBuilder.CreateConstval(mirConst);
501     if (IsPrimitiveFloat(primType)) {
502         MIRType *floatType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(primType));
503         ret = mirBuilder.CreateExprRetype(*floatType, constPrimType, ret);
504     }
505     return ret;
506 }
507 
ConstructConstvalNode(int64 byte,uint64 num,PrimType primType,MIRBuilder & mirBuilder)508 static BaseNode *ConstructConstvalNode(int64 byte, uint64 num, PrimType primType, MIRBuilder &mirBuilder)
509 {
510     auto val = JoinBytes(byte, static_cast<uint32>(num));
511     return ConstructConstvalNode(val, primType, mirBuilder);
512 }
513 
514 // Input total size of memory, split the memory into several blocks, the max block size is 8 bytes
515 // Example:
516 //   input        output
517 //     40     [ 8, 8, 8, 8, 8 ]
518 //     31     [ 8, 8, 8, 4, 2, 1 ]
SplitMemoryIntoBlocks(size_t totalMemorySize,std::vector<uint32> & blocks)519 static void SplitMemoryIntoBlocks(size_t totalMemorySize, std::vector<uint32> &blocks)
520 {
521     size_t leftSize = totalMemorySize;
522     size_t curBlockSize = kMaxMemoryBlockSizeToAssign;  // max block size in byte
523     while (curBlockSize > 0) {
524         size_t n = leftSize / curBlockSize;
525         blocks.insert(blocks.end(), n, curBlockSize);
526         leftSize -= (n * curBlockSize);
527         curBlockSize = curBlockSize >> 1;
528     }
529 }
530 
IsComplexExpr(const BaseNode * expr,MIRFunction & func)531 static bool IsComplexExpr(const BaseNode *expr, MIRFunction &func)
532 {
533     Opcode op = expr->GetOpCode();
534     if (op == OP_regread) {
535         return false;
536     }
537     if (op == OP_dread) {
538         auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const DreadNode *>(expr)->GetStIdx());
539         DEBUG_ASSERT(symbol != nullptr, "nullptr check");
540         if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
541             return true;  // dread global/static var is complex expr because it will be lowered to adrp + add
542         } else {
543             return false;
544         }
545     }
546     if (op == OP_addrof) {
547         auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const AddrofNode *>(expr)->GetStIdx());
548         if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
549             return true;  // addrof global/static var is complex expr because it will be lowered to adrp + add
550         } else {
551             return false;
552         }
553     }
554     return true;
555 }
556 
BuildAsRhsExpr(MIRFunction & func) const557 BaseNode *MemEntry::BuildAsRhsExpr(MIRFunction &func) const
558 {
559     BaseNode *expr = nullptr;
560     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
561     if (addrExpr->GetOpCode() == OP_addrof) {
562         // We prefer dread to iread
563         // consider iaddrof if possible
564         auto *addrof = static_cast<AddrofNode *>(addrExpr);
565         auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
566         expr = mirBuilder->CreateExprDread(*memType, addrof->GetFieldID(), *symbol);
567     } else {
568         MIRType *structPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
569         expr = mirBuilder->CreateExprIread(*memType, *structPtrType, 0, addrExpr);
570     }
571     return expr;
572 }
573 
InsertAndMayPrintStmt(BlockNode & block,const StmtNode & anchor,bool debug,StmtNode * stmt)574 static void InsertAndMayPrintStmt(BlockNode &block, const StmtNode &anchor, bool debug, StmtNode *stmt)
575 {
576     if (stmt == nullptr) {
577         return;
578     }
579     block.InsertBefore(&anchor, stmt);
580     if (debug) {
581         stmt->Dump(0);
582     }
583 }
584 
TryToExtractComplexExpr(BaseNode * expr,MIRFunction & func,BlockNode & block,const StmtNode & anchor,bool debug)585 static BaseNode *TryToExtractComplexExpr(BaseNode *expr, MIRFunction &func, BlockNode &block, const StmtNode &anchor,
586                                          bool debug)
587 {
588     if (!IsComplexExpr(expr, func)) {
589         return expr;
590     }
591     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
592     auto pregIdx = func.GetPregTab()->CreatePreg(PTY_ptr);
593     StmtNode *regassign = mirBuilder->CreateStmtRegassign(PTY_ptr, pregIdx, expr);
594     InsertAndMayPrintStmt(block, anchor, debug, regassign);
595     auto *extractedExpr = mirBuilder->CreateExprRegread(PTY_ptr, pregIdx);
596     return extractedExpr;
597 }
598 
ExpandMemsetLowLevel(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const599 void MemEntry::ExpandMemsetLowLevel(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
600                                     MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
601 {
602     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
603     std::vector<uint32> blocks;
604     SplitMemoryIntoBlocks(size, blocks);
605     int32 offset = 0;
606     // If blocks.size() > 1 and `dst` is not a leaf node,
607     // we should extract common expr to avoid redundant expression
608     BaseNode *realDstExpr = addrExpr;
609     if (blocks.size() > 1) {
610         realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
611     }
612     BaseNode *readConst = nullptr;
613     // rhs const is big, extract it to avoid redundant expression
614     bool shouldExtractRhs = blocks.size() > 1 && (byte & 0xff) != 0;
615     for (auto curSize : blocks) {
616         // low level memset expand result:
617         //   iassignoff <prim-type> <offset> (dstAddrExpr, constval <prim-type> xx)
618         PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
619         BaseNode *rhsExpr = ConstructConstvalNode(byte, curSize, constType, *mirBuilder);
620         if (shouldExtractRhs) {
621             // we only need to extract u64 const once
622             PregIdx pregIdx = func.GetPregTab()->CreatePreg(constType);
623             auto *constAssign = mirBuilder->CreateStmtRegassign(constType, pregIdx, rhsExpr);
624             InsertAndMayPrintStmt(block, stmt, debug, constAssign);
625             readConst = mirBuilder->CreateExprRegread(constType, pregIdx);
626             shouldExtractRhs = false;
627         }
628         if (readConst != nullptr && curSize == kMaxMemoryBlockSizeToAssign) {
629             rhsExpr = readConst;
630         }
631         auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
632         InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
633         if (debug) {
634             ASSERT_NOT_NULL(iassignoff);
635             iassignoff->Dump(0);
636         }
637         offset += static_cast<int32>(curSize);
638     }
639     // handle memset return val
640     auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
641     InsertAndMayPrintStmt(block, stmt, debug, retAssign);
642     // return ERRNO_INVAL if memset_s dest is NULL
643     block.RemoveStmt(&stmt);
644 }
645 
646 // handle memset, memcpy return val
GenMemopRetAssign(StmtNode & stmt,MIRFunction & func,bool isLowLevel,MemOpKind memOpKind,ErrorNumber errorNumber)647 StmtNode *MemEntry::GenMemopRetAssign(StmtNode &stmt, MIRFunction &func, bool isLowLevel, MemOpKind memOpKind,
648                                       ErrorNumber errorNumber)
649 {
650     if (stmt.GetOpCode() != OP_call && stmt.GetOpCode() != OP_callassigned) {
651         return nullptr;
652     }
653     auto &callStmt = static_cast<CallNode &>(stmt);
654     const auto &retVec = callStmt.GetReturnVec();
655     if (retVec.empty()) {
656         return nullptr;
657     }
658     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
659     BaseNode *rhs = callStmt.Opnd(0);  // for memset, memcpy
660     if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
661         // memset_s and memcpy_s must return an errorNumber
662         MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_i32));
663         MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(errorNumber, *constType);
664         rhs = mirBuilder->CreateConstval(mirConst);
665     }
666     if (!retVec[0].second.IsReg()) {
667         auto *retAssign = mirBuilder->CreateStmtDassign(retVec[0].first, 0, rhs);
668         return retAssign;
669     } else {
670         PregIdx pregIdx = retVec[0].second.GetPregIdx();
671         auto pregType = func.GetPregTab()->GetPregTableItem(static_cast<uint32>(pregIdx))->GetPrimType();
672         auto *retAssign = mirBuilder->CreateStmtRegassign(pregType, pregIdx, rhs);
673         if (isLowLevel) {
674             retAssign->GetRHS()->SetPrimType(pregType);
675         }
676         return retAssign;
677     }
678 }
679 
ComputeMemOpKind(StmtNode & stmt)680 MemOpKind SimplifyMemOp::ComputeMemOpKind(StmtNode &stmt)
681 {
682     if (stmt.GetOpCode() == OP_intrinsiccall) {
683         auto intrinsicID = static_cast<IntrinsiccallNode &>(stmt).GetIntrinsic();
684         if (intrinsicID == INTRN_C_memset) {
685             return MEM_OP_memset;
686         } else if (intrinsicID == INTRN_C_memcpy) {
687             return MEM_OP_memcpy;
688         }
689     }
690     // lowered memop function (such as memset) may be a call, not callassigned
691     if (stmt.GetOpCode() != OP_callassigned && stmt.GetOpCode() != OP_call) {
692         return MEM_OP_unknown;
693     }
694     auto &callStmt = static_cast<CallNode &>(stmt);
695     MIRFunction *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callStmt.GetPUIdx());
696     const char *funcName = func->GetName().c_str();
697     if (strcmp(funcName, kFuncNameOfMemset) == 0) {
698         return MEM_OP_memset;
699     }
700     if (strcmp(funcName, kFuncNameOfMemcpy) == 0) {
701         return MEM_OP_memcpy;
702     }
703     if (strcmp(funcName, kFuncNameOfMemsetS) == 0) {
704         return MEM_OP_memset_s;
705     }
706     if (strcmp(funcName, kFuncNameOfMemcpyS) == 0) {
707         return MEM_OP_memcpy_s;
708     }
709     return MEM_OP_unknown;
710 }
711 
AutoSimplify(StmtNode & stmt,BlockNode & block,bool isLowLevel)712 bool SimplifyMemOp::AutoSimplify(StmtNode &stmt, BlockNode &block, bool isLowLevel)
713 {
714     return false;
715 }
716 }  // namespace maple
717