1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "triple.h"
17 #include "simplify.h"
18 #include <functional>
19 #include <initializer_list>
20 #include <iostream>
21 #include <algorithm>
22 #include "constantfold.h"
23 #include "mpl_logging.h"
24
25 namespace maple {
26
27 namespace {
28
29 constexpr char kClassNameOfMath[] = "Ljava_2Flang_2FMath_3B";
30 constexpr char kFuncNamePrefixOfMathSqrt[] = "Ljava_2Flang_2FMath_3B_7Csqrt_7C_28D_29D";
31 constexpr char kFuncNamePrefixOfMathAbs[] = "Ljava_2Flang_2FMath_3B_7Cabs_7C";
32 constexpr char kFuncNamePrefixOfMathMax[] = "Ljava_2Flang_2FMath_3B_7Cmax_7C";
33 constexpr char kFuncNamePrefixOfMathMin[] = "Ljava_2Flang_2FMath_3B_7Cmin_7C";
34 constexpr char kFuncNameOfMathAbs[] = "abs";
35 constexpr char kFuncNameOfMemset[] = "memset";
36 constexpr char kFuncNameOfMemcpy[] = "memcpy";
37 constexpr char kFuncNameOfMemsetS[] = "memset_s";
38 constexpr char kFuncNameOfMemcpyS[] = "memcpy_s";
39 constexpr uint64_t kSecurecMemMaxLen = 0x7fffffffUL;
40 static constexpr int32 kProbUnlikely = 1000;
41 constexpr uint32_t kMemsetDstOpndIdx = 0;
42 constexpr uint32_t kMemsetSDstSizeOpndIdx = 1;
43 constexpr uint32_t kMemsetSSrcOpndIdx = 2;
44 constexpr uint32_t kMemsetSSrcSizeOpndIdx = 3;
45
46 // Truncate the constant field of 'union' if it's written as scalar type (e.g. int),
47 // but accessed as bit-field type with smaller size.
48 //
49 // Return the truncated constant or nullptr if the constant doesn't need to be truncated.
TruncateUnionConstant(const MIRStructType & unionType,MIRConst * fieldCst,const MIRType & unionFieldType)50 MIRConst *TruncateUnionConstant(const MIRStructType &unionType, MIRConst *fieldCst, const MIRType &unionFieldType)
51 {
52 if (unionType.GetKind() != kTypeUnion) {
53 return nullptr;
54 }
55
56 auto *bitFieldType = safe_cast<MIRBitFieldType>(unionFieldType);
57 auto *intCst = safe_cast<MIRIntConst>(fieldCst);
58
59 if (!bitFieldType || !intCst) {
60 return nullptr;
61 }
62
63 bool isBigEndian = Triple::GetTriple().IsBigEndian();
64
65 IntVal val = intCst->GetValue();
66 uint8 bitSize = bitFieldType->GetFieldSize();
67
68 if (bitSize >= val.GetBitWidth()) {
69 return nullptr;
70 }
71
72 if (isBigEndian) {
73 val = val.LShr(val.GetBitWidth() - bitSize);
74 } else {
75 val = val & ((uint64(1) << bitSize) - 1);
76 }
77
78 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, fieldCst->GetType());
79 }
80
81 } // namespace
82
83 // If size (in byte) is bigger than this threshold, we won't expand memop
84 const uint32 SimplifyMemOp::thresholdMemsetExpand = 512;
85 const uint32 SimplifyMemOp::thresholdMemcpyExpand = 512;
86 const uint32 SimplifyMemOp::thresholdMemsetSExpand = 1024;
87 const uint32 SimplifyMemOp::thresholdMemcpySExpand = 1024;
88 static const uint32 kMaxMemoryBlockSizeToAssign = 8; // in byte
89
MayPrintLog(bool debug,bool success,MemOpKind memOpKind,const char * str)90 static void MayPrintLog(bool debug, bool success, MemOpKind memOpKind, const char *str)
91 {
92 if (!debug) {
93 return;
94 }
95 const char *memop = "";
96 if (memOpKind == MEM_OP_memset) {
97 memop = "memset";
98 } else if (memOpKind == MEM_OP_memcpy) {
99 memop = "memcpy";
100 } else if (memOpKind == MEM_OP_memset_s) {
101 memop = "memset_s";
102 } else if (memOpKind == MEM_OP_memcpy_s) {
103 memop = "memcpy_s";
104 }
105 LogInfo::MapleLogger() << memop << " expand " << (success ? "success: " : "failure: ") << str << std::endl;
106 }
107
IsMathSqrt(const std::string funcName)108 bool Simplify::IsMathSqrt(const std::string funcName)
109 {
110 return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathSqrt) == 0));
111 }
112
IsMathAbs(const std::string funcName)113 bool Simplify::IsMathAbs(const std::string funcName)
114 {
115 return (mirMod.IsCModule() && (strcmp(funcName.c_str(), kFuncNameOfMathAbs) == 0)) ||
116 (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathAbs) == 0));
117 }
118
IsMathMax(const std::string funcName)119 bool Simplify::IsMathMax(const std::string funcName)
120 {
121 return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathMax) == 0));
122 }
123
IsMathMin(const std::string funcName)124 bool Simplify::IsMathMin(const std::string funcName)
125 {
126 return (mirMod.IsJavaModule() && (strcmp(funcName.c_str(), kFuncNamePrefixOfMathMin) == 0));
127 }
128
SimplifyMathMethod(const StmtNode & stmt,BlockNode & block)129 bool Simplify::SimplifyMathMethod(const StmtNode &stmt, BlockNode &block)
130 {
131 if (stmt.GetOpCode() != OP_callassigned) {
132 return false;
133 }
134 auto &cnode = static_cast<const CallNode &>(stmt);
135 MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(cnode.GetPUIdx());
136 DEBUG_ASSERT(calleeFunc != nullptr, "null ptr check");
137 const std::string &funcName = calleeFunc->GetName();
138 if (funcName.empty()) {
139 return false;
140 }
141 if (!mirMod.IsCModule() && !mirMod.IsJavaModule()) {
142 return false;
143 }
144 if (mirMod.IsJavaModule() && funcName.find(kClassNameOfMath) == std::string::npos) {
145 return false;
146 }
147 if (cnode.GetNumOpnds() == 0 || cnode.GetReturnVec().empty()) {
148 return false;
149 }
150
151 auto *opnd0 = cnode.Opnd(0);
152 DEBUG_ASSERT(opnd0 != nullptr, "null ptr check");
153 auto *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(opnd0->GetPrimType());
154
155 BaseNode *opExpr = nullptr;
156 if (IsMathSqrt(funcName) && !IsPrimitiveFloat(opnd0->GetPrimType())) {
157 opExpr = builder->CreateExprUnary(OP_sqrt, *type, opnd0);
158 } else if (IsMathAbs(funcName)) {
159 opExpr = builder->CreateExprUnary(OP_abs, *type, opnd0);
160 } else if (IsMathMax(funcName)) {
161 opExpr = builder->CreateExprBinary(OP_max, *type, opnd0, cnode.Opnd(1));
162 } else if (IsMathMin(funcName)) {
163 opExpr = builder->CreateExprBinary(OP_min, *type, opnd0, cnode.Opnd(1));
164 }
165 if (opExpr != nullptr) {
166 auto stIdx = cnode.GetNthReturnVec(0).first;
167 auto *dassign = builder->CreateStmtDassign(stIdx, 0, opExpr);
168 block.ReplaceStmt1WithStmt2(&stmt, dassign);
169 return true;
170 }
171 return false;
172 }
173
SimplifyCallAssigned(StmtNode & stmt,BlockNode & block)174 void Simplify::SimplifyCallAssigned(StmtNode &stmt, BlockNode &block)
175 {
176 if (SimplifyMathMethod(stmt, block)) {
177 return;
178 }
179 simplifyMemOp.SetDebug(dump);
180 simplifyMemOp.SetFunction(currFunc);
181 if (simplifyMemOp.AutoSimplify(stmt, block, false)) {
182 return;
183 }
184 }
185
186 constexpr uint32 kUpperLimitOfFieldNum = 10;
GetDassignedStructType(const DassignNode * dassign,MIRFunction * func)187 static MIRStructType *GetDassignedStructType(const DassignNode *dassign, MIRFunction *func)
188 {
189 const auto &lhsStIdx = dassign->GetStIdx();
190 auto lhsSymbol = func->GetLocalOrGlobalSymbol(lhsStIdx);
191 auto lhsAggType = lhsSymbol->GetType();
192 if (!lhsAggType->IsStructType()) {
193 return nullptr;
194 }
195 if (lhsAggType->GetKind() == kTypeUnion) { // no need to split union's field
196 return nullptr;
197 }
198 auto lhsFieldID = dassign->GetFieldID();
199 if (lhsFieldID != 0) {
200 CHECK_FATAL(lhsAggType->IsStructType(), "only struct has non-zero fieldID");
201 lhsAggType = static_cast<MIRStructType *>(lhsAggType)->GetFieldType(lhsFieldID);
202 if (!lhsAggType->IsStructType()) {
203 return nullptr;
204 }
205 if (lhsAggType->GetKind() == kTypeUnion) { // no need to split union's field
206 return nullptr;
207 }
208 }
209 if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
210 return nullptr;
211 }
212 return static_cast<MIRStructType *>(lhsAggType);
213 }
214
GetIassignedStructType(const IassignNode * iassign)215 static MIRStructType *GetIassignedStructType(const IassignNode *iassign)
216 {
217 auto ptrTyIdx = iassign->GetTyIdx();
218 auto *ptrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(ptrTyIdx);
219 CHECK_FATAL(ptrType->IsMIRPtrType(), "must be pointer type");
220 auto aggTyIdx = static_cast<MIRPtrType *>(ptrType)->GetPointedTyIdxWithFieldID(iassign->GetFieldID());
221 auto *lhsAggType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(aggTyIdx);
222 if (!lhsAggType->IsStructType()) {
223 return nullptr;
224 }
225 if (lhsAggType->GetKind() == kTypeUnion) {
226 return nullptr;
227 }
228 if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
229 return nullptr;
230 }
231 return static_cast<MIRStructType *>(lhsAggType);
232 }
233
GetReadedStructureType(const DreadNode * dread,const MIRFunction * func)234 static MIRStructType *GetReadedStructureType(const DreadNode *dread, const MIRFunction *func)
235 {
236 const auto &rhsStIdx = dread->GetStIdx();
237 auto rhsSymbol = func->GetLocalOrGlobalSymbol(rhsStIdx);
238 auto rhsAggType = rhsSymbol->GetType();
239 auto rhsFieldID = dread->GetFieldID();
240 if (rhsFieldID != 0) {
241 CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
242 rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
243 }
244 if (!rhsAggType->IsStructType()) {
245 return nullptr;
246 }
247 return static_cast<MIRStructType *>(rhsAggType);
248 }
249
GetReadedStructureType(const IreadNode * iread,const MIRFunction *)250 static MIRStructType *GetReadedStructureType(const IreadNode *iread, const MIRFunction *)
251 {
252 auto rhsPtrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx());
253 CHECK_FATAL(rhsPtrType->IsMIRPtrType(), "must be pointer type");
254 auto rhsAggType = static_cast<MIRPtrType *>(rhsPtrType)->GetPointedType();
255 auto rhsFieldID = iread->GetFieldID();
256 if (rhsFieldID != 0) {
257 CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
258 rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
259 }
260 if (!rhsAggType->IsStructType()) {
261 return nullptr;
262 }
263 return static_cast<MIRStructType *>(rhsAggType);
264 }
265
266 template <class RhsType, class AssignType>
SplitAggCopy(const AssignType * assignNode,MIRStructType * structureType,BlockNode * block,MIRFunction * func)267 static StmtNode *SplitAggCopy(const AssignType *assignNode, MIRStructType *structureType, BlockNode *block,
268 MIRFunction *func)
269 {
270 auto *readNode = static_cast<RhsType *>(assignNode->GetRHS());
271 auto rhsFieldID = readNode->GetFieldID();
272 auto *rhsAggType = GetReadedStructureType(readNode, func);
273 if (structureType != rhsAggType) {
274 return nullptr;
275 }
276
277 for (FieldID id = 1; id <= static_cast<FieldID>(structureType->NumberOfFieldIDs()); ++id) {
278 MIRType *fieldType = structureType->GetFieldType(id);
279 if (fieldType->GetSize() == 0) {
280 continue; // field size is zero for empty struct/union;
281 }
282 if (fieldType->GetKind() == kTypeBitField && static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
283 continue; // bitfield size is zero
284 }
285 auto *newDassign = assignNode->CloneTree(func->GetCodeMemPoolAllocator());
286 newDassign->SetFieldID(assignNode->GetFieldID() + id);
287 auto *newRHS = static_cast<RhsType *>(newDassign->GetRHS());
288 newRHS->SetFieldID(rhsFieldID + id);
289 newRHS->SetPrimType(fieldType->GetPrimType());
290 block->InsertAfter(assignNode, newDassign);
291 if (fieldType->IsMIRUnionType()) {
292 id += fieldType->NumberOfFieldIDs();
293 }
294 }
295 auto newAssign = assignNode->GetNext();
296 block->RemoveStmt(assignNode);
297 return newAssign;
298 }
299
SplitDassignAggCopy(DassignNode * dassign,BlockNode * block,MIRFunction * func)300 static StmtNode *SplitDassignAggCopy(DassignNode *dassign, BlockNode *block, MIRFunction *func)
301 {
302 auto *rhs = dassign->GetRHS();
303 if (rhs->GetPrimType() != PTY_agg) {
304 return nullptr;
305 }
306
307 auto *lhsAggType = GetDassignedStructType(dassign, func);
308 if (lhsAggType == nullptr) {
309 return nullptr;
310 }
311
312 if (rhs->GetOpCode() == OP_dread) {
313 auto *lhsSymbol = func->GetLocalOrGlobalSymbol(dassign->GetStIdx());
314 auto *rhsSymbol = func->GetLocalOrGlobalSymbol(static_cast<DreadNode *>(rhs)->GetStIdx());
315 if (!lhsSymbol->IsLocal() && !rhsSymbol->IsLocal()) {
316 return nullptr;
317 }
318
319 return SplitAggCopy<DreadNode>(dassign, lhsAggType, block, func);
320 } else if (rhs->GetOpCode() == OP_iread) {
321 return SplitAggCopy<IreadNode>(dassign, lhsAggType, block, func);
322 }
323 return nullptr;
324 }
325
SplitIassignAggCopy(IassignNode * iassign,BlockNode * block,MIRFunction * func)326 static StmtNode *SplitIassignAggCopy(IassignNode *iassign, BlockNode *block, MIRFunction *func)
327 {
328 auto rhs = iassign->GetRHS();
329 if (rhs->GetPrimType() != PTY_agg) {
330 return nullptr;
331 }
332
333 auto *lhsAggType = GetIassignedStructType(iassign);
334 if (lhsAggType == nullptr) {
335 return nullptr;
336 }
337
338 if (rhs->GetOpCode() == OP_dread) {
339 return SplitAggCopy<DreadNode>(iassign, lhsAggType, block, func);
340 } else if (rhs->GetOpCode() == OP_iread) {
341 return SplitAggCopy<IreadNode>(iassign, lhsAggType, block, func);
342 }
343 return nullptr;
344 }
345
UseGlobalVar(const BaseNode * expr)346 bool UseGlobalVar(const BaseNode *expr)
347 {
348 if (expr->GetOpCode() == OP_addrof || expr->GetOpCode() == OP_dread) {
349 StIdx stIdx = static_cast<const AddrofNode *>(expr)->GetStIdx();
350 if (stIdx.IsGlobal()) {
351 return true;
352 }
353 }
354 for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
355 if (UseGlobalVar(expr->Opnd(i))) {
356 return true;
357 }
358 }
359 return false;
360 }
361
SimplifyToSelect(MIRFunction * func,IfStmtNode * ifNode,BlockNode * block)362 StmtNode *Simplify::SimplifyToSelect(MIRFunction *func, IfStmtNode *ifNode, BlockNode *block)
363 {
364 // Example: if (condition) {
365 // Example: res = trueRes
366 // Example: }
367 // Example: else {
368 // Example: res = falseRes
369 // Example: }
370 // =================
371 // res = select condition ? trueRes : falseRes
372 if (ifNode->GetPrev() != nullptr && ifNode->GetPrev()->GetOpCode() == OP_label) {
373 // simplify shortCircuit will stop opt in cfg_opt, and generate extra compare
374 auto *labelNode = static_cast<LabelNode *>(ifNode->GetPrev());
375 const std::string &labelName = func->GetLabelTabItem(labelNode->GetLabelIdx());
376 if (labelName.find("shortCircuit") != std::string::npos) {
377 return nullptr;
378 }
379 }
380 if (ifNode->GetThenPart() == nullptr || ifNode->GetElsePart() == nullptr) {
381 return nullptr;
382 }
383 StmtNode *thenFirst = ifNode->GetThenPart()->GetFirst();
384 StmtNode *elseFirst = ifNode->GetElsePart()->GetFirst();
385 if (thenFirst == nullptr || elseFirst == nullptr) {
386 return nullptr;
387 }
388 // thenpart and elsepart has only one stmt
389 if (thenFirst->GetNext() != nullptr || elseFirst->GetNext() != nullptr) {
390 return nullptr;
391 }
392 if (thenFirst->GetOpCode() != OP_dassign || elseFirst->GetOpCode() != OP_dassign) {
393 return nullptr;
394 }
395 auto *thenDass = static_cast<DassignNode *>(thenFirst);
396 auto *elseDass = static_cast<DassignNode *>(elseFirst);
397 if (thenDass->GetStIdx() != elseDass->GetStIdx() || thenDass->GetFieldID() != elseDass->GetFieldID()) {
398 return nullptr;
399 }
400 // iread has sideeffect : may cause deref error
401 if (HasIreadExpr(thenDass->GetRHS()) || HasIreadExpr(elseDass->GetRHS())) {
402 return nullptr;
403 }
404 // Check if the operand of the select node is complex enough
405 // we should not simplify it to if-then-else for either functionality or performance reason
406 if (thenDass->GetRHS()->GetPrimType() == PTY_agg || elseDass->GetRHS()->GetPrimType() == PTY_agg) {
407 return nullptr;
408 }
409 constexpr size_t maxDepth = 3;
410 if (MaxDepth(thenDass->GetRHS()) > maxDepth || MaxDepth(elseDass->GetRHS()) > maxDepth) {
411 return nullptr;
412 }
413 if (UseGlobalVar(thenDass->GetRHS()) || UseGlobalVar(elseDass->GetRHS())) {
414 return nullptr;
415 }
416 MIRBuilder *mirBuiler = func->GetModule()->GetMIRBuilder();
417 MIRType *type = GlobalTables::GetTypeTable().GetPrimType(thenDass->GetRHS()->GetPrimType());
418 auto *selectExpr =
419 mirBuiler->CreateExprTernary(OP_select, *type, ifNode->Opnd(0), thenDass->GetRHS(), elseDass->GetRHS());
420 auto *newDassign = mirBuiler->CreateStmtDassign(thenDass->GetStIdx(), thenDass->GetFieldID(), selectExpr);
421 newDassign->SetSrcPos(ifNode->GetSrcPos());
422 block->InsertBefore(ifNode, newDassign);
423 block->RemoveStmt(ifNode);
424 return newDassign;
425 }
426
ProcessStmt(StmtNode & stmt)427 void Simplify::ProcessStmt(StmtNode &stmt)
428 {
429 switch (stmt.GetOpCode()) {
430 case OP_callassigned: {
431 SimplifyCallAssigned(stmt, *currBlock);
432 break;
433 }
434 case OP_intrinsiccall: {
435 simplifyMemOp.SetDebug(dump);
436 simplifyMemOp.SetFunction(currFunc);
437 (void)simplifyMemOp.AutoSimplify(stmt, *currBlock, false);
438 break;
439 }
440 case OP_dassign: {
441 auto *newStmt = SplitDassignAggCopy(static_cast<DassignNode *>(&stmt), currBlock, currFunc);
442 if (newStmt) {
443 ProcessBlock(*newStmt);
444 }
445 break;
446 }
447 case OP_iassign: {
448 auto *newStmt = SplitIassignAggCopy(static_cast<IassignNode *>(&stmt), currBlock, currFunc);
449 if (newStmt) {
450 ProcessBlock(*newStmt);
451 }
452 break;
453 }
454 case OP_if:
455 case OP_while:
456 case OP_dowhile: {
457 auto unaryStmt = static_cast<UnaryStmtNode &>(stmt);
458 unaryStmt.SetRHS(SimplifyExpr(*unaryStmt.GetRHS()));
459 return;
460 }
461 default: {
462 break;
463 }
464 }
465 for (size_t i = 0; i < stmt.NumOpnds(); ++i) {
466 if (stmt.Opnd(i)) {
467 stmt.SetOpnd(SimplifyExpr(*stmt.Opnd(i)), i);
468 }
469 }
470 }
471
SimplifyExpr(BaseNode & expr)472 BaseNode *Simplify::SimplifyExpr(BaseNode &expr)
473 {
474 switch (expr.GetOpCode()) {
475 case OP_dread: {
476 auto &dread = static_cast<DreadNode &>(expr);
477 return ReplaceExprWithConst(dread);
478 }
479 default: {
480 for (auto i = 0; i < expr.GetNumOpnds(); i++) {
481 if (expr.Opnd(i)) {
482 expr.SetOpnd(SimplifyExpr(*expr.Opnd(i)), i);
483 }
484 }
485 break;
486 }
487 }
488 return &expr;
489 }
490
ReplaceExprWithConst(DreadNode & dread)491 BaseNode *Simplify::ReplaceExprWithConst(DreadNode &dread)
492 {
493 auto stIdx = dread.GetStIdx();
494 auto fieldId = dread.GetFieldID();
495 auto *symbol = currFunc->GetLocalOrGlobalSymbol(stIdx);
496 auto *symbolConst = symbol->GetKonst();
497 if (!currFunc->GetModule()->IsCModule() || !symbolConst || !stIdx.IsGlobal() ||
498 !IsSymbolReplaceableWithConst(*symbol)) {
499 return &dread;
500 }
501 if (fieldId != 0) {
502 symbolConst = GetElementConstFromFieldId(fieldId, symbolConst);
503 }
504 if (!symbolConst || !IsConstRepalceable(*symbolConst)) {
505 return &dread;
506 }
507 return currFunc->GetModule()->GetMIRBuilder()->CreateConstval(symbolConst);
508 }
509
IsSymbolReplaceableWithConst(const MIRSymbol & symbol) const510 bool Simplify::IsSymbolReplaceableWithConst(const MIRSymbol &symbol) const
511 {
512 return (symbol.GetStorageClass() == kScFstatic && !symbol.HasPotentialAssignment()) ||
513 symbol.GetAttrs().GetAttr(ATTR_const);
514 }
515
IsConstRepalceable(const MIRConst & mirConst) const516 bool Simplify::IsConstRepalceable(const MIRConst &mirConst) const
517 {
518 switch (mirConst.GetKind()) {
519 case kConstInt:
520 case kConstFloatConst:
521 case kConstDoubleConst:
522 case kConstFloat128Const:
523 case kConstLblConst:
524 return true;
525 default:
526 return false;
527 }
528 }
529
GetElementConstFromFieldId(FieldID fieldId,MIRConst * mirConst)530 MIRConst *Simplify::GetElementConstFromFieldId(FieldID fieldId, MIRConst *mirConst)
531 {
532 FieldID currFieldId = 1;
533 MIRConst *resultConst = nullptr;
534 auto originAggConst = static_cast<MIRAggConst *>(mirConst);
535 auto originAggType = static_cast<MIRStructType &>(originAggConst->GetType());
536 bool hasReached = false;
537 std::function<void(MIRConst *)> traverseAgg = [&](MIRConst *currConst) {
538 auto *currAggConst = safe_cast<MIRAggConst>(currConst);
539 ASSERT_NOT_NULL(currAggConst);
540 auto *currAggType = safe_cast<MIRStructType>(currAggConst->GetType());
541 ASSERT_NOT_NULL(currAggType);
542 for (size_t iter = 0; iter < currAggType->GetFieldsSize() && !hasReached; ++iter) {
543 size_t constIdx = currAggType->GetKind() == kTypeUnion ? 1 : iter + 1;
544 auto *fieldConst = currAggConst->GetAggConstElement(constIdx);
545 auto *fieldType = originAggType.GetFieldType(currFieldId);
546
547 if (currFieldId == fieldId) {
548 if (auto *truncCst = TruncateUnionConstant(*currAggType, fieldConst, *fieldType)) {
549 resultConst = truncCst;
550 } else {
551 resultConst = fieldConst;
552 }
553
554 hasReached = true;
555 return;
556 }
557
558 ++currFieldId;
559 if (fieldType->GetKind() == kTypeUnion || fieldType->GetKind() == kTypeStruct) {
560 traverseAgg(fieldConst);
561 }
562 }
563 };
564 traverseAgg(mirConst);
565 CHECK_FATAL(hasReached, "const not found");
566 return resultConst;
567 }
568
Finish()569 void Simplify::Finish() {}
570
571 // Join `num` `byte`s into a number
572 // Example:
573 // byte num output
574 // 0x0a 2 0x0a0a
575 // 0x12 4 0x12121212
576 // 0xff 8 0xffffffffffffffff
JoinBytes(int byte,uint32 num)577 static uint64 JoinBytes(int byte, uint32 num)
578 {
579 CHECK_FATAL(num <= 8, "not support"); // just support num less or equal 8, see comment above
580 uint64 realByte = static_cast<uint64>(byte % 256);
581 if (realByte == 0) {
582 return 0;
583 }
584 uint64 result = 0;
585 for (uint32 i = 0; i < num; ++i) {
586 result += (realByte << (i * k8BitSize));
587 }
588 return result;
589 }
590
591 // Return Fold result expr, does not always return a constant expr
592 // Attention: Fold may modify the input expr, if foldExpr is not a nullptr, we should always replace expr with foldExpr
FoldIntConst(BaseNode * expr,uint64 & out,bool & isIntConst)593 static BaseNode *FoldIntConst(BaseNode *expr, uint64 &out, bool &isIntConst)
594 {
595 if (expr->GetOpCode() == OP_constval) {
596 MIRConst *mirConst = static_cast<ConstvalNode *>(expr)->GetConstVal();
597 if (mirConst->GetKind() == kConstInt) {
598 out = static_cast<MIRIntConst *>(mirConst)->GetExtValue();
599 isIntConst = true;
600 }
601 return nullptr;
602 }
603 BaseNode *foldExpr = nullptr;
604 static ConstantFold cf(*theMIRModule);
605 foldExpr = cf.Fold(expr);
606 if (foldExpr != nullptr && foldExpr->GetOpCode() == OP_constval) {
607 MIRConst *mirConst = static_cast<ConstvalNode *>(foldExpr)->GetConstVal();
608 if (mirConst->GetKind() == kConstInt) {
609 out = static_cast<MIRIntConst *>(mirConst)->GetExtValue();
610 isIntConst = true;
611 }
612 }
613 return foldExpr;
614 }
615
ConstructConstvalNode(uint64 val,PrimType primType,MIRBuilder & mirBuilder)616 static BaseNode *ConstructConstvalNode(uint64 val, PrimType primType, MIRBuilder &mirBuilder)
617 {
618 PrimType constPrimType = primType;
619 if (IsPrimitiveFloat(primType)) {
620 constPrimType = GetIntegerPrimTypeBySizeAndSign(GetPrimTypeBitSize(primType), false);
621 }
622 MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constPrimType));
623 MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, *constType);
624 BaseNode *ret = mirBuilder.CreateConstval(mirConst);
625 if (IsPrimitiveFloat(primType)) {
626 MIRType *floatType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(primType));
627 ret = mirBuilder.CreateExprRetype(*floatType, constPrimType, ret);
628 }
629 return ret;
630 }
631
ConstructConstvalNode(int64 byte,uint64 num,PrimType primType,MIRBuilder & mirBuilder)632 static BaseNode *ConstructConstvalNode(int64 byte, uint64 num, PrimType primType, MIRBuilder &mirBuilder)
633 {
634 auto val = JoinBytes(byte, static_cast<uint32>(num));
635 return ConstructConstvalNode(val, primType, mirBuilder);
636 }
637
638 // Input total size of memory, split the memory into several blocks, the max block size is 8 bytes
639 // Example:
640 // input output
641 // 40 [ 8, 8, 8, 8, 8 ]
642 // 31 [ 8, 8, 8, 4, 2, 1 ]
SplitMemoryIntoBlocks(size_t totalMemorySize,std::vector<uint32> & blocks)643 static void SplitMemoryIntoBlocks(size_t totalMemorySize, std::vector<uint32> &blocks)
644 {
645 size_t leftSize = totalMemorySize;
646 size_t curBlockSize = kMaxMemoryBlockSizeToAssign; // max block size in byte
647 while (curBlockSize > 0) {
648 size_t n = leftSize / curBlockSize;
649 blocks.insert(blocks.end(), n, curBlockSize);
650 leftSize -= (n * curBlockSize);
651 curBlockSize = curBlockSize >> 1;
652 }
653 }
654
IsComplexExpr(const BaseNode * expr,MIRFunction & func)655 static bool IsComplexExpr(const BaseNode *expr, MIRFunction &func)
656 {
657 Opcode op = expr->GetOpCode();
658 if (op == OP_regread) {
659 return false;
660 }
661 if (op == OP_dread) {
662 auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const DreadNode *>(expr)->GetStIdx());
663 if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
664 return true; // dread global/static var is complex expr because it will be lowered to adrp + add
665 } else {
666 return false;
667 }
668 }
669 if (op == OP_addrof) {
670 auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const AddrofNode *>(expr)->GetStIdx());
671 if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
672 return true; // addrof global/static var is complex expr because it will be lowered to adrp + add
673 } else {
674 return false;
675 }
676 }
677 return true;
678 }
679
680 // Input a address expr, output a memEntry to abstract this expr
ComputeMemEntry(BaseNode & expr,MIRFunction & func,MemEntry & memEntry,bool isLowLevel)681 bool MemEntry::ComputeMemEntry(BaseNode &expr, MIRFunction &func, MemEntry &memEntry, bool isLowLevel)
682 {
683 Opcode op = expr.GetOpCode();
684 MIRType *memType = nullptr;
685 switch (op) {
686 case OP_dread: {
687 const auto &concreteExpr = static_cast<const DreadNode &>(expr);
688 auto *symbol = func.GetLocalOrGlobalSymbol(concreteExpr.GetStIdx());
689 MIRType *curType = symbol->GetType();
690 if (concreteExpr.GetFieldID() != 0) {
691 curType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
692 }
693 // Support kTypeScalar ptr if possible
694 if (curType->GetKind() == kTypePointer) {
695 memType = static_cast<MIRPtrType *>(curType)->GetPointedType();
696 }
697 break;
698 }
699 case OP_addrof: {
700 const auto &concreteExpr = static_cast<const AddrofNode &>(expr);
701 auto *symbol = func.GetLocalOrGlobalSymbol(concreteExpr.GetStIdx());
702 MIRType *curType = symbol->GetType();
703 if (concreteExpr.GetFieldID() != 0) {
704 curType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
705 }
706 memType = curType;
707 break;
708 }
709 case OP_iread: {
710 const auto &concreteExpr = static_cast<const IreadNode &>(expr);
711 memType = concreteExpr.GetType();
712 break;
713 }
714 case OP_iaddrof: { // Do NOT call GetType because it is for OP_iread
715 const auto &concreteExpr = static_cast<const IaddrofNode &>(expr);
716 MIRType *curType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(concreteExpr.GetTyIdx());
717 CHECK_FATAL(curType->IsMIRPtrType(), "must be MIRPtrType");
718 curType = static_cast<MIRPtrType *>(curType)->GetPointedType();
719 CHECK_FATAL(curType->IsStructType(), "must be MIRStructType");
720 memType = static_cast<MIRStructType *>(curType)->GetFieldType(concreteExpr.GetFieldID());
721 break;
722 }
723 case OP_regread: {
724 if (isLowLevel && IsPrimitivePoint(expr.GetPrimType())) {
725 memEntry.addrExpr = &expr;
726 memEntry.memType =
727 nullptr; // we cannot infer high level memory type, this is allowed for low level expand
728 return true;
729 }
730 const auto &concreteExpr = static_cast<const RegreadNode &>(expr);
731 MIRPreg *preg = func.GetPregItem(concreteExpr.GetRegIdx());
732 bool isFromDread = (preg->GetOp() == OP_dread);
733 bool isFromAddrof = (preg->GetOp() == OP_addrof);
734 if (isFromDread || isFromAddrof) {
735 auto *symbol = preg->rematInfo.sym;
736 auto fieldId = preg->fieldID;
737 MIRType *curType = symbol->GetType();
738 if (fieldId != 0) {
739 curType = static_cast<MIRStructType *>(symbol->GetType())->GetFieldType(fieldId);
740 }
741 if (isFromDread && curType->GetKind() == kTypePointer) {
742 curType = static_cast<MIRPtrType *>(curType)->GetPointedType();
743 }
744 memType = curType;
745 }
746 break;
747 }
748 default: {
749 if (isLowLevel && IsPrimitivePoint(expr.GetPrimType())) {
750 memEntry.addrExpr = &expr;
751 memEntry.memType =
752 nullptr; // we cannot infer high level memory type, this is allowed for low level expand
753 return true;
754 }
755 break;
756 }
757 }
758 if (memType == nullptr) {
759 return false;
760 }
761 memEntry.addrExpr = &expr;
762 memEntry.memType = memType;
763 return true;
764 }
765
BuildAsRhsExpr(MIRFunction & func) const766 BaseNode *MemEntry::BuildAsRhsExpr(MIRFunction &func) const
767 {
768 BaseNode *expr = nullptr;
769 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
770 if (addrExpr->GetOpCode() == OP_addrof) {
771 // We prefer dread to iread
772 // consider iaddrof if possible
773 auto *addrof = static_cast<AddrofNode *>(addrExpr);
774 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
775 expr = mirBuilder->CreateExprDread(*memType, addrof->GetFieldID(), *symbol);
776 } else {
777 MIRType *structPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
778 expr = mirBuilder->CreateExprIread(*memType, *structPtrType, 0, addrExpr);
779 }
780 return expr;
781 }
782
InsertAndMayPrintStmt(BlockNode & block,const StmtNode & anchor,bool debug,StmtNode * stmt)783 static void InsertAndMayPrintStmt(BlockNode &block, const StmtNode &anchor, bool debug, StmtNode *stmt)
784 {
785 if (stmt == nullptr) {
786 return;
787 }
788 block.InsertBefore(&anchor, stmt);
789 if (debug) {
790 stmt->Dump(0);
791 }
792 }
793
InsertBeforeAndMayPrintStmtList(BlockNode & block,const StmtNode & anchor,bool debug,std::initializer_list<StmtNode * > stmtList)794 static void InsertBeforeAndMayPrintStmtList(BlockNode &block, const StmtNode &anchor, bool debug,
795 std::initializer_list<StmtNode *> stmtList)
796 {
797 for (StmtNode *stmt : stmtList) {
798 if (stmt == nullptr) {
799 continue;
800 }
801 block.InsertBefore(&anchor, stmt);
802 if (debug) {
803 stmt->Dump(0);
804 }
805 }
806 }
807
NeedCheck(MemOpKind memOpKind)808 static bool NeedCheck(MemOpKind memOpKind)
809 {
810 if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
811 return true;
812 }
813 return false;
814 }
815
816 // Create maple IR to check whether `expr` is a null pointer, IR is as follows:
817 // brfalse @@n1 (ne u8 ptr (regread ptr %1, constval u64 0))
CreateNullptrCheckStmt(BaseNode & expr,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)818 static CondGotoNode *CreateNullptrCheckStmt(BaseNode &expr, MIRFunction &func, MIRBuilder *mirBuilder,
819 const MIRType &cmpResType, const MIRType &cmpOpndType)
820 {
821 LabelIdx nullLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('n'); // 'n' means nullptr
822 auto *checkExpr = mirBuilder->CreateExprCompare(OP_ne, cmpResType, cmpOpndType, &expr,
823 ConstructConstvalNode(0, PTY_u64, *mirBuilder));
824 auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, nullLabIdx);
825 return checkStmt;
826 }
827
828 // Create maple IR to check whether `expr1` and `expr2` are equal
829 // brfalse @@a1 (ne u8 ptr (regread ptr %1, regread ptr %2))
CreateAddressEqualCheckStmt(BaseNode & expr1,BaseNode & expr2,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)830 static CondGotoNode *CreateAddressEqualCheckStmt(BaseNode &expr1, BaseNode &expr2, MIRFunction &func,
831 MIRBuilder *mirBuilder, const MIRType &cmpResType,
832 const MIRType &cmpOpndType)
833 {
834 LabelIdx equalLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('a'); // 'a' means address equal
835 auto *checkExpr = mirBuilder->CreateExprCompare(OP_ne, cmpResType, cmpOpndType, &expr1, &expr2);
836 auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, equalLabIdx);
837 return checkStmt;
838 }
839
840 // Create maple IR to check whether `expr1` and `expr2` are overlapped
841 // brfalse @@o1 (ge u8 ptr (
842 // abs ptr (sub ptr (regread ptr %1, regread ptr %2)),
843 // constval u64 xxx))
CreateOverlapCheckStmt(BaseNode & expr1,BaseNode & expr2,uint32 size,MIRFunction & func,MIRBuilder * mirBuilder,const MIRType & cmpResType,const MIRType & cmpOpndType)844 static CondGotoNode *CreateOverlapCheckStmt(BaseNode &expr1, BaseNode &expr2, uint32 size, MIRFunction &func,
845 MIRBuilder *mirBuilder, const MIRType &cmpResType,
846 const MIRType &cmpOpndType)
847 {
848 LabelIdx overlapLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('o'); // 'n' means overlap
849 auto *checkExpr = mirBuilder->CreateExprCompare(
850 OP_ge, cmpResType, cmpOpndType,
851 mirBuilder->CreateExprUnary(OP_abs, cmpOpndType,
852 mirBuilder->CreateExprBinary(OP_sub, cmpOpndType, &expr1, &expr2)),
853 ConstructConstvalNode(size, PTY_u64, *mirBuilder));
854 auto *checkStmt = mirBuilder->CreateStmtCondGoto(checkExpr, OP_brfalse, overlapLabIdx);
855 return checkStmt;
856 }
857
858 // Generate IR to handle nullptr, IR is as follows:
859 // @curLabel
860 // regassign i32 %1 (constval i32 errNum)
861 // goto @finalLabel
AddNullptrHandlerIR(const StmtNode & stmt,MIRBuilder * mirBuilder,BlockNode & block,StmtNode * retAssign,LabelIdx curLabIdx,LabelIdx finalLabIdx,bool debug)862 static void AddNullptrHandlerIR(const StmtNode &stmt, MIRBuilder *mirBuilder, BlockNode &block, StmtNode *retAssign,
863 LabelIdx curLabIdx, LabelIdx finalLabIdx, bool debug)
864 {
865 auto *curLabelNode = mirBuilder->CreateStmtLabel(curLabIdx);
866 auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
867 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {curLabelNode, retAssign, gotoFinal});
868 }
869
AddMemsetCallStmt(const StmtNode & stmt,MIRFunction & func,BlockNode & block,BaseNode * addrExpr)870 static void AddMemsetCallStmt(const StmtNode &stmt, MIRFunction &func, BlockNode &block, BaseNode *addrExpr)
871 {
872 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
873 // call memset for dst memory when detecting overlapping
874 MapleVector<BaseNode *> args(mirBuilder->GetCurrentFuncCodeMpAllocator()->Adapter());
875 args.push_back(addrExpr);
876 args.push_back(ConstructConstvalNode(0, PTY_i32, *mirBuilder));
877 args.push_back(stmt.Opnd(1));
878 auto *callMemset = mirBuilder->CreateStmtCall("memset", args);
879 block.InsertBefore(&stmt, callMemset);
880 }
881
882 // Generate IR to handle errors that should be reset with memset, IR is as follows:
883 // @curLabel
884 // regassign i32 %1 (constval i32 errNum)
885 // call memset # new genrated memset will be expanded if possible
886 // goto @finalLabel
AddResetHandlerIR(const StmtNode & stmt,MIRFunction & func,BlockNode & block,StmtNode * retAssign,LabelIdx curLabIdx,LabelIdx finalLabIdx,BaseNode * addrExpr,bool debug)887 static void AddResetHandlerIR(const StmtNode &stmt, MIRFunction &func, BlockNode &block, StmtNode *retAssign,
888 LabelIdx curLabIdx, LabelIdx finalLabIdx, BaseNode *addrExpr, bool debug)
889 {
890 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
891 auto *curLabelNode = mirBuilder->CreateStmtLabel(curLabIdx);
892 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {curLabelNode, retAssign});
893 AddMemsetCallStmt(stmt, func, block, addrExpr);
894 auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
895 InsertAndMayPrintStmt(block, stmt, debug, gotoFinal);
896 }
897
TryToExtractComplexExpr(BaseNode * expr,MIRFunction & func,BlockNode & block,const StmtNode & anchor,bool debug)898 static BaseNode *TryToExtractComplexExpr(BaseNode *expr, MIRFunction &func, BlockNode &block, const StmtNode &anchor,
899 bool debug)
900 {
901 if (!IsComplexExpr(expr, func)) {
902 return expr;
903 }
904 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
905 auto pregIdx = func.GetPregTab()->CreatePreg(PTY_ptr);
906 StmtNode *regassign = mirBuilder->CreateStmtRegassign(PTY_ptr, pregIdx, expr);
907 InsertAndMayPrintStmt(block, anchor, debug, regassign);
908 auto *extractedExpr = mirBuilder->CreateExprRegread(PTY_ptr, pregIdx);
909 return extractedExpr;
910 }
911
InsertCheckFailedBranch(MIRFunction & func,StmtNode & stmt,BlockNode & block,LabelIdx branchLabIdx,LabelIdx finalLabIdx,ErrorNumber errNumber,MemOpKind memOpKind,bool debug)912 static void InsertCheckFailedBranch(MIRFunction &func, StmtNode &stmt, BlockNode &block, LabelIdx branchLabIdx,
913 LabelIdx finalLabIdx, ErrorNumber errNumber, MemOpKind memOpKind, bool debug)
914 {
915 auto mirBuilder = func.GetModule()->GetMIRBuilder();
916 auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
917 auto branchLabNode = mirBuilder->CreateStmtLabel(branchLabIdx);
918 auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, memOpKind, errNumber);
919 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {branchLabNode, errnoAssign, gotoFinal});
920 }
921
InsertMemsetCallStmt(const MapleVector<BaseNode * > & args,MIRFunction & func,StmtNode & stmt,BlockNode & block,LabelIdx finalLabIdx,ErrorNumber errorNumber,bool debug)922 static StmtNode *InsertMemsetCallStmt(const MapleVector<BaseNode *> &args, MIRFunction &func, StmtNode &stmt,
923 BlockNode &block, LabelIdx finalLabIdx, ErrorNumber errorNumber, bool debug)
924 {
925 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
926 auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
927 auto memsetFunc = mirBuilder->GetOrCreateFunction(kFuncNameOfMemset, TyIdx(PTY_void));
928 auto memsetCallStmt = mirBuilder->CreateStmtCallAssigned(memsetFunc->GetPuidx(), args, nullptr, OP_callassigned);
929 memsetCallStmt->SetSrcPos(stmt.GetSrcPos());
930 auto *errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, errorNumber);
931 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {memsetCallStmt, errnoAssign, gotoFinal});
932 return memsetCallStmt;
933 }
934
CreateAndInsertCheckStmt(Opcode op,BaseNode * lhs,BaseNode * rhs,LabelIdx label,StmtNode & stmt,BlockNode & block,MIRFunction & func,bool debug)935 static void CreateAndInsertCheckStmt(Opcode op, BaseNode *lhs, BaseNode *rhs, LabelIdx label, StmtNode &stmt,
936 BlockNode &block, MIRFunction &func, bool debug)
937 {
938 auto cmpResType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u8));
939 auto cmpU64Type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u64));
940 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
941 auto cmpStmt = mirBuilder->CreateExprCompare(op, *cmpResType, *cmpU64Type, lhs, rhs);
942 auto checkStmt = mirBuilder->CreateStmtCondGoto(cmpStmt, OP_brtrue, label);
943 checkStmt->SetBranchProb(kProbUnlikely);
944 checkStmt->SetSrcPos(stmt.GetSrcPos());
945 InsertAndMayPrintStmt(block, stmt, debug, checkStmt);
946 }
947
ExpandOnSrcSizeGtDstSize(StmtNode & stmt,BlockNode & block,int64 srcSize,LabelIdx finalLabIdx,LabelIdx nullPtrLabIdx,MIRFunction & func,bool debug)948 static StmtNode *ExpandOnSrcSizeGtDstSize(StmtNode &stmt, BlockNode &block, int64 srcSize, LabelIdx finalLabIdx,
949 LabelIdx nullPtrLabIdx, MIRFunction &func, bool debug)
950 {
951 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
952 MapleVector<BaseNode *> args(func.GetCodeMempoolAllocator().Adapter());
953 args.push_back(stmt.Opnd(kMemsetDstOpndIdx));
954 args.push_back(stmt.Opnd(kMemsetSSrcOpndIdx));
955 args.push_back(ConstructConstvalNode(srcSize, stmt.Opnd(kMemsetSSrcSizeOpndIdx)->GetPrimType(), *mirBuilder));
956 auto memsetFunc = mirBuilder->GetOrCreateFunction(kFuncNameOfMemset, TyIdx(PTY_void));
957 auto callStmt = mirBuilder->CreateStmtCallAssigned(memsetFunc->GetPuidx(), args, nullptr, OP_callassigned);
958 callStmt->SetSrcPos(stmt.GetSrcPos());
959 InsertAndMayPrintStmt(block, stmt, debug, callStmt);
960 auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
961 auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, ERRNO_RANGE_AND_RESET);
962 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
963 InsertCheckFailedBranch(func, stmt, block, nullPtrLabIdx, finalLabIdx, ERRNO_INVAL, MEM_OP_memset_s, debug);
964 auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
965 InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
966 block.RemoveStmt(&stmt);
967 return callStmt;
968 }
969
HandleZeroValueOfDstSize(StmtNode & stmt,BlockNode & block,int64 srcSize,int64 dstSize,LabelIdx finalLabIdx,LabelIdx dstSizeCheckLabIdx,MIRFunction & func,bool isDstSizeConst,bool debug)970 static void HandleZeroValueOfDstSize(StmtNode &stmt, BlockNode &block, int64 srcSize, int64 dstSize,
971 LabelIdx finalLabIdx, LabelIdx dstSizeCheckLabIdx, MIRFunction &func,
972 bool isDstSizeConst, bool debug)
973 {
974 uint32 dstSizeOpndIdx = 1; // only used by memset_s
975 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
976 if (!isDstSizeConst) {
977 CreateAndInsertCheckStmt(OP_eq, stmt.Opnd(dstSizeOpndIdx), ConstructConstvalNode(0, PTY_u64, *mirBuilder),
978 dstSizeCheckLabIdx, stmt, block, func, debug);
979 } else if (dstSize == 0) {
980 auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
981 auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, func, true, MEM_OP_memset_s, ERRNO_RANGE);
982 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
983 }
984 }
985
ExpandMemsetLowLevel(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const986 void MemEntry::ExpandMemsetLowLevel(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
987 MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
988 {
989 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
990 std::vector<uint32> blocks;
991 SplitMemoryIntoBlocks(size, blocks);
992 int32 offset = 0;
993 // If blocks.size() > 1 and `dst` is not a leaf node,
994 // we should extract common expr to avoid redundant expression
995 BaseNode *realDstExpr = addrExpr;
996 if (blocks.size() > 1) {
997 realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
998 }
999 BaseNode *readConst = nullptr;
1000 // rhs const is big, extract it to avoid redundant expression
1001 bool shouldExtractRhs = blocks.size() > 1 && (byte & 0xff) != 0;
1002 for (auto curSize : blocks) {
1003 // low level memset expand result:
1004 // iassignoff <prim-type> <offset> (dstAddrExpr, constval <prim-type> xx)
1005 PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1006 BaseNode *rhsExpr = ConstructConstvalNode(byte, curSize, constType, *mirBuilder);
1007 if (shouldExtractRhs) {
1008 // we only need to extract u64 const once
1009 PregIdx pregIdx = func.GetPregTab()->CreatePreg(constType);
1010 auto *constAssign = mirBuilder->CreateStmtRegassign(constType, pregIdx, rhsExpr);
1011 InsertAndMayPrintStmt(block, stmt, debug, constAssign);
1012 readConst = mirBuilder->CreateExprRegread(constType, pregIdx);
1013 shouldExtractRhs = false;
1014 }
1015 if (readConst != nullptr && curSize == kMaxMemoryBlockSizeToAssign) {
1016 rhsExpr = readConst;
1017 }
1018 auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
1019 InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
1020 if (debug) {
1021 iassignoff->Dump(0);
1022 }
1023 offset += static_cast<int32>(curSize);
1024 }
1025 // handle memset return val
1026 auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1027 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1028 // return ERRNO_INVAL if memset_s dest is NULL
1029 block.RemoveStmt(&stmt);
1030 }
1031
1032 // Lower memset(MemEntry, byte, size) into a series of assign stmts and replace callStmt in the block
1033 // with these assign stmts
ExpandMemset(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,bool isLowLevel,bool debug,ErrorNumber errorNumber) const1034 bool MemEntry::ExpandMemset(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
1035 bool isLowLevel, bool debug, ErrorNumber errorNumber) const
1036 {
1037 MemOpKind memOpKind = SimplifyMemOp::ComputeMemOpKind(stmt);
1038 MemEntryKind memKind = GetKind();
1039 // we don't check size equality in the low level expand
1040 if (!isLowLevel) {
1041 if (memKind == kMemEntryUnknown) {
1042 MayPrintLog(debug, false, memOpKind, "unsupported dst memory type, is it a bitfield?");
1043 return false;
1044 }
1045 if (memType->GetSize() != size) {
1046 MayPrintLog(debug, false, memOpKind, "dst size and size arg are not equal");
1047 return false;
1048 }
1049 }
1050 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1051
1052 if (isLowLevel) { // For cglower, replace memset with a series of low-level iassignoff
1053 ExpandMemsetLowLevel(byte, size, func, stmt, block, memOpKind, debug, errorNumber);
1054 return true;
1055 }
1056
1057 if (memKind == kMemEntryPrimitive) {
1058 BaseNode *rhsExpr = ConstructConstvalNode(byte, size, memType->GetPrimType(), *mirBuilder);
1059 StmtNode *newAssign = nullptr;
1060 if (addrExpr->GetOpCode() == OP_addrof) { // We prefer dassign to iassign
1061 auto *addrof = static_cast<AddrofNode *>(addrExpr);
1062 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1063 newAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID(), rhsExpr);
1064 } else {
1065 MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1066 newAssign = mirBuilder->CreateStmtIassign(*memPtrType, 0, addrExpr, rhsExpr);
1067 }
1068 InsertAndMayPrintStmt(block, stmt, debug, newAssign);
1069 } else if (memKind == kMemEntryStruct) {
1070 auto *structType = static_cast<MIRStructType *>(memType);
1071 // struct size should be small enough, struct field size should be big enough
1072 constexpr uint32 maxStructSize = 64; // in byte
1073 constexpr uint32 minFieldSize = 4; // in byte
1074 size_t structSize = structType->GetSize();
1075 size_t numFields = structType->NumberOfFieldIDs();
1076 // Relax restrictions when store-merge is powerful enough
1077 bool expandIt =
1078 (structSize <= maxStructSize && (structSize / numFields >= minFieldSize) && !structType->HasPadding());
1079 if (!expandIt) {
1080 // We only expand memset for no-padding struct, because only in this case, element-wise and byte-wise
1081 // are equivalent
1082 MayPrintLog(debug, false, memOpKind,
1083 "struct type has padding, or struct sum size is too big, or filed size is too small");
1084 return false;
1085 }
1086 bool hasArrayField = false;
1087 for (uint32 id = 1; id <= numFields; ++id) {
1088 auto *fieldType = structType->GetFieldType(id);
1089 if (fieldType->GetKind() == kTypeArray) {
1090 hasArrayField = true;
1091 break;
1092 }
1093 }
1094 if (hasArrayField) {
1095 // struct with array fields is not supported to expand for now, enhance it when needed
1096 MayPrintLog(debug, false, memOpKind, "struct with array fields is not supported to expand");
1097 return false;
1098 }
1099
1100 // Build assign for each fields in the struct type
1101 // We should skip union fields
1102 for (FieldID id = 1; static_cast<size_t>(id) <= numFields; ++id) {
1103 MIRType *fieldType = structType->GetFieldType(id);
1104 // We only consider leaf field with valid type size
1105 if (fieldType->GetSize() == 0 || fieldType->GetPrimType() == PTY_agg) {
1106 continue;
1107 }
1108 if (fieldType->GetKind() == kTypeBitField &&
1109 static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
1110 continue;
1111 }
1112 // now the fieldType is primitive type
1113 BaseNode *rhsExpr =
1114 ConstructConstvalNode(byte, fieldType->GetSize(), fieldType->GetPrimType(), *mirBuilder);
1115 StmtNode *fieldAssign = nullptr;
1116 if (addrExpr->GetOpCode() == OP_addrof) {
1117 auto *addrof = static_cast<AddrofNode *>(addrExpr);
1118 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1119 fieldAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID() + id, rhsExpr);
1120 } else {
1121 MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1122 fieldAssign = mirBuilder->CreateStmtIassign(*memPtrType, id, addrExpr, rhsExpr);
1123 }
1124 InsertAndMayPrintStmt(block, stmt, debug, fieldAssign);
1125 }
1126 } else if (memKind == kMemEntryArray) {
1127 // We only consider array with dim == 1 now, and element type must be primitive type
1128 auto *arrayType = static_cast<MIRArrayType *>(memType);
1129 if (arrayType->GetDim() != 1 || (arrayType->GetElemType()->GetKind() != kTypeScalar &&
1130 arrayType->GetElemType()->GetKind() != kTypePointer)) {
1131 MayPrintLog(debug, false, memOpKind, "array dim != 1 or array elements are not primtive type");
1132 return false;
1133 }
1134 MIRType *elemType = arrayType->GetElemType();
1135 if (elemType->GetSize() < k4BitSize) {
1136 MayPrintLog(debug, false, memOpKind,
1137 "element size < 4, don't expand it to avoid to genearte lots of strb/strh");
1138 return false;
1139 }
1140 uint64 elemCnt = static_cast<uint64>(arrayType->GetSizeArrayItem(0));
1141 if (elemType->GetSize() * elemCnt != size) {
1142 MayPrintLog(debug, false, memOpKind, "array size not equal");
1143 return false;
1144 }
1145 for (size_t i = 0; i < elemCnt; ++i) {
1146 BaseNode *indexExpr = ConstructConstvalNode(i, PTY_u32, *mirBuilder);
1147 auto *arrayExpr = mirBuilder->CreateExprArray(*arrayType, addrExpr, indexExpr);
1148 auto *newValOpnd = ConstructConstvalNode(byte, elemType->GetSize(), elemType->GetPrimType(), *mirBuilder);
1149 MIRType *elemPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*elemType);
1150 auto *arrayElementAssign = mirBuilder->CreateStmtIassign(*elemPtrType, 0, arrayExpr, newValOpnd);
1151 InsertAndMayPrintStmt(block, stmt, debug, arrayElementAssign);
1152 }
1153 } else {
1154 CHECK_FATAL(false, "impossible");
1155 }
1156
1157 // handle memset return val
1158 auto *retAssign = GenMemopRetAssign(stmt, func, isLowLevel, memOpKind, errorNumber);
1159 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1160 block.RemoveStmt(&stmt);
1161 return true;
1162 }
1163
GenerateMemoryCopyPair(MIRBuilder * mirBuilder,BaseNode * rhs,BaseNode * lhs,uint32 offset,uint32 curSize,PregIdx tmpRegIdx)1164 static std::pair<StmtNode *, StmtNode *> GenerateMemoryCopyPair(MIRBuilder *mirBuilder, BaseNode *rhs, BaseNode *lhs,
1165 uint32 offset, uint32 curSize, PregIdx tmpRegIdx)
1166 {
1167 auto *ptrType = GlobalTables::GetTypeTable().GetPtrType();
1168 PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1169 MIRType *constMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constType));
1170 auto *constMIRPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*constMIRType);
1171 BaseNode *rhsAddrExpr = rhs;
1172 if (offset != 0) {
1173 auto *offsetConstExpr = ConstructConstvalNode(offset, PTY_u64, *mirBuilder);
1174 rhsAddrExpr = mirBuilder->CreateExprBinary(OP_add, *ptrType, rhs, offsetConstExpr);
1175 }
1176 BaseNode *rhsExpr = mirBuilder->CreateExprIread(*constMIRType, *constMIRPtrType, 0, rhsAddrExpr);
1177 auto *regassign = mirBuilder->CreateStmtRegassign(PTY_u64, tmpRegIdx, rhsExpr);
1178 auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, static_cast<int32>(offset), lhs,
1179 mirBuilder->CreateExprRegread(PTY_u64, tmpRegIdx));
1180 return {regassign, iassignoff};
1181 }
1182
ExpandMemcpyLowLevel(const MemEntry & srcMem,uint64 copySize,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const1183 void MemEntry::ExpandMemcpyLowLevel(const MemEntry &srcMem, uint64 copySize, MIRFunction &func, StmtNode &stmt,
1184 BlockNode &block, MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
1185 {
1186 if (errorNumber == ERRNO_RANGE) {
1187 auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1188 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1189 block.RemoveStmt(&stmt);
1190 return;
1191 }
1192 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1193 std::vector<uint32> blocks;
1194 SplitMemoryIntoBlocks(copySize, blocks);
1195 uint32 offset = 0;
1196 // If blocks.size() > 1 and `src` or `dst` is not a leaf node,
1197 // we should extract common expr to avoid redundant expression
1198 BaseNode *realSrcExpr = srcMem.addrExpr;
1199 BaseNode *realDstExpr = addrExpr;
1200 if (blocks.size() > 1) {
1201 realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
1202 realSrcExpr = TryToExtractComplexExpr(srcMem.addrExpr, func, block, stmt, debug);
1203 }
1204 auto *ptrType = GlobalTables::GetTypeTable().GetPtrType();
1205 LabelIdx dstNullLabIdx;
1206 LabelIdx srcNullLabIdx;
1207 LabelIdx overlapLabIdx;
1208 LabelIdx addressEqualLabIdx;
1209 if (NeedCheck(memOpKind)) {
1210 auto *cmpResType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_u8));
1211 auto *cmpOpndType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_ptr));
1212 // check dst != NULL
1213 auto *checkDstStmt = CreateNullptrCheckStmt(*realDstExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1214 dstNullLabIdx = checkDstStmt->GetOffset();
1215 // check src != NULL
1216 auto *checkSrcStmt = CreateNullptrCheckStmt(*realSrcExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1217 srcNullLabIdx = checkSrcStmt->GetOffset();
1218 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {checkDstStmt, checkSrcStmt});
1219 if (errorNumber != ERRNO_RANGE_AND_RESET) {
1220 // check src == dst
1221 auto *checkAddrEqualStmt =
1222 CreateAddressEqualCheckStmt(*realDstExpr, *realSrcExpr, func, mirBuilder, *cmpResType, *cmpOpndType);
1223 addressEqualLabIdx = checkAddrEqualStmt->GetOffset();
1224 // check overlap
1225 auto *checkOverlapStmt = CreateOverlapCheckStmt(*realDstExpr, *realSrcExpr, static_cast<uint32>(copySize),
1226 func, mirBuilder, *cmpResType, *cmpOpndType);
1227 overlapLabIdx = checkOverlapStmt->GetOffset();
1228 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {checkAddrEqualStmt, checkOverlapStmt});
1229 }
1230 }
1231 if (errorNumber == ERRNO_RANGE_AND_RESET) {
1232 AddMemsetCallStmt(stmt, func, block, addrExpr);
1233 } else {
1234 // memory copy optimization
1235 PregIdx tmpRegIdx1 = 0;
1236 PregIdx tmpRegIdx2 = 0;
1237 for (uint32 i = 0; i < blocks.size(); ++i) {
1238 uint32 curSize = blocks[i];
1239 bool canMergedWithNextSize = (i + 1 < blocks.size()) && blocks[i + 1] == curSize;
1240 if (!canMergedWithNextSize) {
1241 // low level memcpy expand result:
1242 // It seems ireadoff has not been supported by cg HandleFunction, so we use iread instead of ireadoff
1243 // [not support] iassignoff <prim-type> <offset> (dstAddrExpr, ireadoff <prim-type> <offset>
1244 // (srcAddrExpr)) [ok] iassignoff <prim-type> <offset> (dstAddrExpr, iread <prim-type> <type> (add ptr
1245 // (srcAddrExpr, offset)))
1246 PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
1247 MIRType *constMIRType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constType));
1248 auto *constMIRPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*constMIRType);
1249 BaseNode *rhsAddrExpr = realSrcExpr;
1250 if (offset != 0) {
1251 auto *offsetConstExpr = ConstructConstvalNode(offset, PTY_u64, *mirBuilder);
1252 rhsAddrExpr = mirBuilder->CreateExprBinary(OP_add, *ptrType, realSrcExpr, offsetConstExpr);
1253 }
1254 BaseNode *rhsExpr = mirBuilder->CreateExprIread(*constMIRType, *constMIRPtrType, 0, rhsAddrExpr);
1255 auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
1256 InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
1257 offset += curSize;
1258 continue;
1259 }
1260
1261 // merge two str/ldr into a stp/ldp
1262 if (tmpRegIdx1 == 0 || tmpRegIdx2 == 0) {
1263 tmpRegIdx1 = func.GetPregTab()->CreatePreg(PTY_u64);
1264 tmpRegIdx2 = func.GetPregTab()->CreatePreg(PTY_u64);
1265 }
1266 auto pair1 = GenerateMemoryCopyPair(mirBuilder, realSrcExpr, realDstExpr, offset, curSize, tmpRegIdx1);
1267 auto pair2 =
1268 GenerateMemoryCopyPair(mirBuilder, realSrcExpr, realDstExpr, offset + curSize, curSize, tmpRegIdx2);
1269 // insert order: regassign1, regassign2, iassignoff1, iassignoff2
1270 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {pair1.first, pair2.first, pair1.second, pair2.second});
1271 offset += (curSize << 1);
1272 ++i;
1273 }
1274 }
1275 // handle memcpy return val
1276 auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
1277 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1278 if (NeedCheck(memOpKind)) {
1279 LabelIdx finalLabIdx = func.GetLabelTab()->CreateLabelWithPrefix('f');
1280 auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
1281 // Add goto final stmt for expanded body
1282 auto *gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
1283 InsertAndMayPrintStmt(block, stmt, debug, gotoFinal);
1284 // Add handler IR if dst == NULL
1285 auto *dstErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_INVAL);
1286 AddNullptrHandlerIR(stmt, mirBuilder, block, dstErrAssign, dstNullLabIdx, finalLabIdx, debug);
1287 // Add handler IR if src == NULL
1288 auto *srcErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_INVAL_AND_RESET);
1289 AddResetHandlerIR(stmt, func, block, srcErrAssign, srcNullLabIdx, finalLabIdx, addrExpr, debug);
1290 if (errorNumber != ERRNO_RANGE_AND_RESET) {
1291 // Add handler IR if dst == src
1292 auto *addrEqualAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_OK);
1293 AddNullptrHandlerIR(stmt, mirBuilder, block, addrEqualAssign, addressEqualLabIdx, finalLabIdx, debug);
1294 // Add handler IR if dst and src are overlapped
1295 auto *overlapErrAssign = GenMemopRetAssign(stmt, func, true, memOpKind, ERRNO_OVERLAP_AND_RESET);
1296 AddResetHandlerIR(stmt, func, block, overlapErrAssign, overlapLabIdx, finalLabIdx, addrExpr, debug);
1297 }
1298 InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
1299 }
1300 block.RemoveStmt(&stmt);
1301 }
1302
ExpandMemcpy(const MemEntry & srcMem,uint64 copySize,MIRFunction & func,StmtNode & stmt,BlockNode & block,bool isLowLevel,bool debug,ErrorNumber errorNumber) const1303 bool MemEntry::ExpandMemcpy(const MemEntry &srcMem, uint64 copySize, MIRFunction &func, StmtNode &stmt,
1304 BlockNode &block, bool isLowLevel, bool debug, ErrorNumber errorNumber) const
1305 {
1306 MemOpKind memOpKind = SimplifyMemOp::ComputeMemOpKind(stmt);
1307 MemEntryKind memKind = GetKind();
1308 if (!isLowLevel) { // check type consistency and memKind only for high level expand
1309 if (memOpKind == MEM_OP_memcpy_s) {
1310 MayPrintLog(debug, false, memOpKind, "all memcpy_s will be handed by cglower");
1311 return false;
1312 }
1313 if (memType != srcMem.memType) {
1314 return false;
1315 }
1316 CHECK_FATAL(memKind != kMemEntryUnknown, "invalid memKind");
1317 }
1318 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1319 StmtNode *newAssign = nullptr;
1320 if (isLowLevel) { // For cglower, replace memcpy with a series of low-level iassignoff
1321 ExpandMemcpyLowLevel(srcMem, copySize, func, stmt, block, memOpKind, debug, errorNumber);
1322 return true;
1323 }
1324
1325 if (memKind == kMemEntryPrimitive || memKind == kMemEntryStruct) {
1326 // Do low level expand for all struct memcpy for now
1327 if (memKind == kMemEntryStruct) {
1328 MayPrintLog(debug, false, memOpKind, "Do low level expand for all struct memcpy for now");
1329 return false;
1330 }
1331 if (addrExpr->GetOpCode() == OP_addrof) { // We prefer dassign to iassign
1332 auto *addrof = static_cast<AddrofNode *>(addrExpr);
1333 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
1334 newAssign = mirBuilder->CreateStmtDassign(*symbol, addrof->GetFieldID(), srcMem.BuildAsRhsExpr(func));
1335 } else {
1336 MIRType *memPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
1337 newAssign = mirBuilder->CreateStmtIassign(*memPtrType, 0, addrExpr, srcMem.BuildAsRhsExpr(func));
1338 }
1339 InsertAndMayPrintStmt(block, stmt, debug, newAssign);
1340 // split struct agg copy
1341 if (memKind == kMemEntryStruct) {
1342 if (newAssign->GetOpCode() == OP_dassign) {
1343 (void)SplitDassignAggCopy(static_cast<DassignNode *>(newAssign), &block, &func);
1344 } else if (newAssign->GetOpCode() == OP_iassign) {
1345 (void)SplitIassignAggCopy(static_cast<IassignNode *>(newAssign), &block, &func);
1346 } else {
1347 CHECK_FATAL(false, "impossible");
1348 }
1349 }
1350 } else if (memKind == kMemEntryArray) {
1351 // We only consider array with dim == 1 now, and element type must be primitive type
1352 auto *arrayType = static_cast<MIRArrayType *>(memType);
1353 if (arrayType->GetDim() != 1 || (arrayType->GetElemType()->GetKind() != kTypeScalar &&
1354 arrayType->GetElemType()->GetKind() != kTypePointer)) {
1355 MayPrintLog(debug, false, memOpKind, "array dim != 1 or array elements are not primtive type");
1356 return false;
1357 }
1358 MIRType *elemType = arrayType->GetElemType();
1359 if (elemType->GetSize() < k4ByteSize) {
1360 MayPrintLog(debug, false, memOpKind,
1361 "element size < 4, don't expand it to avoid to genearte lots of strb/strh");
1362 return false;
1363 }
1364 size_t elemCnt = arrayType->GetSizeArrayItem(0);
1365 if (elemType->GetSize() * elemCnt != copySize) {
1366 MayPrintLog(debug, false, memOpKind, "array size not equal");
1367 return false;
1368 }
1369 // if srcExpr is too complex (for example: addrof expr of global/static array), let cg expand it
1370 if (elemCnt > 1 && IsComplexExpr(srcMem.addrExpr, func)) {
1371 MayPrintLog(debug, false, memOpKind, "srcExpr is too complex, let cg expand it to avoid redundant inst");
1372 return false;
1373 }
1374 MIRType *elemPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*elemType);
1375 MIRType *u32Type = GlobalTables::GetTypeTable().GetUInt32();
1376 for (size_t i = 0; i < elemCnt; ++i) {
1377 ConstvalNode *indexExpr = mirBuilder->CreateConstval(
1378 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(i), *u32Type));
1379 auto *arrayExpr = mirBuilder->CreateExprArray(*arrayType, addrExpr, indexExpr);
1380 auto *rhsArrayExpr = mirBuilder->CreateExprArray(*arrayType, srcMem.addrExpr, indexExpr);
1381 auto *rhsIreadExpr = mirBuilder->CreateExprIread(*elemType, *elemPtrType, 0, rhsArrayExpr);
1382 auto *arrayElemAssign = mirBuilder->CreateStmtIassign(*elemPtrType, 0, arrayExpr, rhsIreadExpr);
1383 InsertAndMayPrintStmt(block, stmt, debug, arrayElemAssign);
1384 }
1385 } else {
1386 CHECK_FATAL(false, "impossible");
1387 }
1388
1389 // handle memcpy return val
1390 auto *retAssign = GenMemopRetAssign(stmt, func, isLowLevel, memOpKind, errorNumber);
1391 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1392 block.RemoveStmt(&stmt);
1393 return true;
1394 }
1395
1396 // handle memset, memcpy return val
GenMemopRetAssign(StmtNode & stmt,MIRFunction & func,bool isLowLevel,MemOpKind memOpKind,ErrorNumber errorNumber)1397 StmtNode *MemEntry::GenMemopRetAssign(StmtNode &stmt, MIRFunction &func, bool isLowLevel, MemOpKind memOpKind,
1398 ErrorNumber errorNumber)
1399 {
1400 if (stmt.GetOpCode() != OP_call && stmt.GetOpCode() != OP_callassigned) {
1401 return nullptr;
1402 }
1403 auto &callStmt = static_cast<CallNode &>(stmt);
1404 const auto &retVec = callStmt.GetReturnVec();
1405 if (retVec.empty()) {
1406 return nullptr;
1407 }
1408 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
1409 BaseNode *rhs = callStmt.Opnd(0); // for memset, memcpy
1410 if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
1411 // memset_s and memcpy_s must return an errorNumber
1412 MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_i32));
1413 MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(errorNumber, *constType);
1414 rhs = mirBuilder->CreateConstval(mirConst);
1415 }
1416 if (!retVec[0].second.IsReg()) {
1417 auto *retAssign = mirBuilder->CreateStmtDassign(retVec[0].first, 0, rhs);
1418 return retAssign;
1419 } else {
1420 PregIdx pregIdx = retVec[0].second.GetPregIdx();
1421 auto pregType = func.GetPregTab()->GetPregTableItem(static_cast<uint32>(pregIdx))->GetPrimType();
1422 auto *retAssign = mirBuilder->CreateStmtRegassign(pregType, pregIdx, rhs);
1423 if (isLowLevel) {
1424 retAssign->GetRHS()->SetPrimType(pregType);
1425 }
1426 return retAssign;
1427 }
1428 }
1429
ComputeMemOpKind(StmtNode & stmt)1430 MemOpKind SimplifyMemOp::ComputeMemOpKind(StmtNode &stmt)
1431 {
1432 if (stmt.GetOpCode() == OP_intrinsiccall) {
1433 auto intrinsicID = static_cast<IntrinsiccallNode &>(stmt).GetIntrinsic();
1434 if (intrinsicID == INTRN_C_memset) {
1435 return MEM_OP_memset;
1436 } else if (intrinsicID == INTRN_C_memcpy) {
1437 return MEM_OP_memcpy;
1438 }
1439 }
1440 // lowered memop function (such as memset) may be a call, not callassigned
1441 if (stmt.GetOpCode() != OP_callassigned && stmt.GetOpCode() != OP_call) {
1442 return MEM_OP_unknown;
1443 }
1444 auto &callStmt = static_cast<CallNode &>(stmt);
1445 MIRFunction *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callStmt.GetPUIdx());
1446 const char *funcName = func->GetName().c_str();
1447 if (strcmp(funcName, kFuncNameOfMemset) == 0) {
1448 return MEM_OP_memset;
1449 }
1450 if (strcmp(funcName, kFuncNameOfMemcpy) == 0) {
1451 return MEM_OP_memcpy;
1452 }
1453 if (strcmp(funcName, kFuncNameOfMemsetS) == 0) {
1454 return MEM_OP_memset_s;
1455 }
1456 if (strcmp(funcName, kFuncNameOfMemcpyS) == 0) {
1457 return MEM_OP_memcpy_s;
1458 }
1459 return MEM_OP_unknown;
1460 }
1461
AutoSimplify(StmtNode & stmt,BlockNode & block,bool isLowLevel)1462 bool SimplifyMemOp::AutoSimplify(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1463 {
1464 MemOpKind memOpKind = ComputeMemOpKind(stmt);
1465 switch (memOpKind) {
1466 case MEM_OP_memset:
1467 case MEM_OP_memset_s: {
1468 return SimplifyMemset(stmt, block, isLowLevel);
1469 }
1470 case MEM_OP_memcpy:
1471 case MEM_OP_memcpy_s: {
1472 return SimplifyMemcpy(stmt, block, isLowLevel);
1473 }
1474 default:
1475 break;
1476 }
1477 return false;
1478 }
1479
1480 // expand memset_s call statement, return pointer of memset call statement node to be expanded in the next step, return
1481 // nullptr if memset_s is expanded completely.
PartiallyExpandMemsetS(StmtNode & stmt,BlockNode & block)1482 StmtNode *SimplifyMemOp::PartiallyExpandMemsetS(StmtNode &stmt, BlockNode &block)
1483 {
1484 ErrorNumber errNum = ERRNO_OK;
1485
1486 uint64 srcSize = 0;
1487 bool isSrcSizeConst = false;
1488 BaseNode *foldSrcSizeExpr = FoldIntConst(stmt.Opnd(kMemsetSSrcSizeOpndIdx), srcSize, isSrcSizeConst);
1489 if (foldSrcSizeExpr != nullptr) {
1490 stmt.SetOpnd(foldSrcSizeExpr, kMemsetSDstSizeOpndIdx);
1491 }
1492
1493 uint64 dstSize = 0;
1494 bool isDstSizeConst = false;
1495 BaseNode *foldDstSizeExpr = FoldIntConst(stmt.Opnd(kMemsetSDstSizeOpndIdx), dstSize, isDstSizeConst);
1496 if (foldDstSizeExpr != nullptr) {
1497 stmt.SetOpnd(foldDstSizeExpr, kMemsetSDstSizeOpndIdx);
1498 }
1499 if (isDstSizeConst) {
1500 if ((srcSize > dstSize && dstSize == 0) || dstSize > kSecurecMemMaxLen) {
1501 errNum = ERRNO_RANGE;
1502 }
1503 }
1504
1505 MIRBuilder *mirBuilder = func->GetModule()->GetMIRBuilder();
1506 LabelIdx finalLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('f');
1507 if (errNum != ERRNO_OK) {
1508 auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, *func, true, MEM_OP_memset_s, errNum);
1509 InsertAndMayPrintStmt(block, stmt, debug, errnoAssign);
1510 block.RemoveStmt(&stmt);
1511 return nullptr;
1512 } else {
1513 LabelIdx dstSizeCheckLabIdx, srcSizeCheckLabIdx, nullPtrLabIdx;
1514 if (!isDstSizeConst) {
1515 // check if dst size is greater than maxlen
1516 dstSizeCheckLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n'); // 'n' means nullptr
1517 CreateAndInsertCheckStmt(OP_gt, stmt.Opnd(kMemsetSDstSizeOpndIdx),
1518 ConstructConstvalNode(kSecurecMemMaxLen, PTY_u64, *mirBuilder), dstSizeCheckLabIdx,
1519 stmt, block, *func, debug);
1520 }
1521
1522 // check if dst is nullptr
1523 nullPtrLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n'); // 'n' means nullptr
1524 CreateAndInsertCheckStmt(OP_eq, stmt.Opnd(kMemsetDstOpndIdx), ConstructConstvalNode(0, PTY_u64, *mirBuilder),
1525 nullPtrLabIdx, stmt, block, *func, debug);
1526
1527 if (isDstSizeConst && isSrcSizeConst) {
1528 if (srcSize > dstSize) {
1529 srcSize = dstSize;
1530 return ExpandOnSrcSizeGtDstSize(stmt, block, srcSize, finalLabIdx, nullPtrLabIdx, *func, debug);
1531 }
1532 } else {
1533 // check if src size is greater than dst size
1534 srcSizeCheckLabIdx = func->GetLabelTab()->CreateLabelWithPrefix('n'); // 'n' means nullptr
1535 CreateAndInsertCheckStmt(OP_gt, stmt.Opnd(kMemsetSSrcSizeOpndIdx), stmt.Opnd(kMemsetSDstSizeOpndIdx),
1536 srcSizeCheckLabIdx, stmt, block, *func, debug);
1537 }
1538
1539 MapleVector<BaseNode *> args(func->GetCodeMempoolAllocator().Adapter());
1540 args.push_back(stmt.Opnd(kMemsetDstOpndIdx));
1541 args.push_back(stmt.Opnd(kMemsetSSrcOpndIdx));
1542 args.push_back(stmt.Opnd(kMemsetSSrcSizeOpndIdx));
1543 auto memsetCallStmt = InsertMemsetCallStmt(args, *func, stmt, block, finalLabIdx, errNum, debug);
1544
1545 if (!isSrcSizeConst || !isDstSizeConst) {
1546 // handle src size error
1547 auto branchLabNode = mirBuilder->CreateStmtLabel(srcSizeCheckLabIdx);
1548 InsertAndMayPrintStmt(block, stmt, debug, branchLabNode);
1549 HandleZeroValueOfDstSize(stmt, block, srcSize, dstSize, finalLabIdx, dstSizeCheckLabIdx, *func,
1550 isDstSizeConst, debug);
1551 args.pop_back();
1552 args.push_back(stmt.Opnd(kMemsetSDstSizeOpndIdx));
1553 (void)InsertMemsetCallStmt(args, *func, stmt, block, finalLabIdx, ERRNO_RANGE_AND_RESET, debug);
1554 }
1555
1556 // handle dst nullptr error
1557 auto nullptrLabNode = mirBuilder->CreateStmtLabel(nullPtrLabIdx);
1558 InsertAndMayPrintStmt(block, stmt, debug, nullptrLabNode);
1559 HandleZeroValueOfDstSize(stmt, block, srcSize, dstSize, finalLabIdx, dstSizeCheckLabIdx, *func, isDstSizeConst,
1560 debug);
1561 auto gotoFinal = mirBuilder->CreateStmtGoto(OP_goto, finalLabIdx);
1562 auto errnoAssign = MemEntry::GenMemopRetAssign(stmt, *func, true, MEM_OP_memset_s, ERRNO_INVAL);
1563 InsertBeforeAndMayPrintStmtList(block, stmt, debug, {errnoAssign, gotoFinal});
1564
1565 if (!isDstSizeConst) {
1566 // handle dst size error
1567 InsertCheckFailedBranch(*func, stmt, block, dstSizeCheckLabIdx, finalLabIdx, ERRNO_RANGE, MEM_OP_memset_s,
1568 debug);
1569 }
1570 auto *finalLabelNode = mirBuilder->CreateStmtLabel(finalLabIdx);
1571 InsertAndMayPrintStmt(block, stmt, debug, finalLabelNode);
1572 block.RemoveStmt(&stmt);
1573 return memsetCallStmt;
1574 }
1575 }
1576
1577 // Try to replace the call to memset with a series of assign operations (including dassign, iassign, iassignoff), which
1578 // is usually profitable for small memory size.
1579 // This function is called in two places, one in mpl2mpl simplify, another in cglower:
1580 // (1) mpl2mpl memset expand (isLowLevel == false)
1581 // for primitive type, array type with element size < 4 bytes and struct type without padding
1582 // (2) cglower memset expand
1583 // for array type with element size >= 4 bytes and struct type with paddings
SimplifyMemset(StmtNode & stmt,BlockNode & block,bool isLowLevel)1584 bool SimplifyMemOp::SimplifyMemset(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1585 {
1586 MemOpKind memOpKind = ComputeMemOpKind(stmt);
1587 if (memOpKind != MEM_OP_memset && memOpKind != MEM_OP_memset_s) {
1588 return false;
1589 }
1590 uint32 dstOpndIdx = 0;
1591 uint32 srcOpndIdx = 1;
1592 uint32 srcSizeOpndIdx = 2;
1593 bool isSafeVersion = memOpKind == MEM_OP_memset_s;
1594 if (debug) {
1595 LogInfo::MapleLogger() << "[funcName] " << func->GetName() << std::endl;
1596 stmt.Dump(0);
1597 }
1598
1599 StmtNode *memsetCallStmt = &stmt;
1600 if (memOpKind == MEM_OP_memset_s && !isLowLevel) {
1601 memsetCallStmt = PartiallyExpandMemsetS(stmt, block);
1602 if (!memsetCallStmt) {
1603 return true; // Expand memset_s completely, no extra memset is generated, so just return true
1604 }
1605 }
1606
1607 uint64 srcSize = 0;
1608 bool isSrcSizeConst = false;
1609 BaseNode *foldSrcSizeExpr = FoldIntConst(memsetCallStmt->Opnd(srcSizeOpndIdx), srcSize, isSrcSizeConst);
1610 if (foldSrcSizeExpr != nullptr) {
1611 memsetCallStmt->SetOpnd(foldSrcSizeExpr, srcSizeOpndIdx);
1612 }
1613
1614 if (isSrcSizeConst) {
1615 // If the size is too big, we won't expand it
1616 uint32 thresholdExpand = (isSafeVersion ? thresholdMemsetSExpand : thresholdMemsetExpand);
1617 if (srcSize > thresholdExpand) {
1618 MayPrintLog(debug, false, memOpKind, "size is too big");
1619 return false;
1620 }
1621 if (srcSize == 0) {
1622 if (memOpKind == MEM_OP_memset) {
1623 auto *retAssign = MemEntry::GenMemopRetAssign(*memsetCallStmt, *func, isLowLevel, memOpKind);
1624 InsertAndMayPrintStmt(block, *memsetCallStmt, debug, retAssign);
1625 }
1626 block.RemoveStmt(memsetCallStmt);
1627 return true;
1628 }
1629 }
1630
1631 // memset's 'src size' must be a const value, otherwise we can not expand it
1632 if (!isSrcSizeConst) {
1633 MayPrintLog(debug, false, memOpKind, "size is not int const");
1634 return false;
1635 }
1636
1637 ErrorNumber errNum = ERRNO_OK;
1638
1639 uint64 val = 0;
1640 bool isIntConst = false;
1641 BaseNode *foldValExpr = FoldIntConst(memsetCallStmt->Opnd(srcOpndIdx), val, isIntConst);
1642 if (foldValExpr != nullptr) {
1643 memsetCallStmt->SetOpnd(foldValExpr, srcOpndIdx);
1644 }
1645 // memset's second argument 'val' should also be a const value
1646 if (!isIntConst) {
1647 MayPrintLog(debug, false, memOpKind, "val is not int const");
1648 return false;
1649 }
1650
1651 MemEntry dstMemEntry;
1652 bool valid = MemEntry::ComputeMemEntry(*(memsetCallStmt->Opnd(dstOpndIdx)), *func, dstMemEntry, isLowLevel);
1653 if (!valid) {
1654 MayPrintLog(debug, false, memOpKind, "dstMemEntry is invalid");
1655 return false;
1656 }
1657 bool ret = false;
1658 if (srcSize != 0) {
1659 ret = dstMemEntry.ExpandMemset(val, static_cast<uint64>(srcSize), *func, *memsetCallStmt, block, isLowLevel,
1660 debug, errNum);
1661 } else {
1662 // if size == 0, no need to set memory, just return error nummber
1663 auto *retAssign = MemEntry::GenMemopRetAssign(*memsetCallStmt, *func, isLowLevel, memOpKind, errNum);
1664 InsertAndMayPrintStmt(block, *memsetCallStmt, debug, retAssign);
1665 block.RemoveStmt(memsetCallStmt);
1666 ret = true;
1667 }
1668 if (ret) {
1669 MayPrintLog(debug, true, memOpKind, "well done");
1670 }
1671 return ret;
1672 }
1673
SimplifyMemcpy(StmtNode & stmt,BlockNode & block,bool isLowLevel)1674 bool SimplifyMemOp::SimplifyMemcpy(StmtNode &stmt, BlockNode &block, bool isLowLevel)
1675 {
1676 MemOpKind memOpKind = ComputeMemOpKind(stmt);
1677 if (memOpKind != MEM_OP_memcpy && memOpKind != MEM_OP_memcpy_s) {
1678 return false;
1679 }
1680 uint32 dstOpndIdx = 0;
1681 uint32 dstSizeOpndIdx = kFirstOpnd; // only used by memcpy_s
1682 uint32 srcOpndIdx = kSecondOpnd;
1683 uint32 srcSizeOpndIdx = kThirdOpnd;
1684 bool isSafeVersion = memOpKind == MEM_OP_memcpy_s;
1685 if (isSafeVersion) {
1686 dstSizeOpndIdx = kSecondOpnd;
1687 srcOpndIdx = kThirdOpnd;
1688 srcSizeOpndIdx = kFourthOpnd;
1689 }
1690 if (debug) {
1691 LogInfo::MapleLogger() << "[funcName] " << func->GetName() << std::endl;
1692 stmt.Dump(0);
1693 }
1694
1695 uint64 srcSize = 0;
1696 bool isIntConst = false;
1697 BaseNode *foldCopySizeExpr = FoldIntConst(stmt.Opnd(srcSizeOpndIdx), srcSize, isIntConst);
1698 if (foldCopySizeExpr != nullptr) {
1699 stmt.SetOpnd(foldCopySizeExpr, srcSizeOpndIdx);
1700 }
1701 if (!isIntConst) {
1702 MayPrintLog(debug, false, memOpKind, "src size is not an int const");
1703 return false;
1704 }
1705 uint32 thresholdExpand = (isSafeVersion ? thresholdMemcpySExpand : thresholdMemcpyExpand);
1706 if (srcSize > thresholdExpand) {
1707 MayPrintLog(debug, false, memOpKind, "size is too big");
1708 return false;
1709 }
1710 if (srcSize == 0) {
1711 MayPrintLog(debug, false, memOpKind, "memcpy with src size 0");
1712 return false;
1713 }
1714 uint64 copySize = srcSize;
1715 ErrorNumber errNum = ERRNO_OK;
1716 if (isSafeVersion) {
1717 uint64 dstSize = 0;
1718 bool isDstSizeConst = false;
1719 BaseNode *foldDstSizeExpr = FoldIntConst(stmt.Opnd(dstSizeOpndIdx), dstSize, isDstSizeConst);
1720 if (foldDstSizeExpr != nullptr) {
1721 stmt.SetOpnd(foldDstSizeExpr, dstSizeOpndIdx);
1722 }
1723 if (!isDstSizeConst) {
1724 MayPrintLog(debug, false, memOpKind, "dst size is not int const");
1725 return false;
1726 }
1727 if (dstSize == 0 || dstSize > kSecurecMemMaxLen) {
1728 copySize = 0;
1729 errNum = ERRNO_RANGE;
1730 } else if (srcSize > dstSize) {
1731 copySize = dstSize;
1732 errNum = ERRNO_RANGE_AND_RESET;
1733 }
1734 }
1735
1736 MemEntry dstMemEntry;
1737 bool valid = MemEntry::ComputeMemEntry(*stmt.Opnd(dstOpndIdx), *func, dstMemEntry, isLowLevel);
1738 if (!valid) {
1739 MayPrintLog(debug, false, memOpKind, "dstMemEntry is invalid");
1740 return false;
1741 }
1742 MemEntry srcMemEntry;
1743 valid = MemEntry::ComputeMemEntry(*stmt.Opnd(srcOpndIdx), *func, srcMemEntry, isLowLevel);
1744 if (!valid) {
1745 MayPrintLog(debug, false, memOpKind, "srcMemEntry is invalid");
1746 return false;
1747 }
1748 // We don't check type consistency when doing low level expand
1749 if (!isLowLevel) {
1750 if (dstMemEntry.memType != srcMemEntry.memType) {
1751 MayPrintLog(debug, false, memOpKind, "dst and src have different type");
1752 return false; // entryType must be identical
1753 }
1754 if (dstMemEntry.memType->GetSize() != static_cast<uint64>(srcSize)) {
1755 MayPrintLog(debug, false, memOpKind, "copy size != dst memory size");
1756 return false; // copy size should equal to dst memory size, we maybe allow smaller copy size later
1757 }
1758 }
1759 bool ret = false;
1760 if (copySize != 0) {
1761 ret = dstMemEntry.ExpandMemcpy(srcMemEntry, copySize, *func, stmt, block, isLowLevel, debug, errNum);
1762 } else {
1763 // if copySize == 0, no need to copy memory, just return error number
1764 auto *retAssign = MemEntry::GenMemopRetAssign(stmt, *func, isLowLevel, memOpKind, errNum);
1765 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
1766 block.RemoveStmt(&stmt);
1767 ret = true;
1768 }
1769 if (ret) {
1770 MayPrintLog(debug, true, memOpKind, "well done");
1771 }
1772 return ret;
1773 }
1774
1775 } // namespace maple
1776