• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "triple.h"
17 #include "simplify.h"
18 #include <functional>
19 #include <initializer_list>
20 #include <iostream>
21 #include <algorithm>
22 #include "constantfold.h"
23 #include "mpl_logging.h"
24 
25 namespace maple {
26 
27 namespace {
28 
29 constexpr char kClassNameOfMath[] = "Ljava_2Flang_2FMath_3B";
30 constexpr char kFuncNamePrefixOfMathSqrt[] = "Ljava_2Flang_2FMath_3B_7Csqrt_7C_28D_29D";
31 constexpr char kFuncNamePrefixOfMathAbs[] = "Ljava_2Flang_2FMath_3B_7Cabs_7C";
32 constexpr char kFuncNamePrefixOfMathMax[] = "Ljava_2Flang_2FMath_3B_7Cmax_7C";
33 constexpr char kFuncNamePrefixOfMathMin[] = "Ljava_2Flang_2FMath_3B_7Cmin_7C";
34 constexpr char kFuncNameOfMathAbs[] = "abs";
35 constexpr char kFuncNameOfMemset[] = "memset";
36 constexpr char kFuncNameOfMemcpy[] = "memcpy";
37 constexpr char kFuncNameOfMemsetS[] = "memset_s";
38 constexpr char kFuncNameOfMemcpyS[] = "memcpy_s";
39 constexpr uint64_t kSecurecMemMaxLen = 0x7fffffffUL;
40 static constexpr int32 kProbUnlikely = 1000;
41 constexpr uint32_t kMemsetDstOpndIdx = 0;
42 constexpr uint32_t kMemsetSDstSizeOpndIdx = 1;
43 constexpr uint32_t kMemsetSSrcOpndIdx = 2;
44 constexpr uint32_t kMemsetSSrcSizeOpndIdx = 3;
45 
46 // Truncate the constant field of 'union' if it's written as scalar type (e.g. int),
47 // but accessed as bit-field type with smaller size.
48 //
49 // Return the truncated constant or nullptr if the constant doesn't need to be truncated.
TruncateUnionConstant(const MIRStructType & unionType,MIRConst * fieldCst,const MIRType & unionFieldType)50 MIRConst *TruncateUnionConstant(const MIRStructType &unionType, MIRConst *fieldCst, const MIRType &unionFieldType)
51 {
52     if (unionType.GetKind() != kTypeUnion) {
53         return nullptr;
54     }
55 
56     auto *bitFieldType = safe_cast<MIRBitFieldType>(unionFieldType);
57     auto *intCst = safe_cast<MIRIntConst>(fieldCst);
58 
59     if (!bitFieldType || !intCst) {
60         return nullptr;
61     }
62 
63     bool isBigEndian = Triple::GetTriple().IsBigEndian();
64 
65     IntVal val = intCst->GetValue();
66     uint8 bitSize = bitFieldType->GetFieldSize();
67 
68     if (bitSize >= val.GetBitWidth()) {
69         return nullptr;
70     }
71 
72     if (isBigEndian) {
73         val = val.LShr(val.GetBitWidth() - bitSize);
74     } else {
75         val = val & ((uint64(1) << bitSize) - 1);
76     }
77 
78     return GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, fieldCst->GetType());
79 }
80 
81 }  // namespace
82 
83 // If size (in byte) is bigger than this threshold, we won't expand memop
84 const uint32 SimplifyMemOp::thresholdMemsetExpand = 512;
85 const uint32 SimplifyMemOp::thresholdMemcpyExpand = 512;
86 const uint32 SimplifyMemOp::thresholdMemsetSExpand = 1024;
87 const uint32 SimplifyMemOp::thresholdMemcpySExpand = 1024;
88 static const uint32 kMaxMemoryBlockSizeToAssign = 8;  // in byte
89 
MayPrintLog(bool debug,bool success,MemOpKind memOpKind,const char * str)90 static void MayPrintLog(bool debug, bool success, MemOpKind memOpKind, const char *str)
91 {
92     if (!debug) {
93         return;
94     }
95     const char *memop = "";
96     if (memOpKind == MEM_OP_memset) {
97         memop = "memset";
98     } else if (memOpKind == MEM_OP_memcpy) {
99         memop = "memcpy";
100     } else if (memOpKind == MEM_OP_memset_s) {
101         memop = "memset_s";
102     } else if (memOpKind == MEM_OP_memcpy_s) {
103         memop = "memcpy_s";
104     }
105     LogInfo::MapleLogger() << memop << " expand " << (success ? "success: " : "failure: ") << str << std::endl;
106 }
107 
IsMathSqrt(const std::string funcName)108 bool Simplify::IsMathSqrt(const std::string funcName)
109 {
110     return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathSqrt) == 0));
111 }
112 
IsMathAbs(const std::string funcName)113 bool Simplify::IsMathAbs(const std::string funcName)
114 {
115     return (mirMod.IsCModule() && (strcmp(funcName.c_str(), kFuncNameOfMathAbs) == 0)) ||
116            (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathAbs) == 0));
117 }
118 
IsMathMax(const std::string funcName)119 bool Simplify::IsMathMax(const std::string funcName)
120 {
121     return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathMax) == 0));
122 }
123 
IsMathMin(const std::string funcName)124 bool Simplify::IsMathMin(const std::string funcName)
125 {
126     return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathMin) == 0));
127 }
128 
SimplifyMathMethod(const StmtNode & stmt,BlockNode & block)129 bool Simplify::SimplifyMathMethod(const StmtNode &stmt, BlockNode &block)
130 {
131     if (stmt.GetOpCode() != OP_callassigned) {
132         return false;
133     }
134     auto &cnode = static_cast<const CallNode &>(stmt);
135     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(cnode.GetPUIdx());
136     DEBUG_ASSERT(calleeFunc != nullptr, "null ptr check");
137     const std::string &funcName = calleeFunc->GetName();
138     if (funcName.empty()) {
139         return false;
140     }
141     if (!mirMod.IsCModule() && !mirMod.IsJavaModule()) {
142         return false;
143     }
144     if (mirMod.IsJavaModule() && funcName.find(kClassNameOfMath) == std::string::npos) {
145         return false;
146     }
147     if (cnode.GetNumOpnds() == 0 || cnode.GetReturnVec().empty()) {
148         return false;
149     }
150 
151     auto *opnd0 = cnode.Opnd(0);
152     DEBUG_ASSERT(opnd0 != nullptr, "null ptr check");
153     auto *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(opnd0->GetPrimType());
154 
155     BaseNode *opExpr = nullptr;
156     if (IsMathSqrt(funcName) && !IsPrimitiveFloat(opnd0->GetPrimType())) {
157         opExpr = builder->CreateExprUnary(OP_sqrt, *type, opnd0);
158     } else if (IsMathAbs(funcName)) {
159         opExpr = builder->CreateExprUnary(OP_abs, *type, opnd0);
160     } else if (IsMathMax(funcName)) {
161         opExpr = builder->CreateExprBinary(OP_max, *type, opnd0, cnode.Opnd(1));
162     } else if (IsMathMin(funcName)) {
163         opExpr = builder->CreateExprBinary(OP_min, *type, opnd0, cnode.Opnd(1));
164     }
165     if (opExpr != nullptr) {
166         auto stIdx = cnode.GetNthReturnVec(0).first;
167         auto *dassign = builder->CreateStmtDassign(stIdx, 0, opExpr);
168         block.ReplaceStmt1WithStmt2(&stmt, dassign);
169         return true;
170     }
171     return false;
172 }
173 
SimplifyCallAssigned(StmtNode & stmt,BlockNode & block)174 void Simplify::SimplifyCallAssigned(StmtNode &stmt, BlockNode &block)
175 {
176     if (SimplifyMathMethod(stmt, block)) {
177         return;
178     }
179     simplifyMemOp.SetDebug(dump);
180     simplifyMemOp.SetFunction(currFunc);
181     if (simplifyMemOp.AutoSimplify(stmt, block, false)) {
182         return;
183     }
184 }
185 
186 constexpr uint32 kUpperLimitOfFieldNum = 10;
GetDassignedStructType(const DassignNode * dassign,MIRFunction * func)187 static MIRStructType *GetDassignedStructType(const DassignNode *dassign, MIRFunction *func)
188 {
189     const auto &lhsStIdx = dassign->GetStIdx();
190     auto lhsSymbol = func->GetLocalOrGlobalSymbol(lhsStIdx);
191     auto lhsAggType = lhsSymbol->GetType();
192     if (!lhsAggType->IsStructType()) {
193         return nullptr;
194     }
195     if (lhsAggType->GetKind() == kTypeUnion) {  // no need to split union's field
196         return nullptr;
197     }
198     auto lhsFieldID = dassign->GetFieldID();
199     if (lhsFieldID != 0) {
200         CHECK_FATAL(lhsAggType->IsStructType(), "only struct has non-zero fieldID");
201         lhsAggType = static_cast<MIRStructType *>(lhsAggType)->GetFieldType(lhsFieldID);
202         if (!lhsAggType->IsStructType()) {
203             return nullptr;
204         }
205         if (lhsAggType->GetKind() == kTypeUnion) {  // no need to split union's field
206             return nullptr;
207         }
208     }
209     if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
210         return nullptr;
211     }
212     return static_cast<MIRStructType *>(lhsAggType);
213 }
214 
GetIassignedStructType(const IassignNode * iassign)215 static MIRStructType *GetIassignedStructType(const IassignNode *iassign)
216 {
217     auto ptrTyIdx = iassign->GetTyIdx();
218     auto *ptrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(ptrTyIdx);
219     CHECK_FATAL(ptrType->IsMIRPtrType(), "must be pointer type");
220     auto aggTyIdx = static_cast<MIRPtrType *>(ptrType)->GetPointedTyIdxWithFieldID(iassign->GetFieldID());
221     auto *lhsAggType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(aggTyIdx);
222     if (!lhsAggType->IsStructType()) {
223         return nullptr;
224     }
225     if (lhsAggType->GetKind() == kTypeUnion) {
226         return nullptr;
227     }
228     if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
229         return nullptr;
230     }
231     return static_cast<MIRStructType *>(lhsAggType);
232 }
233 
GetReadedStructureType(const DreadNode * dread,const MIRFunction * func)234 static MIRStructType *GetReadedStructureType(const DreadNode *dread, const MIRFunction *func)
235 {
236     const auto &rhsStIdx = dread->GetStIdx();
237     auto rhsSymbol = func->GetLocalOrGlobalSymbol(rhsStIdx);
238     auto rhsAggType = rhsSymbol->GetType();
239     auto rhsFieldID = dread->GetFieldID();
240     if (rhsFieldID != 0) {
241         CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
242         rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
243     }
244     if (!rhsAggType->IsStructType()) {
245         return nullptr;
246     }
247     return static_cast<MIRStructType *>(rhsAggType);
248 }
249 
GetReadedStructureType(const IreadNode * iread,const MIRFunction *)250 static MIRStructType *GetReadedStructureType(const IreadNode *iread, const MIRFunction *)
251 {
252     auto rhsPtrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx());
253     CHECK_FATAL(rhsPtrType->IsMIRPtrType(), "must be pointer type");
254     auto rhsAggType = static_cast<MIRPtrType *>(rhsPtrType)->GetPointedType();
255     auto rhsFieldID = iread->GetFieldID();
256     if (rhsFieldID != 0) {
257         CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
258         rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
259     }
260     if (!rhsAggType->IsStructType()) {
261         return nullptr;
262     }
263     return static_cast<MIRStructType *>(rhsAggType);
264 }
265 
266 template <class RhsType, class AssignType>
SplitAggCopy(const AssignType * assignNode,MIRStructType * structureType,BlockNode * block,MIRFunction * func)267 static StmtNode *SplitAggCopy(const AssignType *assignNode, MIRStructType *structureType, BlockNode *block,
268                               MIRFunction *func)
269 {
270     auto *readNode = static_cast<RhsType *>(assignNode->GetRHS());
271     auto rhsFieldID = readNode->GetFieldID();
272     auto *rhsAggType = GetReadedStructureType(readNode, func);
273     if (structureType != rhsAggType) {
274         return nullptr;
275     }
276 
277     for (FieldID id = 1; id <= static_cast<FieldID>(structureType->NumberOfFieldIDs()); ++id) {
278         MIRType *fieldType = structureType->GetFieldType(id);
279         if (fieldType->GetSize() == 0) {
280             continue;  // field size is zero for empty struct/union;
281         }
282         if (fieldType->GetKind() == kTypeBitField && static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
283             continue;  // bitfield size is zero
284         }
285         auto *newDassign = assignNode->CloneTree(func->GetCodeMemPoolAllocator());
286         newDassign->SetFieldID(assignNode->GetFieldID() + id);
287         auto *newRHS = static_cast<RhsType *>(newDassign->GetRHS());
288         newRHS->SetFieldID(rhsFieldID + id);
289         newRHS->SetPrimType(fieldType->GetPrimType());
290         block->InsertAfter(assignNode, newDassign);
291         if (fieldType->IsMIRUnionType()) {
292             id += fieldType->NumberOfFieldIDs();
293         }
294     }
295     auto newAssign = assignNode->GetNext();
296     block->RemoveStmt(assignNode);
297     return newAssign;
298 }
299 
SplitDassignAggCopy(DassignNode * dassign,BlockNode * block,MIRFunction * func)300 static StmtNode *SplitDassignAggCopy(DassignNode *dassign, BlockNode *block, MIRFunction *func)
301 {
302     auto *rhs = dassign->GetRHS();
303     if (rhs->GetPrimType() != PTY_agg) {
304         return nullptr;
305     }
306 
307     auto *lhsAggType = GetDassignedStructType(dassign, func);
308     if (lhsAggType == nullptr) {
309         return nullptr;
310     }
311 
312     if (rhs->GetOpCode() == OP_dread) {
313         auto *lhsSymbol = func->GetLocalOrGlobalSymbol(dassign->GetStIdx());
314         auto *rhsSymbol = func->GetLocalOrGlobalSymbol(static_cast<DreadNode *>(rhs)->GetStIdx());
315         if (!lhsSymbol->IsLocal() && !rhsSymbol->IsLocal()) {
316             return nullptr;
317         }
318 
319         return SplitAggCopy<DreadNode>(dassign, lhsAggType, block, func);
320     } else if (rhs->GetOpCode() == OP_iread) {
321         return SplitAggCopy<IreadNode>(dassign, lhsAggType, block, func);
322     }
323     return nullptr;
324 }
325 
SplitIassignAggCopy(IassignNode * iassign,BlockNode * block,MIRFunction * func)326 static StmtNode *SplitIassignAggCopy(IassignNode *iassign, BlockNode *block, MIRFunction *func)
327 {
328     auto rhs = iassign->GetRHS();
329     if (rhs->GetPrimType() != PTY_agg) {
330         return nullptr;
331     }
332 
333     auto *lhsAggType = GetIassignedStructType(iassign);
334     if (lhsAggType == nullptr) {
335         return nullptr;
336     }
337 
338     if (rhs->GetOpCode() == OP_dread) {
339         return SplitAggCopy<DreadNode>(iassign, lhsAggType, block, func);
340     } else if (rhs->GetOpCode() == OP_iread) {
341         return SplitAggCopy<IreadNode>(iassign, lhsAggType, block, func);
342     }
343     return nullptr;
344 }
345 
UseGlobalVar(const BaseNode * expr)346 bool UseGlobalVar(const BaseNode *expr)
347 {
348     if (expr->GetOpCode() == OP_addrof || expr->GetOpCode() == OP_dread) {
349         StIdx stIdx = static_cast<const AddrofNode *>(expr)->GetStIdx();
350         if (stIdx.IsGlobal()) {
351             return true;
352         }
353     }
354     for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
355         if (UseGlobalVar(expr->Opnd(i))) {
356             return true;
357         }
358     }
359     return false;
360 }
361 
SimplifyToSelect(MIRFunction * func,IfStmtNode * ifNode,BlockNode * block)362 StmtNode *Simplify::SimplifyToSelect(MIRFunction *func, IfStmtNode *ifNode, BlockNode *block)
363 {
364     // Example: if (condition) {
365     //   Example: res = trueRes
366     // Example: }
367     // Example: else {
368     //   Example: res = falseRes
369     // Example: }
370     // =================
371     // res = select condition ? trueRes : falseRes
372     if (ifNode->GetPrev() != nullptr && ifNode->GetPrev()->GetOpCode() == OP_label) {
373         // simplify shortCircuit will stop opt in cfg_opt, and generate extra compare
374         auto *labelNode = static_cast<LabelNode *>(ifNode->GetPrev());
375         const std::string &labelName = func->GetLabelTabItem(labelNode->GetLabelIdx());
376         if (labelName.find("shortCircuit") != std::string::npos) {
377             return nullptr;
378         }
379     }
380     if (ifNode->GetThenPart() == nullptr || ifNode->GetElsePart() == nullptr) {
381         return nullptr;
382     }
383     StmtNode *thenFirst = ifNode->GetThenPart()->GetFirst();
384     StmtNode *elseFirst = ifNode->GetElsePart()->GetFirst();
385     if (thenFirst == nullptr || elseFirst == nullptr) {
386         return nullptr;
387     }
388     // thenpart and elsepart has only one stmt
389     if (thenFirst->GetNext() != nullptr || elseFirst->GetNext() != nullptr) {
390         return nullptr;
391     }
392     if (thenFirst->GetOpCode() != OP_dassign || elseFirst->GetOpCode() != OP_dassign) {
393         return nullptr;
394     }
395     auto *thenDass = static_cast<DassignNode *>(thenFirst);
396     auto *elseDass = static_cast<DassignNode *>(elseFirst);
397     if (thenDass->GetStIdx() != elseDass->GetStIdx() || thenDass->GetFieldID() != elseDass->GetFieldID()) {
398         return nullptr;
399     }
400     // iread has sideeffect : may cause deref error
401     if (HasIreadExpr(thenDass->GetRHS()) || HasIreadExpr(elseDass->GetRHS())) {
402         return nullptr;
403     }
404     // Check if the operand of the select node is complex enough
405     // we should not simplify it to if-then-else for either functionality or performance reason
406     if (thenDass->GetRHS()->GetPrimType() == PTY_agg || elseDass->GetRHS()->GetPrimType() == PTY_agg) {
407         return nullptr;
408     }
409     constexpr size_t maxDepth = 3;
410     if (MaxDepth(thenDass->GetRHS()) > maxDepth || MaxDepth(elseDass->GetRHS()) > maxDepth) {
411         return nullptr;
412     }
413     if (UseGlobalVar(thenDass->GetRHS()) || UseGlobalVar(elseDass->GetRHS())) {
414         return nullptr;
415     }
416     MIRBuilder *mirBuiler = func->GetModule()->GetMIRBuilder();
417     MIRType *type = GlobalTables::GetTypeTable().GetPrimType(thenDass->GetRHS()->GetPrimType());
418     auto *selectExpr =
419         mirBuiler->CreateExprTernary(OP_select, *type, ifNode->Opnd(0), thenDass->GetRHS(), elseDass->GetRHS());
420     auto *newDassign = mirBuiler->CreateStmtDassign(thenDass->GetStIdx(), thenDass->GetFieldID(), selectExpr);
421     newDassign->SetSrcPos(ifNode->GetSrcPos());
422     block->InsertBefore(ifNode, newDassign);
423     block->RemoveStmt(ifNode);
424     return newDassign;
425 }
426 
ProcessStmt(StmtNode & stmt)427 void Simplify::ProcessStmt(StmtNode &stmt)
428 {
429     switch (stmt.GetOpCode()) {
430         case OP_callassigned: {
431             SimplifyCallAssigned(stmt, *currBlock);
432             break;
433         }
434         case OP_intrinsiccall: {
435             simplifyMemOp.SetDebug(dump);
436             simplifyMemOp.SetFunction(currFunc);
437             (void)simplifyMemOp.AutoSimplify(stmt, *currBlock, false);
438             break;
439         }
440         case OP_dassign: {
441             auto *newStmt = SplitDassignAggCopy(static_cast<DassignNode *>(&stmt), currBlock, currFunc);
442             if (newStmt) {
443                 ProcessBlock(*newStmt);
444             }
445             break;
446         }
447         case OP_iassign: {
448             auto *newStmt = SplitIassignAggCopy(static_cast<IassignNode *>(&stmt), currBlock, currFunc);
449             if (newStmt) {
450                 ProcessBlock(*newStmt);
451             }
452             break;
453         }
454         case OP_if:
455         case OP_while:
456         case OP_dowhile: {
457             auto unaryStmt = static_cast<UnaryStmtNode &>(stmt);
458             unaryStmt.SetRHS(SimplifyExpr(*unaryStmt.GetRHS()));
459             return;
460         }
461         default: {
462             break;
463         }
464     }
465     for (size_t i = 0; i < stmt.NumOpnds(); ++i) {
466         if (stmt.Opnd(i)) {
467             stmt.SetOpnd(SimplifyExpr(*stmt.Opnd(i)), i);
468         }
469     }
470 }
471 
SimplifyExpr(BaseNode & expr)472 BaseNode *Simplify::SimplifyExpr(BaseNode &expr)
473 {
474     switch (expr.GetOpCode()) {
475         case OP_dread: {
476             auto &dread = static_cast<DreadNode &>(expr);
477             return ReplaceExprWithConst(dread);
478         }
479         default: {
480             for (auto i = 0; i < expr.GetNumOpnds(); i++) {
481                 if (expr.Opnd(i)) {
482                     expr.SetOpnd(SimplifyExpr(*expr.Opnd(i)), i);
483                 }
484             }
485             break;
486         }
487     }
488     return &expr;
489 }
490 
ReplaceExprWithConst(DreadNode & dread)491 BaseNode *Simplify::ReplaceExprWithConst(DreadNode &dread)
492 {
493     auto stIdx = dread.GetStIdx();
494     auto fieldId = dread.GetFieldID();
495     auto *symbol = currFunc->GetLocalOrGlobalSymbol(stIdx);
496     auto *symbolConst = symbol->GetKonst();
497     if (!currFunc->GetModule()->IsCModule() || !symbolConst || !stIdx.IsGlobal() ||
498         !IsSymbolReplaceableWithConst(*symbol)) {
499         return &dread;
500     }
501     if (fieldId != 0) {
502         symbolConst = GetElementConstFromFieldId(fieldId, symbolConst);
503     }
504     if (!symbolConst || !IsConstRepalceable(*symbolConst)) {
505         return &dread;
506     }
507     return currFunc->GetModule()->GetMIRBuilder()->CreateConstval(symbolConst);
508 }
509 
IsSymbolReplaceableWithConst(const MIRSymbol & symbol) const510 bool Simplify::IsSymbolReplaceableWithConst(const MIRSymbol &symbol) const
511 {
512     return (symbol.GetStorageClass() == kScFstatic && !symbol.HasPotentialAssignment()) ||
513            symbol.GetAttrs().GetAttr(ATTR_const);
514 }
515 
IsConstRepalceable(const MIRConst & mirConst) const516 bool Simplify::IsConstRepalceable(const MIRConst &mirConst) const
517 {
518     switch (mirConst.GetKind()) {
519         case kConstInt:
520         case kConstFloatConst:
521         case kConstDoubleConst:
522         case kConstFloat128Const:
523         case kConstLblConst:
524             return true;
525         default:
526             return false;
527     }
528 }
529 
GetElementConstFromFieldId(FieldID fieldId,MIRConst * mirConst)530 MIRConst *Simplify::GetElementConstFromFieldId(FieldID fieldId, MIRConst *mirConst)
531 {
532     FieldID currFieldId = 1;
533     MIRConst *resultConst = nullptr;
534     auto originAggConst = static_cast<MIRAggConst *>(mirConst);
535     auto originAggType = static_cast<MIRStructType &>(originAggConst->GetType());
536     bool hasReached = false;
537     std::function<void(MIRConst *)> traverseAgg = [&](MIRConst *currConst) {
538         auto *currAggConst = safe_cast<MIRAggConst>(currConst);
539         ASSERT_NOT_NULL(currAggConst);
540         auto *currAggType = safe_cast<MIRStructType>(currAggConst->GetType());
541         ASSERT_NOT_NULL(currAggType);
542         for (size_t iter = 0; iter < currAggType->GetFieldsSize() && !hasReached; ++iter) {
543             size_t constIdx = currAggType->GetKind() == kTypeUnion ? 1 : iter + 1;
544             auto *fieldConst = currAggConst->GetAggConstElement(constIdx);
545             auto *fieldType = originAggType.GetFieldType(currFieldId);
546 
547             if (currFieldId == fieldId) {
548                 if (auto *truncCst = TruncateUnionConstant(*currAggType, fieldConst, *fieldType)) {
549                     resultConst = truncCst;
550                 } else {
551                     resultConst = fieldConst;
552                 }
553 
554                 hasReached = true;
555                 return;
556             }
557 
558             ++currFieldId;
559             if (fieldType->GetKind() == kTypeUnion || fieldType->GetKind() == kTypeStruct) {
560                 traverseAgg(fieldConst);
561             }
562         }
563     };
564     traverseAgg(mirConst);
565     CHECK_FATAL(hasReached, "const not found");
566     return resultConst;
567 }
568 
Finish()569 void Simplify::Finish() {}
570 
571 // Join `num` `byte`s into a number
572 // Example:
573 //   byte   num                output
574 //   0x0a    2                 0x0a0a
575 //   0x12    4             0x12121212
576 //   0xff    8     0xffffffffffffffff
JoinBytes(int byte,uint32 num)577 static uint64 JoinBytes(int byte, uint32 num)
578 {
579     CHECK_FATAL(num <= 8, "not support"); // just support num less or equal 8, see comment above
580     uint64 realByte = static_cast<uint64>(byte % 256);
581     if (realByte == 0) {
582         return 0;
583     }
584     uint64 result = 0;
585     for (uint32 i = 0; i < num; ++i) {
586         result += (realByte << (i * k8BitSize));
587     }
588     return result;
589 }
590 
591 // Return Fold result expr, does not always return a constant expr
592 // Attention: Fold may modify the input expr, if foldExpr is not a nullptr, we should always replace expr with foldExpr
FoldIntConst(BaseNode * expr,uint64 & out,bool & isIntConst)593 static BaseNode *FoldIntConst(BaseNode *expr, uint64 &out, bool &isIntConst)
594 {
595     if (expr->GetOpCode() == OP_constval) {
596         MIRConst *mirConst = static_cast<ConstvalNode *>(expr)->GetConstVal();
597         if (mirConst->GetKind() == kConstInt) {
598             out = static_cast<MIRIntConst *>(mirConst)->GetExtValue();
599             isIntConst = true;
600         }
601         return nullptr;
602     }
603     BaseNode *foldExpr = nullptr;
604     static ConstantFold cf(*theMIRModule);
605     foldExpr = cf.Fold(expr);
606     if (foldExpr != nullptr && foldExpr->GetOpCode() == OP_constval) {
607         MIRConst *mirConst = static_cast<ConstvalNode *>(foldExpr)->GetConstVal();
608         if (mirConst->GetKind() == kConstInt) {
609             out = static_cast<MIRIntConst *>(mirConst)->GetExtValue();
610             isIntConst = true;
611         }
612     }
613     return foldExpr;
614 }
615 
ConstructConstvalNode(uint64 val,PrimType primType,MIRBuilder & mirBuilder)616 static BaseNode *ConstructConstvalNode(uint64 val, PrimType primType, MIRBuilder &mirBuilder)
617 {
618     PrimType constPrimType = primType;
619     if (IsPrimitiveFloat(primType)) {
620         constPrimType = GetIntegerPrimTypeBySizeAndSign(GetPrimTypeBitSize(primType), false);
621     }
622     MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constPrimType));
623     MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, *constType);
624     BaseNode *ret = mirBuilder.CreateConstval(mirConst);
625     if (IsPrimitiveFloat(primType)) {
626         MIRType *floatType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(primType));
627         ret = mirBuilder.CreateExprRetype(*floatType, constPrimType, ret);
628     }
629     return ret;
630 }
631 
ConstructConstvalNode(int64 byte,uint64 num,PrimType primType,MIRBuilder & mirBuilder)632 static BaseNode *ConstructConstvalNode(int64 byte, uint64 num, PrimType primType, MIRBuilder &mirBuilder)
633 {
634     auto val = JoinBytes(byte, static_cast<uint32>(num));
635     return ConstructConstvalNode(val, primType, mirBuilder);
636 }
637 
638 // Input total size of memory, split the memory into several blocks, the max block size is 8 bytes
639 // Example:
640 //   input        output
641 //     40     [ 8, 8, 8, 8, 8 ]
642 //     31     [ 8, 8, 8, 4, 2, 1 ]
SplitMemoryIntoBlocks(size_t totalMemorySize,std::vector<uint32> & blocks)643 static void SplitMemoryIntoBlocks(size_t totalMemorySize, std::vector<uint32> &blocks)
644 {
645     size_t leftSize = totalMemorySize;
646     size_t curBlockSize = kMaxMemoryBlockSizeToAssign;  // max block size in byte
647     while (curBlockSize > 0) {
648         size_t n = leftSize / curBlockSize;
649         blocks.insert(blocks.end(), n, curBlockSize);
650         leftSize -= (n * curBlockSize);
651         curBlockSize = curBlockSize >> 1;
652     }
653 }
654 
IsComplexExpr(const BaseNode * expr,MIRFunction & func)655 static bool IsComplexExpr(const BaseNode *expr, MIRFunction &func)
656 {
657     Opcode op = expr->GetOpCode();
658     if (op == OP_regread) {
659         return false;
660     }
661     if (op == OP_dread) {
662         auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const DreadNode *>(expr)->GetStIdx());
663         if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
664             return true;  // dread global/static var is complex expr because it will be lowered to adrp + add
665         } else {
666             return false;
667         }
668     }
669     if (op == OP_addrof) {
670         auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const AddrofNode *>(expr)->GetStIdx());
671         if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
672             return true;  // addrof global/static var is complex expr because it will be lowered to adrp + add
673         } else {
674             return false;
675         }
676     }
677     return true;
678 }
679 
680 // Input a address expr, output a memEntry to abstract this expr
ComputeMemEntry(BaseNode & expr,MIRFunction & func,MemEntry & memEntry,bool isLowLevel)681 bool MemEntry::ComputeMemEntry(BaseNode &expr, MIRFunction &func, MemEntry &memEntry, bool isLowLevel)
682 {
683     Opcode op = expr.GetOpCode();
684     MIRType *memType = nullptr;
685     switch (op) {
686         case OP_dread: {
687             const auto &concreteExpr = static_cast<const DreadNode &>(expr);
688             auto *symbol = func.GetLocalOrGlobalSymbol(concreteExpr.GetStIdx());
689             MIRType *curType = symbol->GetType();
690             if (concreteExpr.GetFieldID() != 0) {
691                 curType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
692             }
693             // Support kTypeScalar ptr if possible
694             if (curType->GetKind() == kTypePointer) {
695                 memType = static_cast<MIRPtrType *>(curType)->GetPointedType();
696             }
697             break;
698         }
699         case OP_addrof: {
700             const auto &concreteExpr = static_cast<const AddrofNode &>(expr);
701             auto *symbol = func.GetLocalOrGlobalSymbol(concreteExpr.GetStIdx());
702             MIRType *curType = symbol->GetType();
703             if (concreteExpr.GetFieldID() != 0) {
704                 curType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
705             }
706             memType = curType;
707             break;
708         }
709         case OP_iread: {
710             const auto &concreteExpr = static_cast<const IreadNode &>(expr);
711             memType = concreteExpr.GetType();
712             break;
713         }
714         case OP_iaddrof: {  // Do NOT call GetType because it is for OP_iread
715             const auto &concreteExpr = static_cast<const IaddrofNode &>(expr);
716             MIRType *curType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(concreteExpr.GetTyIdx());
717             CHECK_FATAL(curType->IsMIRPtrType(), "must be MIRPtrType");
718             curType = static_cast<MIRPtrType *>(curType)->GetPointedType();
719             CHECK_FATAL(curType->IsStructType(), "must be MIRStructType");
720             memType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
721             break;
722         }
723         case OP_regread: {
724             if (isLowLevel && IsPrimitivePoint(expr.GetPrimType())) {
725                 memEntry.addrExpr = &expr;
726                 memEntry.memType =
727                     nullptr;  // we cannot infer high level memory type, this is allowed for low level expand
728                 return true;
729             }
730             const auto &concreteExpr = static_cast<const RegreadNode &>(expr);
731             MIRPreg *preg = func.GetPregItem(concreteExpr.GetRegIdx());
732             bool isFromDread = (preg->GetOp() == OP_dread);
733             bool isFromAddrof = (preg->GetOp() == OP_addrof);
734             if (isFromDread || isFromAddrof) {
735                 auto *symbol = preg->rematInfo.sym;
736                 auto fieldId = preg->fieldID;
737                 MIRType *curType = symbol->GetType();
738                 if (fieldId != 0) {
739                     curType = static_cast<MIRStructType *>(symbol->GetType())->GetFieldType(fieldId);
740                 }
741                 if (isFromDread && curType->GetKind() == kTypePointer) {
742                     curType = static_cast<MIRPtrType *>(curType)->GetPointedType();
743                 }
744                 memType = curType;
745             }
746             break;
747         }
748         default: {
749             if (isLowLevel && IsPrimitivePoint(expr.GetPrimType())) {
750                 memEntry.addrExpr = &expr;
751                 memEntry.memType =
752                     nullptr;  // we cannot infer high level memory type, this is allowed for low level expand
753                 return true;
754             }
755             break;
756         }
757     }
758     if (memType == nullptr) {
759         return false;
760     }
761     memEntry.addrExpr = &expr;
762     memEntry.memType = memType;
763     return true;
764 }
765 
BuildAsRhsExpr(MIRFunction & func) const766 BaseNode *MemEntry::BuildAsRhsExpr(MIRFunction &func) const
767 {
768     BaseNode *expr = nullptr;
769     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
770     if (addrExpr->GetOpCode() == OP_addrof) {
771         // We prefer dread to iread
772         // consider iaddrof if possible
773         auto *addrof = static_cast<AddrofNode *>(addrExpr);
774         auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
775         expr = mirBuilder->CreateExprDread(*memType, addrof->GetFieldID(), *symbol);
776     } else {
777         MIRType *structPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
778         expr = mirBuilder->CreateExprIread(*memType, *structPtrType, 0, addrExpr);
779     }
780     return expr;
781 }
782 
InsertAndMayPrintStmt(BlockNode & block,const StmtNode & anchor,bool debug,StmtNode * stmt)783 static void InsertAndMayPrintStmt(BlockNode &block, const StmtNode &anchor, bool debug, StmtNode *stmt)
784 {
785     if (stmt == nullptr) {
786         return;
787     }
788     block.InsertBefore(&anchor, stmt);
789     if (debug) {
790         stmt->Dump(0);
791     }
792 }
793 
InsertBeforeAndMayPrintStmtList(BlockNode & block,const StmtNode & anchor,bool debug,std::initializer_list<StmtNode * > stmtList)794 static void InsertBeforeAndMayPrintStmtList(BlockNode &block, const StmtNode &anchor, bool debug,
795                                             std::initializer_list<StmtNode *> stmtList)
796 {
797     for (StmtNode *stmt : stmtList) {
798         if (stmt == nullptr) {
799             continue;
800         }
801         block.InsertBefore(&anchor, stmt);
802         if (debug) {
803             stmt->Dump(0);
804         }
805     }
806 }
807 
NeedCheck(MemOpKind memOpKind)808 static bool NeedCheck(MemOpKind memOpKind)
809 {
810     if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
811         return true;
812     }
813     return false;
814 }
815 
816 // Create maple IR to check whether `expr` is a null pointer, IR is as follows:
817 //   brfalse @@n1 (ne u8 ptr (regread ptr %1, constval u64 0))
CreateNullptrCheckStmt(BaseNode & expr,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)818 static CondGotoNode *CreateNullptrCheckStmt(BaseNode &expr, MIRFunction &func, MIRBuilder *mirBuilder,
819                                             const MIRType &cmpResType, const MIRType &cmpOpndType)
820 {
821     LabelIdx nullLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('n');  // 'n' means nullptr
822     auto *checkExpr = mirBuilder->CreateExprCompare(OP_ne, cmpResType, cmpOpndType, &expr,
823                                                     ConstructConstvalNode(0, PTY_u64, *mirBuilder));
824     auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, nullLabIdx);
825     return checkStmt;
826 }
827 
828 // Create maple IR to check whether `expr1` and `expr2` are equal
829 // brfalse @@a1 (ne u8 ptr (regread ptr %1, regread ptr %2))
CreateAddressEqualCheckStmt(BaseNode & expr1,BaseNode & expr2,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)830 static CondGotoNode *CreateAddressEqualCheckStmt(BaseNode &expr1, BaseNode &expr2, MIRFunction &func,
831                                                  MIRBuilder *mirBuilder, const MIRType &cmpResType,
832                                                  const MIRType &cmpOpndType)
833 {
834     LabelIdx equalLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('a');  // 'a' means address equal
835     auto *checkExpr = mirBuilder->CreateExprCompare(OP_ne, cmpResType, cmpOpndType, &expr1, &expr2);
836     auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, equalLabIdx);
837     return checkStmt;
838 }
839 
840 // Create maple IR to check whether `expr1` and `expr2` are overlapped
841 // brfalse @@o1 (ge u8 ptr (
842 //   abs ptr (sub ptr (regread ptr %1, regread ptr %2)),
843 //   constval u64 xxx))
CreateOverlapCheckStmt(BaseNode & expr1,BaseNode & expr2,uint32 size,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)844 static CondGotoNode *CreateOverlapCheckStmt(BaseNode &expr1, BaseNode &expr2, uint32 size, MIRFunction &func,
845                                             MIRBuilder *mirBuilder, const MIRType &cmpResType,
846                                             const MIRType &cmpOpndType)
847 {
848     LabelIdx overlapLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('o');  // 'n' means overlap
849     auto *checkExpr = mirBuilder->CreateExprCompare(
850         OP_ge, cmpResType, cmpOpndType,
851         mirBuilder->CreateExprUnary(OP_abs, cmpOpndType,
852                                     mirBuilder->CreateExprBinary(OP_sub, cmpOpndType, &expr1, &expr2)),
853         ConstructConstvalNode(size, PTY_u64, *mirBuilder));
854     auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, overlapLabIdx);
855     return checkStmt;
856 }
857 
858 // Generate IR to handle nullptr, IR is as follows:
859 //   @curLabel
860 //   regassign i32 %1 (constval i32 errNum)
861 //   goto @finalLabel
AddNullptrHandlerIR(const StmtNode & stmt,MIRBuilder * mirBuilder,BlockNode & block,StmtNode * retAssign,LabelIdx curLabIdx,LabelIdx finalLabIdx,bool debug)862 static void AddNullptrHandlerIR(const StmtNode &stmt, MIRBuilder *mirBuilder, BlockNode &block, StmtNode *retAssign,
863                                 LabelIdx curLabIdx, LabelIdx finalLabIdx, bool debug)
864 {
865     auto *curLabelNode = mirBuilder->CreateStmtLabel(curLabIdx);
866     auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
867     InsertBeforeAndMayPrintStmtList(block, stmt, debug, {curLabelNode, retAssign, gotoFinal});
868 }
869 
AddMemsetCallStmt(const StmtNode & stmt,MIRFunction & func,BlockNode & block,BaseNode * addrExpr)870 static void AddMemsetCallStmt(const StmtNode &stmt, MIRFunction &func, BlockNode &block, BaseNode *addrExpr)
871 {
872     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
873     // call memset for dst memory when detecting overlapping
874     MapleVector<BaseNode *> args(mirBuilder->GetCurrentFuncCodeMpAllocator()->Adapter());
875     args.push_back(addrExpr);
876     args.push_back(ConstructConstvalNode(0, PTY_i32, *mirBuilder));
877     args.push_back(stmt.Opnd(1));
878     auto *callMemset = mirBuilder->CreateStmtCall("memset", args);
879     block.InsertBefore(&stmt, callMemset);
880 }
881 
882 // Generate IR to handle errors that should be reset with memset, IR is as follows:
883 //   @curLabel
884 //   regassign i32 %1 (constval i32 errNum)
885 //   call memset  # new genrated memset will be expanded if possible
886 //   goto @finalLabel
AddResetHandlerIR(const StmtNode & stmt,MIRFunction & func,BlockNode & block,StmtNode * retAssign,LabelIdx curLabIdx,LabelIdx finalLabIdx,BaseNode * addrExpr,bool debug)887 static void AddResetHandlerIR(const StmtNode &stmt, MIRFunction &func, BlockNode &block, StmtNode *retAssign,
888                               LabelIdx curLabIdx, LabelIdx finalLabIdx, BaseNode *addrExpr, bool debug)
889 {
890     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
891     auto *curLabelNode = mirBuilder->CreateStmtLabel(curLabIdx);
892     InsertBeforeAndMayPrintStmtList(block, stmt, debug, {curLabelNode, retAssign});
893     AddMemsetCallStmt(stmt, func, block, addrExpr);
894     auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
895     InsertAndMayPrintStmt(block, stmt, debug, gotoFinal);
896 }
897 
TryToExtractComplexExpr(BaseNode * expr,MIRFunction & func,BlockNode & block,const StmtNode & anchor,bool debug)898 static BaseNode *TryToExtractComplexExpr(BaseNode *expr, MIRFunction &func, BlockNode &block, const StmtNode &anchor,
899                                          bool debug)
900 {
901     if (!IsComplexExpr(expr, func)) {
902         return expr;
903     }
904     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
905     auto pregIdx = func.GetPregTab()->CreatePreg(PTY_ptr);
906     StmtNode *regassign = mirBuilder->CreateStmtRegassign(PTY_ptr, pregIdx, expr);
907     InsertAndMayPrintStmt(block, anchor, debug, regassign);
908     auto *extractedExpr = mirBuilder->CreateExprRegread(PTY_ptr, pregIdx);
909     return extractedExpr;
910 }
911 
InsertCheckFailedBranch(MIRFunction & func,StmtNode & stmt,BlockNode & block,LabelIdx branchLabIdx,LabelIdx finalLabIdx,ErrorNumber errNumber,MemOpKind memOpKind,bool debug)912 static void InsertCheckFailedBranch(MIRFunction &func, StmtNode &stmt, BlockNode &block, LabelIdx branchLabIdx,
913                                     LabelIdx finalLabIdx, ErrorNumber errNumber, MemOpKind memOpKind, bool debug)
914 {
915     auto mirBuilder = func.GetModule()->GetMIRBuilder();
916     auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
917     auto branchLabNode = mirBuilder->CreateStmtLabel(branchLabIdx);
918     auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, memOpKind, errNumber);
919     InsertBeforeAndMayPrintStmtList(block, stmt, debug, {branchLabNode, errnoAssign, gotoFinal});
920 }
921 
InsertMemsetCallStmt(const MapleVector<BaseNode * > & args,MIRFunction & func,StmtNode & stmt,BlockNode & block,LabelIdx finalLabIdx,ErrorNumber errorNumber,bool debug)922 static StmtNode *InsertMemsetCallStmt(const MapleVector<BaseNode *> &args, MIRFunction &func, StmtNode &stmt,
923                                       BlockNode &block, LabelIdx finalLabIdx, ErrorNumber errorNumber, bool debug)
924 {
925     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
926     auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
927     auto memsetFunc = mirBuilder->GetOrCreateFunction(kFuncNameOfMemset, TyIdx(PTY_void));
928     auto memsetCallStmt = mirBuilder->CreateStmtCallAssigned(memsetFunc->GetPuidx(), args, nullptr, OP_callassigned);
929     memsetCallStmt->SetSrcPos(stmt.GetSrcPos());
930     auto *errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, errorNumber);
931     InsertBeforeAndMayPrintStmtList(block, stmt, debug, {memsetCallStmt, errnoAssign, gotoFinal});
932     return memsetCallStmt;
933 }
934 
CreateAndInsertCheckStmt(Opcode op,BaseNode * lhs,BaseNode * rhs,LabelIdx label,StmtNode & stmt,BlockNode & block,MIRFunction & func,bool debug)935 static void CreateAndInsertCheckStmt(Opcode op, BaseNode *lhs, BaseNode *rhs, LabelIdx label, StmtNode &stmt,
936                                      BlockNode &block, MIRFunction &func, bool debug)
937 {
938     auto cmpResType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u8));
939     auto cmpU64Type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u64));
940     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
941     auto cmpStmt = mirBuilder->CreateExprCompare(op, *cmpResType, *cmpU64Type, lhs, rhs);
942     auto checkStmt = mirBuilder->CreateStmtCondGoto(cmpStmt, OP_brtrue, label);
943     checkStmt->SetBranchProb(kProbUnlikely);
944     checkStmt->SetSrcPos(stmt.GetSrcPos());
945     InsertAndMayPrintStmt(block, stmt, debug, checkStmt);
946 }
947 
ExpandOnSrcSizeGtDstSize(StmtNode & stmt,BlockNode & block,int64 srcSize,LabelIdx finalLabIdx,LabelIdx nullPtrLabIdx,MIRFunction & func,bool debug)948 static StmtNode *ExpandOnSrcSizeGtDstSize(StmtNode &stmt, BlockNode &block, int64 srcSize, LabelIdx finalLabIdx,
949                                           LabelIdx nullPtrLabIdx, MIRFunction &func, bool debug)
950 {
951     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
952     MapleVector<BaseNode *> args(func.GetCodeMempoolAllocator().Adapter());
953     args.push_back(stmt.Opnd(kMemsetDstOpndIdx));
954     args.push_back(stmt.Opnd(kMemsetSSrcOpndIdx));
955     args.push_back(ConstructConstvalNode(srcSize, stmt.Opnd(kMemsetSSrcSizeOpndIdx)->GetPrimType(), *mirBuilder));
956     auto memsetFunc = mirBuilder->GetOrCreateFunction(kFuncNameOfMemset, TyIdx(PTY_void));
957     auto callStmt = mirBuilder->CreateStmtCallAssigned(memsetFunc->GetPuidx(), args, nullptr, OP_callassigned);
958     callStmt->SetSrcPos(stmt.GetSrcPos());
959     InsertAndMayPrintStmt(block, stmt, debug, callStmt);
960     auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
961     auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, ERRNO_RANGE_AND_RESET);
962     InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
963     InsertCheckFailedBranch(func, stmt, block, nullPtrLabIdx, finalLabIdx, ERRNO_INVAL, MEM_OP_memset_s, debug);
964     auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
965     InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
966     block.RemoveStmt(&stmt);
967     return callStmt;
968 }
969 
HandleZeroValueOfDstSize(StmtNode & stmt,BlockNode & block,int64 srcSize,int64 dstSize,LabelIdx finalLabIdx,LabelIdx dstSizeCheckLabIdx,MIRFunction & func,bool isDstSizeConst,bool debug)970 static void HandleZeroValueOfDstSize(StmtNode &stmt, BlockNode &block, int64 srcSize, int64 dstSize,
971                                      LabelIdx finalLabIdx, LabelIdx dstSizeCheckLabIdx, MIRFunction &func,
972                                      bool isDstSizeConst, bool debug)
973 {
974     uint32 dstSizeOpndIdx = 1;  // only used by memset_s
975     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
976     if (!isDstSizeConst) {
977         CreateAndInsertCheckStmt(OP_eq, stmt.Opnd(dstSizeOpndIdx), ConstructConstvalNode(0, PTY_u64, *mirBuilder),
978                                  dstSizeCheckLabIdx, stmt, block, func, debug);
979     } else if (dstSize == 0) {
980         auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
981         auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, ERRNO_RANGE);
982         InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
983     }
984 }
985 
ExpandMemsetLowLevel(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const986 void MemEntry::ExpandMemsetLowLevel(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
987                                     MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
988 {
989     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
990     std::vector<uint32> blocks;
991     SplitMemoryIntoBlocks(size, blocks);
992     int32 offset = 0;
993     // If blocks.size() > 1 and `dst` is not a leaf node,
994     // we should extract common expr to avoid redundant expression
995     BaseNode *realDstExpr = addrExpr;
996     if (blocks.size() > 1) {
997         realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
998     }
999     BaseNode *readConst = nullptr;
1000     // rhs const is big, extract it to avoid redundant expression
1001     bool shouldExtractRhs = blocks.size() > 1 && (byte & 0xff) != 0;
1002     for (auto curSize : blocks) {
1003         // low level memset expand result:
1004         //   iassignoff <prim-type> <offset> (dstAddrExpr, constval <prim-type> xx)
1005         PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1006         BaseNode *rhsExpr = ConstructConstvalNode(byte, curSize, constType, *mirBuilder);
1007         if (shouldExtractRhs) {
1008             // we only need to extract u64 const once
1009             PregIdx pregIdx = func.GetPregTab()->CreatePreg(constType);
1010             auto *constAssign = mirBuilder->CreateStmtRegassign(constType, pregIdx, rhsExpr);
1011             InsertAndMayPrintStmt(block, stmt, debug, constAssign);
1012             readConst = mirBuilder->CreateExprRegread(constType, pregIdx);
1013             shouldExtractRhs = false;
1014         }
1015         if (readConst != nullptr && curSize == kMaxMemoryBlockSizeToAssign) {
1016             rhsExpr = readConst;
1017         }
1018         auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
1019         InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
1020         if (debug) {
1021             iassignoff->Dump(0);
1022         }
1023         offset += static_cast<int32>(curSize);
1024     }
1025     // handle memset return val
1026     auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1027     InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1028     // return ERRNO_INVAL if memset_s dest is NULL
1029     block.RemoveStmt(&stmt);
1030 }
1031 
1032 // Lower memset(MemEntry, byte, size) into a series of assign stmts and replace callStmt in the block
1033 // with these assign stmts
ExpandMemset(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,bool isLowLevel,bool debug,ErrorNumber errorNumber) const1034 bool MemEntry::ExpandMemset(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
1035                             bool isLowLevel, bool debug, ErrorNumber errorNumber) const
1036 {
1037     MemOpKind memOpKind = SimplifyMemOp::ComputeMemOpKind(stmt);
1038     MemEntryKind memKind = GetKind();
1039     // we don't check size equality in the low level expand
1040     if (!isLowLevel) {
1041         if (memKind == kMemEntryUnknown) {
1042             MayPrintLog(debug, false, memOpKind, "unsupported dst memory type, is it a bitfield?");
1043             return false;
1044         }
1045         if (memType->GetSize() != size) {
1046             MayPrintLog(debug, false, memOpKind, "dst size and size arg are not equal");
1047             return false;
1048         }
1049     }
1050     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1051 
1052     if (isLowLevel) {  // For cglower, replace memset with a series of low-level iassignoff
1053         ExpandMemsetLowLevel(byte, size, func, stmt, block, memOpKind, debug, errorNumber);
1054         return true;
1055     }
1056 
1057     if (memKind == kMemEntryPrimitive) {
1058         BaseNode *rhsExpr = ConstructConstvalNode(byte, size, memType->GetPrimType(), *mirBuilder);
1059         StmtNode *newAssign = nullptr;
1060         if (addrExpr->GetOpCode() == OP_addrof) {  // We prefer dassign to iassign
1061             auto *addrof = static_cast<AddrofNode *>(addrExpr);
1062             auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1063             newAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID(), rhsExpr);
1064         } else {
1065             MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1066             newAssign = mirBuilder->CreateStmtIassign(*memPtrType, 0, addrExpr, rhsExpr);
1067         }
1068         InsertAndMayPrintStmt(block, stmt, debug, newAssign);
1069     } else if (memKind == kMemEntryStruct) {
1070         auto *structType = static_cast<MIRStructType *>(memType);
1071         // struct size should be small enough, struct field size should be big enough
1072         constexpr uint32 maxStructSize = 64;  // in byte
1073         constexpr uint32 minFieldSize = 4;    // in byte
1074         size_t structSize = structType->GetSize();
1075         size_t numFields = structType->NumberOfFieldIDs();
1076         // Relax restrictions when store-merge is powerful enough
1077         bool expandIt =
1078             (structSize <= maxStructSize && (structSize / numFields >= minFieldSize) && !structType->HasPadding());
1079         if (!expandIt) {
1080             // We only expand memset for no-padding struct, because only in this case, element-wise and byte-wise
1081             // are equivalent
1082             MayPrintLog(debug, false, memOpKind,
1083                         "struct type has padding, or struct sum size is too big, or filed size is too small");
1084             return false;
1085         }
1086         bool hasArrayField = false;
1087         for (uint32 id = 1; id <= numFields; ++id) {
1088             auto *fieldType = structType->GetFieldType(id);
1089             if (fieldType->GetKind() == kTypeArray) {
1090                 hasArrayField = true;
1091                 break;
1092             }
1093         }
1094         if (hasArrayField) {
1095             // struct with array fields is not supported to expand for now, enhance it when needed
1096             MayPrintLog(debug, false, memOpKind, "struct with array fields is not supported to expand");
1097             return false;
1098         }
1099 
1100         // Build assign for each fields in the struct type
1101         // We should skip union fields
1102         for (FieldID id = 1; static_cast<size_t>(id) <= numFields; ++id) {
1103             MIRType *fieldType = structType->GetFieldType(id);
1104             // We only consider leaf field with valid type size
1105             if (fieldType->GetSize() == 0 || fieldType->GetPrimType() == PTY_agg) {
1106                 continue;
1107             }
1108             if (fieldType->GetKind() == kTypeBitField &&
1109                 static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
1110                 continue;
1111             }
1112             // now the fieldType is primitive type
1113             BaseNode *rhsExpr =
1114                 ConstructConstvalNode(byte, fieldType->GetSize(), fieldType->GetPrimType(), *mirBuilder);
1115             StmtNode *fieldAssign = nullptr;
1116             if (addrExpr->GetOpCode() == OP_addrof) {
1117                 auto *addrof = static_cast<AddrofNode *>(addrExpr);
1118                 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1119                 fieldAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID() + id, rhsExpr);
1120             } else {
1121                 MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1122                 fieldAssign = mirBuilder->CreateStmtIassign(*memPtrType, id, addrExpr, rhsExpr);
1123             }
1124             InsertAndMayPrintStmt(block, stmt, debug, fieldAssign);
1125         }
1126     } else if (memKind == kMemEntryArray) {
1127         // We only consider array with dim == 1 now, and element type must be primitive type
1128         auto *arrayType = static_cast<MIRArrayType *>(memType);
1129         if (arrayType->GetDim() != 1 || (arrayType->GetElemType()->GetKind() != kTypeScalar &&
1130                                          arrayType->GetElemType()->GetKind() != kTypePointer)) {
1131             MayPrintLog(debug, false, memOpKind, "array dim != 1 or array elements are not primtive type");
1132             return false;
1133         }
1134         MIRType *elemType = arrayType->GetElemType();
1135         if (elemType->GetSize() < k4BitSize) {
1136             MayPrintLog(debug, false, memOpKind,
1137                         "element size < 4, don't expand it to  avoid to genearte lots of strb/strh");
1138             return false;
1139         }
1140         uint64 elemCnt = static_cast<uint64>(arrayType->GetSizeArrayItem(0));
1141         if (elemType->GetSize() * elemCnt != size) {
1142             MayPrintLog(debug, false, memOpKind, "array size not equal");
1143             return false;
1144         }
1145         for (size_t i = 0; i < elemCnt; ++i) {
1146             BaseNode *indexExpr = ConstructConstvalNode(i, PTY_u32, *mirBuilder);
1147             auto *arrayExpr = mirBuilder->CreateExprArray(*arrayType, addrExpr, indexExpr);
1148             auto *newValOpnd = ConstructConstvalNode(byte, elemType->GetSize(), elemType->GetPrimType(), *mirBuilder);
1149             MIRType *elemPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*elemType);
1150             auto *arrayElementAssign = mirBuilder->CreateStmtIassign(*elemPtrType, 0, arrayExpr, newValOpnd);
1151             InsertAndMayPrintStmt(block, stmt, debug, arrayElementAssign);
1152         }
1153     } else {
1154         CHECK_FATAL(false, "impossible");
1155     }
1156 
1157     // handle memset return val
1158     auto *retAssign = GenMemopRetAssign(stmt, func, isLowLevel, memOpKind, errorNumber);
1159     InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1160     block.RemoveStmt(&stmt);
1161     return true;
1162 }
1163 
GenerateMemoryCopyPair(MIRBuilder * mirBuilder,BaseNode * rhs,BaseNode * lhs,uint32 offset,uint32 curSize,PregIdx tmpRegIdx)1164 static std::pair<StmtNode *, StmtNode *> GenerateMemoryCopyPair(MIRBuilder *mirBuilder, BaseNode *rhs, BaseNode *lhs,
1165                                                                 uint32 offset, uint32 curSize, PregIdx tmpRegIdx)
1166 {
1167     auto *ptrType = GlobalTables::GetTypeTable().GetPtrType();
1168     PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1169     MIRType *constMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constType));
1170     auto *constMIRPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*constMIRType);
1171     BaseNode *rhsAddrExpr = rhs;
1172     if (offset != 0) {
1173         auto *offsetConstExpr = ConstructConstvalNode(offset, PTY_u64, *mirBuilder);
1174         rhsAddrExpr = mirBuilder->CreateExprBinary(OP_add, *ptrType, rhs, offsetConstExpr);
1175     }
1176     BaseNode *rhsExpr = mirBuilder->CreateExprIread(*constMIRType, *constMIRPtrType, 0, rhsAddrExpr);
1177     auto *regassign = mirBuilder->CreateStmtRegassign(PTY_u64, tmpRegIdx, rhsExpr);
1178     auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, static_cast<int32>(offset), lhs,
1179                                                         mirBuilder->CreateExprRegread(PTY_u64, tmpRegIdx));
1180     return {regassign, iassignoff};
1181 }
1182 
ExpandMemcpyLowLevel(const MemEntry & srcMem,uint64 copySize,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const1183 void MemEntry::ExpandMemcpyLowLevel(const MemEntry &srcMem, uint64 copySize, MIRFunction &func, StmtNode &stmt,
1184                                     BlockNode &block, MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
1185 {
1186     if (errorNumber == ERRNO_RANGE) {
1187         auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1188         InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1189         block.RemoveStmt(&stmt);
1190         return;
1191     }
1192     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1193     std::vector<uint32> blocks;
1194     SplitMemoryIntoBlocks(copySize, blocks);
1195     uint32 offset = 0;
1196     // If blocks.size() > 1 and `src` or `dst` is not a leaf node,
1197     // we should extract common expr to avoid redundant expression
1198     BaseNode *realSrcExpr = srcMem.addrExpr;
1199     BaseNode *realDstExpr = addrExpr;
1200     if (blocks.size() > 1) {
1201         realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
1202         realSrcExpr = TryToExtractComplexExpr(srcMem.addrExpr, func, block, stmt, debug);
1203     }
1204     auto *ptrType = GlobalTables::GetTypeTable().GetPtrType();
1205     LabelIdx dstNullLabIdx;
1206     LabelIdx srcNullLabIdx;
1207     LabelIdx overlapLabIdx;
1208     LabelIdx addressEqualLabIdx;
1209     if (NeedCheck(memOpKind)) {
1210         auto *cmpResType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u8));
1211         auto *cmpOpndType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_ptr));
1212         // check dst != NULL
1213         auto *checkDstStmt = CreateNullptrCheckStmt(*realDstExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1214         dstNullLabIdx = checkDstStmt->GetOffset();
1215         // check src != NULL
1216         auto *checkSrcStmt = CreateNullptrCheckStmt(*realSrcExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1217         srcNullLabIdx = checkSrcStmt->GetOffset();
1218         InsertBeforeAndMayPrintStmtList(block, stmt, debug, {checkDstStmt, checkSrcStmt});
1219         if (errorNumber != ERRNO_RANGE_AND_RESET) {
1220             // check src == dst
1221             auto *checkAddrEqualStmt =
1222                 CreateAddressEqualCheckStmt(*realDstExpr, *realSrcExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1223             addressEqualLabIdx = checkAddrEqualStmt->GetOffset();
1224             // check overlap
1225             auto *checkOverlapStmt = CreateOverlapCheckStmt(*realDstExpr, *realSrcExpr, static_cast<uint32>(copySize),
1226                                                             func, mirBuilder, *cmpResType, *cmpOpndType);
1227             overlapLabIdx = checkOverlapStmt->GetOffset();
1228             InsertBeforeAndMayPrintStmtList(block, stmt, debug, {checkAddrEqualStmt, checkOverlapStmt});
1229         }
1230     }
1231     if (errorNumber == ERRNO_RANGE_AND_RESET) {
1232         AddMemsetCallStmt(stmt, func, block, addrExpr);
1233     } else {
1234         // memory copy optimization
1235         PregIdx tmpRegIdx1 = 0;
1236         PregIdx tmpRegIdx2 = 0;
1237         for (uint32 i = 0; i < blocks.size(); ++i) {
1238             uint32 curSize = blocks[i];
1239             bool canMergedWithNextSize = (i + 1 < blocks.size()) && blocks[i + 1] == curSize;
1240             if (!canMergedWithNextSize) {
1241                 // low level memcpy expand result:
1242                 // It seems ireadoff has not been supported by cg HandleFunction, so we use iread instead of ireadoff
1243                 // [not support] iassignoff <prim-type> <offset> (dstAddrExpr, ireadoff <prim-type> <offset>
1244                 // (srcAddrExpr)) [ok] iassignoff <prim-type> <offset> (dstAddrExpr, iread <prim-type> <type> (add ptr
1245                 // (srcAddrExpr, offset)))
1246                 PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1247                 MIRType *constMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constType));
1248                 auto *constMIRPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*constMIRType);
1249                 BaseNode *rhsAddrExpr = realSrcExpr;
1250                 if (offset != 0) {
1251                     auto *offsetConstExpr = ConstructConstvalNode(offset, PTY_u64, *mirBuilder);
1252                     rhsAddrExpr = mirBuilder->CreateExprBinary(OP_add, *ptrType, realSrcExpr, offsetConstExpr);
1253                 }
1254                 BaseNode *rhsExpr = mirBuilder->CreateExprIread(*constMIRType, *constMIRPtrType, 0, rhsAddrExpr);
1255                 auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
1256                 InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
1257                 offset += curSize;
1258                 continue;
1259             }
1260 
1261             // merge two str/ldr into a stp/ldp
1262             if (tmpRegIdx1 == 0 || tmpRegIdx2 == 0) {
1263                 tmpRegIdx1 = func.GetPregTab()->CreatePreg(PTY_u64);
1264                 tmpRegIdx2 = func.GetPregTab()->CreatePreg(PTY_u64);
1265             }
1266             auto pair1 = GenerateMemoryCopyPair(mirBuilder, realSrcExpr, realDstExpr, offset, curSize, tmpRegIdx1);
1267             auto pair2 =
1268                 GenerateMemoryCopyPair(mirBuilder, realSrcExpr, realDstExpr, offset + curSize, curSize, tmpRegIdx2);
1269             // insert order: regassign1, regassign2, iassignoff1, iassignoff2
1270             InsertBeforeAndMayPrintStmtList(block, stmt, debug, {pair1.first, pair2.first, pair1.second, pair2.second});
1271             offset += (curSize << 1);
1272             ++i;
1273         }
1274     }
1275     // handle memcpy return val
1276     auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1277     InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1278     if (NeedCheck(memOpKind)) {
1279         LabelIdx finalLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('f');
1280         auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
1281         // Add goto final stmt for expanded body
1282         auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
1283         InsertAndMayPrintStmt(block, stmt, debug, gotoFinal);
1284         // Add handler IR if dst == NULL
1285         auto *dstErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_INVAL);
1286         AddNullptrHandlerIR(stmt, mirBuilder, block, dstErrAssign, dstNullLabIdx, finalLabIdx, debug);
1287         // Add handler IR if src == NULL
1288         auto *srcErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_INVAL_AND_RESET);
1289         AddResetHandlerIR(stmt, func, block, srcErrAssign, srcNullLabIdx, finalLabIdx, addrExpr, debug);
1290         if (errorNumber != ERRNO_RANGE_AND_RESET) {
1291             // Add handler IR if dst == src
1292             auto *addrEqualAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_OK);
1293             AddNullptrHandlerIR(stmt, mirBuilder, block, addrEqualAssign, addressEqualLabIdx, finalLabIdx, debug);
1294             // Add handler IR if dst and src are overlapped
1295             auto *overlapErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_OVERLAP_AND_RESET);
1296             AddResetHandlerIR(stmt, func, block, overlapErrAssign, overlapLabIdx, finalLabIdx, addrExpr, debug);
1297         }
1298         InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
1299     }
1300     block.RemoveStmt(&stmt);
1301 }
1302 
ExpandMemcpy(const MemEntry & srcMem,uint64 copySize,MIRFunction & func,StmtNode & stmt,BlockNode & block,bool isLowLevel,bool debug,ErrorNumber errorNumber) const1303 bool MemEntry::ExpandMemcpy(const MemEntry &srcMem, uint64 copySize, MIRFunction &func, StmtNode &stmt,
1304                             BlockNode &block, bool isLowLevel, bool debug, ErrorNumber errorNumber) const
1305 {
1306     MemOpKind memOpKind = SimplifyMemOp::ComputeMemOpKind(stmt);
1307     MemEntryKind memKind = GetKind();
1308     if (!isLowLevel) {  // check type consistency and memKind only for high level expand
1309         if (memOpKind == MEM_OP_memcpy_s) {
1310             MayPrintLog(debug, false, memOpKind, "all memcpy_s will be handed by cglower");
1311             return false;
1312         }
1313         if (memType != srcMem.memType) {
1314             return false;
1315         }
1316         CHECK_FATAL(memKind != kMemEntryUnknown, "invalid memKind");
1317     }
1318     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1319     StmtNode *newAssign = nullptr;
1320     if (isLowLevel) {  // For cglower, replace memcpy with a series of low-level iassignoff
1321         ExpandMemcpyLowLevel(srcMem, copySize, func, stmt, block, memOpKind, debug, errorNumber);
1322         return true;
1323     }
1324 
1325     if (memKind == kMemEntryPrimitive || memKind == kMemEntryStruct) {
1326         // Do low level expand for all struct memcpy for now
1327         if (memKind == kMemEntryStruct) {
1328             MayPrintLog(debug, false, memOpKind, "Do low level expand for all struct memcpy for now");
1329             return false;
1330         }
1331         if (addrExpr->GetOpCode() == OP_addrof) {  // We prefer dassign to iassign
1332             auto *addrof = static_cast<AddrofNode *>(addrExpr);
1333             auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1334             newAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID(), srcMem.BuildAsRhsExpr(func));
1335         } else {
1336             MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1337             newAssign = mirBuilder->CreateStmtIassign(*memPtrType, 0, addrExpr, srcMem.BuildAsRhsExpr(func));
1338         }
1339         InsertAndMayPrintStmt(block, stmt, debug, newAssign);
1340         // split struct agg copy
1341         if (memKind == kMemEntryStruct) {
1342             if (newAssign->GetOpCode() == OP_dassign) {
1343                 (void)SplitDassignAggCopy(static_cast<DassignNode *>(newAssign), &block, &func);
1344             } else if (newAssign->GetOpCode() == OP_iassign) {
1345                 (void)SplitIassignAggCopy(static_cast<IassignNode *>(newAssign), &block, &func);
1346             } else {
1347                 CHECK_FATAL(false, "impossible");
1348             }
1349         }
1350     } else if (memKind == kMemEntryArray) {
1351         // We only consider array with dim == 1 now, and element type must be primitive type
1352         auto *arrayType = static_cast<MIRArrayType *>(memType);
1353         if (arrayType->GetDim() != 1 || (arrayType->GetElemType()->GetKind() != kTypeScalar &&
1354                                          arrayType->GetElemType()->GetKind() != kTypePointer)) {
1355             MayPrintLog(debug, false, memOpKind, "array dim != 1 or array elements are not primtive type");
1356             return false;
1357         }
1358         MIRType *elemType = arrayType->GetElemType();
1359         if (elemType->GetSize() < k4ByteSize) {
1360             MayPrintLog(debug, false, memOpKind,
1361                         "element size < 4, don't expand it to avoid to genearte lots of strb/strh");
1362             return false;
1363         }
1364         size_t elemCnt = arrayType->GetSizeArrayItem(0);
1365         if (elemType->GetSize() * elemCnt != copySize) {
1366             MayPrintLog(debug, false, memOpKind, "array size not equal");
1367             return false;
1368         }
1369         // if srcExpr is too complex (for example: addrof expr of global/static array), let cg expand it
1370         if (elemCnt > 1 && IsComplexExpr(srcMem.addrExpr, func)) {
1371             MayPrintLog(debug, false, memOpKind, "srcExpr is too complex, let cg expand it to avoid redundant inst");
1372             return false;
1373         }
1374         MIRType *elemPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*elemType);
1375         MIRType *u32Type = GlobalTables::GetTypeTable().GetUInt32();
1376         for (size_t i = 0; i < elemCnt; ++i) {
1377             ConstvalNode *indexExpr = mirBuilder->CreateConstval(
1378                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(i), *u32Type));
1379             auto *arrayExpr = mirBuilder->CreateExprArray(*arrayType, addrExpr, indexExpr);
1380             auto *rhsArrayExpr = mirBuilder->CreateExprArray(*arrayType, srcMem.addrExpr, indexExpr);
1381             auto *rhsIreadExpr = mirBuilder->CreateExprIread(*elemType, *elemPtrType, 0, rhsArrayExpr);
1382             auto *arrayElemAssign = mirBuilder->CreateStmtIassign(*elemPtrType, 0, arrayExpr, rhsIreadExpr);
1383             InsertAndMayPrintStmt(block, stmt, debug, arrayElemAssign);
1384         }
1385     } else {
1386         CHECK_FATAL(false, "impossible");
1387     }
1388 
1389     // handle memcpy return val
1390     auto *retAssign = GenMemopRetAssign(stmt, func, isLowLevel, memOpKind, errorNumber);
1391     InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1392     block.RemoveStmt(&stmt);
1393     return true;
1394 }
1395 
1396 // handle memset, memcpy return val
GenMemopRetAssign(StmtNode & stmt,MIRFunction & func,bool isLowLevel,MemOpKind memOpKind,ErrorNumber errorNumber)1397 StmtNode *MemEntry::GenMemopRetAssign(StmtNode &stmt, MIRFunction &func, bool isLowLevel, MemOpKind memOpKind,
1398                                       ErrorNumber errorNumber)
1399 {
1400     if (stmt.GetOpCode() != OP_call && stmt.GetOpCode() != OP_callassigned) {
1401         return nullptr;
1402     }
1403     auto &callStmt = static_cast<CallNode &>(stmt);
1404     const auto &retVec = callStmt.GetReturnVec();
1405     if (retVec.empty()) {
1406         return nullptr;
1407     }
1408     MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1409     BaseNode *rhs = callStmt.Opnd(0);  // for memset, memcpy
1410     if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
1411         // memset_s and memcpy_s must return an errorNumber
1412         MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_i32));
1413         MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(errorNumber, *constType);
1414         rhs = mirBuilder->CreateConstval(mirConst);
1415     }
1416     if (!retVec[0].second.IsReg()) {
1417         auto *retAssign = mirBuilder->CreateStmtDassign(retVec[0].first, 0, rhs);
1418         return retAssign;
1419     } else {
1420         PregIdx pregIdx = retVec[0].second.GetPregIdx();
1421         auto pregType = func.GetPregTab()->GetPregTableItem(static_cast<uint32>(pregIdx))->GetPrimType();
1422         auto *retAssign = mirBuilder->CreateStmtRegassign(pregType, pregIdx, rhs);
1423         if (isLowLevel) {
1424             retAssign->GetRHS()->SetPrimType(pregType);
1425         }
1426         return retAssign;
1427     }
1428 }
1429 
ComputeMemOpKind(StmtNode & stmt)1430 MemOpKind SimplifyMemOp::ComputeMemOpKind(StmtNode &stmt)
1431 {
1432     if (stmt.GetOpCode() == OP_intrinsiccall) {
1433         auto intrinsicID = static_cast<IntrinsiccallNode &>(stmt).GetIntrinsic();
1434         if (intrinsicID == INTRN_C_memset) {
1435             return MEM_OP_memset;
1436         } else if (intrinsicID == INTRN_C_memcpy) {
1437             return MEM_OP_memcpy;
1438         }
1439     }
1440     // lowered memop function (such as memset) may be a call, not callassigned
1441     if (stmt.GetOpCode() != OP_callassigned && stmt.GetOpCode() != OP_call) {
1442         return MEM_OP_unknown;
1443     }
1444     auto &callStmt = static_cast<CallNode &>(stmt);
1445     MIRFunction *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callStmt.GetPUIdx());
1446     const char *funcName = func->GetName().c_str();
1447     if (strcmp(funcName, kFuncNameOfMemset) == 0) {
1448         return MEM_OP_memset;
1449     }
1450     if (strcmp(funcName, kFuncNameOfMemcpy) == 0) {
1451         return MEM_OP_memcpy;
1452     }
1453     if (strcmp(funcName, kFuncNameOfMemsetS) == 0) {
1454         return MEM_OP_memset_s;
1455     }
1456     if (strcmp(funcName, kFuncNameOfMemcpyS) == 0) {
1457         return MEM_OP_memcpy_s;
1458     }
1459     return MEM_OP_unknown;
1460 }
1461 
AutoSimplify(StmtNode & stmt,BlockNode & block,bool isLowLevel)1462 bool SimplifyMemOp::AutoSimplify(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1463 {
1464     MemOpKind memOpKind = ComputeMemOpKind(stmt);
1465     switch (memOpKind) {
1466         case MEM_OP_memset:
1467         case MEM_OP_memset_s: {
1468             return SimplifyMemset(stmt, block, isLowLevel);
1469         }
1470         case MEM_OP_memcpy:
1471         case MEM_OP_memcpy_s: {
1472             return SimplifyMemcpy(stmt, block, isLowLevel);
1473         }
1474         default:
1475             break;
1476     }
1477     return false;
1478 }
1479 
1480 // expand memset_s call statement, return pointer of memset call statement node to be expanded in the next step, return
1481 // nullptr if memset_s is expanded completely.
PartiallyExpandMemsetS(StmtNode & stmt,BlockNode & block)1482 StmtNode *SimplifyMemOp::PartiallyExpandMemsetS(StmtNode &stmt, BlockNode &block)
1483 {
1484     ErrorNumber errNum = ERRNO_OK;
1485 
1486     uint64 srcSize = 0;
1487     bool isSrcSizeConst = false;
1488     BaseNode *foldSrcSizeExpr = FoldIntConst(stmt.Opnd(kMemsetSSrcSizeOpndIdx), srcSize, isSrcSizeConst);
1489     if (foldSrcSizeExpr != nullptr) {
1490         stmt.SetOpnd(foldSrcSizeExpr, kMemsetSDstSizeOpndIdx);
1491     }
1492 
1493     uint64 dstSize = 0;
1494     bool isDstSizeConst = false;
1495     BaseNode *foldDstSizeExpr = FoldIntConst(stmt.Opnd(kMemsetSDstSizeOpndIdx), dstSize, isDstSizeConst);
1496     if (foldDstSizeExpr != nullptr) {
1497         stmt.SetOpnd(foldDstSizeExpr, kMemsetSDstSizeOpndIdx);
1498     }
1499     if (isDstSizeConst) {
1500         if ((srcSize > dstSize && dstSize == 0) || dstSize > kSecurecMemMaxLen) {
1501             errNum = ERRNO_RANGE;
1502         }
1503     }
1504 
1505     MIRBuilder *mirBuilder = func->GetModule()->GetMIRBuilder();
1506     LabelIdx finalLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('f');
1507     if (errNum != ERRNO_OK) {
1508         auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, *func, true, MEM_OP_memset_s, errNum);
1509         InsertAndMayPrintStmt(block, stmt, debug, errnoAssign);
1510         block.RemoveStmt(&stmt);
1511         return nullptr;
1512     } else {
1513         LabelIdx dstSizeCheckLabIdx, srcSizeCheckLabIdx, nullPtrLabIdx;
1514         if (!isDstSizeConst) {
1515             // check if dst size is greater than maxlen
1516             dstSizeCheckLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n');  // 'n' means nullptr
1517             CreateAndInsertCheckStmt(OP_gt, stmt.Opnd(kMemsetSDstSizeOpndIdx),
1518                                      ConstructConstvalNode(kSecurecMemMaxLen, PTY_u64, *mirBuilder), dstSizeCheckLabIdx,
1519                                      stmt, block, *func, debug);
1520         }
1521 
1522         // check if dst is nullptr
1523         nullPtrLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n');  // 'n' means nullptr
1524         CreateAndInsertCheckStmt(OP_eq, stmt.Opnd(kMemsetDstOpndIdx), ConstructConstvalNode(0, PTY_u64, *mirBuilder),
1525                                  nullPtrLabIdx, stmt, block, *func, debug);
1526 
1527         if (isDstSizeConst && isSrcSizeConst) {
1528             if (srcSize > dstSize) {
1529                 srcSize = dstSize;
1530                 return ExpandOnSrcSizeGtDstSize(stmt, block, srcSize, finalLabIdx, nullPtrLabIdx, *func, debug);
1531             }
1532         } else {
1533             // check if src size is greater than dst size
1534             srcSizeCheckLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n');  // 'n' means nullptr
1535             CreateAndInsertCheckStmt(OP_gt, stmt.Opnd(kMemsetSSrcSizeOpndIdx), stmt.Opnd(kMemsetSDstSizeOpndIdx),
1536                                      srcSizeCheckLabIdx, stmt, block, *func, debug);
1537         }
1538 
1539         MapleVector<BaseNode *> args(func->GetCodeMempoolAllocator().Adapter());
1540         args.push_back(stmt.Opnd(kMemsetDstOpndIdx));
1541         args.push_back(stmt.Opnd(kMemsetSSrcOpndIdx));
1542         args.push_back(stmt.Opnd(kMemsetSSrcSizeOpndIdx));
1543         auto memsetCallStmt = InsertMemsetCallStmt(args, *func, stmt, block, finalLabIdx, errNum, debug);
1544 
1545         if (!isSrcSizeConst || !isDstSizeConst) {
1546             // handle src size error
1547             auto branchLabNode = mirBuilder->CreateStmtLabel(srcSizeCheckLabIdx);
1548             InsertAndMayPrintStmt(block, stmt, debug, branchLabNode);
1549             HandleZeroValueOfDstSize(stmt, block, srcSize, dstSize, finalLabIdx, dstSizeCheckLabIdx, *func,
1550                                      isDstSizeConst, debug);
1551             args.pop_back();
1552             args.push_back(stmt.Opnd(kMemsetSDstSizeOpndIdx));
1553             (void)InsertMemsetCallStmt(args, *func, stmt, block, finalLabIdx, ERRNO_RANGE_AND_RESET, debug);
1554         }
1555 
1556         // handle dst nullptr error
1557         auto nullptrLabNode = mirBuilder->CreateStmtLabel(nullPtrLabIdx);
1558         InsertAndMayPrintStmt(block, stmt, debug, nullptrLabNode);
1559         HandleZeroValueOfDstSize(stmt, block, srcSize, dstSize, finalLabIdx, dstSizeCheckLabIdx, *func, isDstSizeConst,
1560                                  debug);
1561         auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
1562         auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, *func, true, MEM_OP_memset_s, ERRNO_INVAL);
1563         InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
1564 
1565         if (!isDstSizeConst) {
1566             // handle dst size error
1567             InsertCheckFailedBranch(*func, stmt, block, dstSizeCheckLabIdx, finalLabIdx, ERRNO_RANGE, MEM_OP_memset_s,
1568                                     debug);
1569         }
1570         auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
1571         InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
1572         block.RemoveStmt(&stmt);
1573         return memsetCallStmt;
1574     }
1575 }
1576 
1577 // Try to replace the call to memset with a series of assign operations (including dassign, iassign, iassignoff), which
1578 // is usually profitable for small memory size.
1579 // This function is called in two places, one in mpl2mpl simplify, another in cglower:
1580 // (1) mpl2mpl memset expand (isLowLevel == false)
1581 //   for primitive type, array type with element size < 4 bytes and struct type without padding
1582 // (2) cglower memset expand
1583 //   for array type with element size >= 4 bytes and struct type with paddings
SimplifyMemset(StmtNode & stmt,BlockNode & block,bool isLowLevel)1584 bool SimplifyMemOp::SimplifyMemset(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1585 {
1586     MemOpKind memOpKind = ComputeMemOpKind(stmt);
1587     if (memOpKind != MEM_OP_memset && memOpKind != MEM_OP_memset_s) {
1588         return false;
1589     }
1590     uint32 dstOpndIdx = 0;
1591     uint32 srcOpndIdx = 1;
1592     uint32 srcSizeOpndIdx = 2;
1593     bool isSafeVersion = memOpKind == MEM_OP_memset_s;
1594     if (debug) {
1595         LogInfo::MapleLogger() << "[funcName] " << func->GetName() << std::endl;
1596         stmt.Dump(0);
1597     }
1598 
1599     StmtNode *memsetCallStmt = &stmt;
1600     if (memOpKind == MEM_OP_memset_s && !isLowLevel) {
1601         memsetCallStmt = PartiallyExpandMemsetS(stmt, block);
1602         if (!memsetCallStmt) {
1603             return true;  // Expand memset_s completely, no extra memset is generated, so just return true
1604         }
1605     }
1606 
1607     uint64 srcSize = 0;
1608     bool isSrcSizeConst = false;
1609     BaseNode *foldSrcSizeExpr = FoldIntConst(memsetCallStmt->Opnd(srcSizeOpndIdx), srcSize, isSrcSizeConst);
1610     if (foldSrcSizeExpr != nullptr) {
1611         memsetCallStmt->SetOpnd(foldSrcSizeExpr, srcSizeOpndIdx);
1612     }
1613 
1614     if (isSrcSizeConst) {
1615         // If the size is too big, we won't expand it
1616         uint32 thresholdExpand = (isSafeVersion ? thresholdMemsetSExpand : thresholdMemsetExpand);
1617         if (srcSize > thresholdExpand) {
1618             MayPrintLog(debug, false, memOpKind, "size is too big");
1619             return false;
1620         }
1621         if (srcSize == 0) {
1622             if (memOpKind == MEM_OP_memset) {
1623                 auto *retAssign = MemEntry::GenMemopRetAssign(*memsetCallStmt, *func, isLowLevel, memOpKind);
1624                 InsertAndMayPrintStmt(block, *memsetCallStmt, debug, retAssign);
1625             }
1626             block.RemoveStmt(memsetCallStmt);
1627             return true;
1628         }
1629     }
1630 
1631     // memset's 'src size' must be a const value, otherwise we can not expand it
1632     if (!isSrcSizeConst) {
1633         MayPrintLog(debug, false, memOpKind, "size is not int const");
1634         return false;
1635     }
1636 
1637     ErrorNumber errNum = ERRNO_OK;
1638 
1639     uint64 val = 0;
1640     bool isIntConst = false;
1641     BaseNode *foldValExpr = FoldIntConst(memsetCallStmt->Opnd(srcOpndIdx), val, isIntConst);
1642     if (foldValExpr != nullptr) {
1643         memsetCallStmt->SetOpnd(foldValExpr, srcOpndIdx);
1644     }
1645     // memset's second argument 'val' should also be a const value
1646     if (!isIntConst) {
1647         MayPrintLog(debug, false, memOpKind, "val is not int const");
1648         return false;
1649     }
1650 
1651     MemEntry dstMemEntry;
1652     bool valid = MemEntry::ComputeMemEntry(*(memsetCallStmt->Opnd(dstOpndIdx)), *func, dstMemEntry, isLowLevel);
1653     if (!valid) {
1654         MayPrintLog(debug, false, memOpKind, "dstMemEntry is invalid");
1655         return false;
1656     }
1657     bool ret = false;
1658     if (srcSize != 0) {
1659         ret = dstMemEntry.ExpandMemset(val, static_cast<uint64>(srcSize), *func, *memsetCallStmt, block, isLowLevel,
1660                                        debug, errNum);
1661     } else {
1662         // if size == 0, no need to set memory, just return error nummber
1663         auto *retAssign = MemEntry::GenMemopRetAssign(*memsetCallStmt, *func, isLowLevel, memOpKind, errNum);
1664         InsertAndMayPrintStmt(block, *memsetCallStmt, debug, retAssign);
1665         block.RemoveStmt(memsetCallStmt);
1666         ret = true;
1667     }
1668     if (ret) {
1669         MayPrintLog(debug, true, memOpKind, "well done");
1670     }
1671     return ret;
1672 }
1673 
SimplifyMemcpy(StmtNode & stmt,BlockNode & block,bool isLowLevel)1674 bool SimplifyMemOp::SimplifyMemcpy(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1675 {
1676     MemOpKind memOpKind = ComputeMemOpKind(stmt);
1677     if (memOpKind != MEM_OP_memcpy && memOpKind != MEM_OP_memcpy_s) {
1678         return false;
1679     }
1680     uint32 dstOpndIdx = 0;
1681     uint32 dstSizeOpndIdx = kFirstOpnd;  // only used by memcpy_s
1682     uint32 srcOpndIdx = kSecondOpnd;
1683     uint32 srcSizeOpndIdx = kThirdOpnd;
1684     bool isSafeVersion = memOpKind == MEM_OP_memcpy_s;
1685     if (isSafeVersion) {
1686         dstSizeOpndIdx = kSecondOpnd;
1687         srcOpndIdx = kThirdOpnd;
1688         srcSizeOpndIdx = kFourthOpnd;
1689     }
1690     if (debug) {
1691         LogInfo::MapleLogger() << "[funcName] " << func->GetName() << std::endl;
1692         stmt.Dump(0);
1693     }
1694 
1695     uint64 srcSize = 0;
1696     bool isIntConst = false;
1697     BaseNode *foldCopySizeExpr = FoldIntConst(stmt.Opnd(srcSizeOpndIdx), srcSize, isIntConst);
1698     if (foldCopySizeExpr != nullptr) {
1699         stmt.SetOpnd(foldCopySizeExpr, srcSizeOpndIdx);
1700     }
1701     if (!isIntConst) {
1702         MayPrintLog(debug, false, memOpKind, "src size is not an int const");
1703         return false;
1704     }
1705     uint32 thresholdExpand = (isSafeVersion ? thresholdMemcpySExpand : thresholdMemcpyExpand);
1706     if (srcSize > thresholdExpand) {
1707         MayPrintLog(debug, false, memOpKind, "size is too big");
1708         return false;
1709     }
1710     if (srcSize == 0) {
1711         MayPrintLog(debug, false, memOpKind, "memcpy with src size 0");
1712         return false;
1713     }
1714     uint64 copySize = srcSize;
1715     ErrorNumber errNum = ERRNO_OK;
1716     if (isSafeVersion) {
1717         uint64 dstSize = 0;
1718         bool isDstSizeConst = false;
1719         BaseNode *foldDstSizeExpr = FoldIntConst(stmt.Opnd(dstSizeOpndIdx), dstSize, isDstSizeConst);
1720         if (foldDstSizeExpr != nullptr) {
1721             stmt.SetOpnd(foldDstSizeExpr, dstSizeOpndIdx);
1722         }
1723         if (!isDstSizeConst) {
1724             MayPrintLog(debug, false, memOpKind, "dst size is not int const");
1725             return false;
1726         }
1727         if (dstSize == 0 || dstSize > kSecurecMemMaxLen) {
1728             copySize = 0;
1729             errNum = ERRNO_RANGE;
1730         } else if (srcSize > dstSize) {
1731             copySize = dstSize;
1732             errNum = ERRNO_RANGE_AND_RESET;
1733         }
1734     }
1735 
1736     MemEntry dstMemEntry;
1737     bool valid = MemEntry::ComputeMemEntry(*stmt.Opnd(dstOpndIdx), *func, dstMemEntry, isLowLevel);
1738     if (!valid) {
1739         MayPrintLog(debug, false, memOpKind, "dstMemEntry is invalid");
1740         return false;
1741     }
1742     MemEntry srcMemEntry;
1743     valid = MemEntry::ComputeMemEntry(*stmt.Opnd(srcOpndIdx), *func, srcMemEntry, isLowLevel);
1744     if (!valid) {
1745         MayPrintLog(debug, false, memOpKind, "srcMemEntry is invalid");
1746         return false;
1747     }
1748     // We don't check type consistency when doing low level expand
1749     if (!isLowLevel) {
1750         if (dstMemEntry.memType != srcMemEntry.memType) {
1751             MayPrintLog(debug, false, memOpKind, "dst and src have different type");
1752             return false;  // entryType must be identical
1753         }
1754         if (dstMemEntry.memType->GetSize() != static_cast<uint64>(srcSize)) {
1755             MayPrintLog(debug, false, memOpKind, "copy size != dst memory size");
1756             return false;  // copy size should equal to dst memory size, we maybe allow smaller copy size later
1757         }
1758     }
1759     bool ret = false;
1760     if (copySize != 0) {
1761         ret = dstMemEntry.ExpandMemcpy(srcMemEntry, copySize, *func, stmt, block, isLowLevel, debug, errNum);
1762     } else {
1763         // if copySize == 0, no need to copy memory, just return error number
1764         auto *retAssign = MemEntry::GenMemopRetAssign(stmt, *func, isLowLevel, memOpKind, errNum);
1765         InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1766         block.RemoveStmt(&stmt);
1767         ret = true;
1768     }
1769     if (ret) {
1770         MayPrintLog(debug, true, memOpKind, "well done");
1771     }
1772     return ret;
1773 }
1774 
1775 }  // namespace maple
1776