1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "triple.h"
17 #include "simplify.h"
18 #include <functional>
19 #include <initializer_list>
20 #include <iostream>
21 #include <algorithm>
22 #include "constantfold.h"
23 #include "mpl_logging.h"
24
25 namespace maple {
26
27 namespace {
28
29 constexpr char kFuncNameOfMemset[] = "memset";
30 constexpr char kFuncNameOfMemcpy[] = "memcpy";
31 constexpr char kFuncNameOfMemsetS[] = "memset_s";
32 constexpr char kFuncNameOfMemcpyS[] = "memcpy_s";
33
34 // Truncate the constant field of 'union' if it's written as scalar type (e.g. int),
35 // but accessed as bit-field type with smaller size.
36 //
37 // Return the truncated constant or nullptr if the constant doesn't need to be truncated.
TruncateUnionConstant(const MIRStructType & unionType,MIRConst * fieldCst,const MIRType & unionFieldType)38 MIRConst *TruncateUnionConstant(const MIRStructType &unionType, MIRConst *fieldCst, const MIRType &unionFieldType)
39 {
40 if (unionType.GetKind() != kTypeUnion) {
41 return nullptr;
42 }
43
44 auto *bitFieldType = safe_cast<MIRBitFieldType>(unionFieldType);
45 auto *intCst = safe_cast<MIRIntConst>(fieldCst);
46 if (!bitFieldType || !intCst) {
47 return nullptr;
48 }
49
50 bool isBigEndian = Triple::GetTriple().IsBigEndian();
51 IntVal val = intCst->GetValue();
52 uint8 bitSize = bitFieldType->GetFieldSize();
53 if (bitSize >= val.GetBitWidth()) {
54 return nullptr;
55 }
56
57 if (isBigEndian) {
58 val = val.LShr(val.GetBitWidth() - bitSize);
59 } else {
60 val = val & ((uint64(1) << bitSize) - 1);
61 }
62
63 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, fieldCst->GetType());
64 }
65
66 } // namespace
67
68 // If size (in byte) is bigger than this threshold, we won't expand memop
69 const uint32 SimplifyMemOp::thresholdMemsetExpand = 512;
70 const uint32 SimplifyMemOp::thresholdMemcpyExpand = 512;
71 const uint32 SimplifyMemOp::thresholdMemsetSExpand = 1024;
72 const uint32 SimplifyMemOp::thresholdMemcpySExpand = 1024;
73 static const uint32 kMaxMemoryBlockSizeToAssign = 8; // in byte
74
IsMathSqrt(const std::string funcName)75 bool Simplify::IsMathSqrt(const std::string funcName)
76 {
77 return false;
78 }
79
IsMathAbs(const std::string funcName)80 bool Simplify::IsMathAbs(const std::string funcName)
81 {
82 return false;
83 }
84
IsMathMax(const std::string funcName)85 bool Simplify::IsMathMax(const std::string funcName)
86 {
87 return false;
88 }
89
IsMathMin(const std::string funcName)90 bool Simplify::IsMathMin(const std::string funcName)
91 {
92 return false;
93 }
94
SimplifyMathMethod(const StmtNode & stmt,BlockNode & block)95 bool Simplify::SimplifyMathMethod(const StmtNode &stmt, BlockNode &block)
96 {
97 if (stmt.GetOpCode() != OP_callassigned) {
98 return false;
99 }
100 auto &cnode = static_cast<const CallNode &>(stmt);
101 MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(cnode.GetPUIdx());
102 DEBUG_ASSERT(calleeFunc != nullptr, "null ptr check");
103 const std::string &funcName = calleeFunc->GetName();
104 if (funcName.empty()) {
105 return false;
106 }
107 if (!mirMod.IsCModule()) {
108 return false;
109 }
110 if (cnode.GetNumOpnds() == 0 || cnode.GetReturnVec().empty()) {
111 return false;
112 }
113
114 auto *opnd0 = cnode.Opnd(0);
115 DEBUG_ASSERT(opnd0 != nullptr, "null ptr check");
116 auto *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(opnd0->GetPrimType());
117
118 BaseNode *opExpr = nullptr;
119 if (IsMathSqrt(funcName) && !IsPrimitiveFloat(opnd0->GetPrimType())) {
120 opExpr = builder->CreateExprUnary(OP_sqrt, *type, opnd0);
121 } else if (IsMathAbs(funcName)) {
122 opExpr = builder->CreateExprUnary(OP_abs, *type, opnd0);
123 } else if (IsMathMax(funcName)) {
124 opExpr = builder->CreateExprBinary(OP_max, *type, opnd0, cnode.Opnd(1));
125 } else if (IsMathMin(funcName)) {
126 opExpr = builder->CreateExprBinary(OP_min, *type, opnd0, cnode.Opnd(1));
127 }
128 if (opExpr != nullptr) {
129 auto stIdx = cnode.GetNthReturnVec(0).first;
130 auto *dassign = builder->CreateStmtDassign(stIdx, 0, opExpr);
131 block.ReplaceStmt1WithStmt2(&stmt, dassign);
132 return true;
133 }
134 return false;
135 }
136
SimplifyCallAssigned(StmtNode & stmt,BlockNode & block)137 void Simplify::SimplifyCallAssigned(StmtNode &stmt, BlockNode &block)
138 {
139 if (SimplifyMathMethod(stmt, block)) {
140 return;
141 }
142 simplifyMemOp.SetDebug(dump);
143 simplifyMemOp.SetFunction(currFunc);
144 if (simplifyMemOp.AutoSimplify(stmt, block, false)) {
145 return;
146 }
147 }
148
149 constexpr uint32 kUpperLimitOfFieldNum = 10;
GetDassignedStructType(const DassignNode * dassign,MIRFunction * func)150 static MIRStructType *GetDassignedStructType(const DassignNode *dassign, MIRFunction *func)
151 {
152 const auto &lhsStIdx = dassign->GetStIdx();
153 auto lhsSymbol = func->GetLocalOrGlobalSymbol(lhsStIdx);
154 DEBUG_ASSERT(lhsSymbol != nullptr, "lhsSymbol should not be nullptr");
155 auto lhsAggType = lhsSymbol->GetType();
156 if (!lhsAggType->IsStructType()) {
157 return nullptr;
158 }
159 if (lhsAggType->GetKind() == kTypeUnion) { // no need to split union's field
160 return nullptr;
161 }
162 auto lhsFieldID = dassign->GetFieldID();
163 if (lhsFieldID != 0) {
164 CHECK_FATAL(lhsAggType->IsStructType(), "only struct has non-zero fieldID");
165 lhsAggType = static_cast<MIRStructType *>(lhsAggType)->GetFieldType(lhsFieldID);
166 if (!lhsAggType->IsStructType()) {
167 return nullptr;
168 }
169 if (lhsAggType->GetKind() == kTypeUnion) { // no need to split union's field
170 return nullptr;
171 }
172 }
173 if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
174 return nullptr;
175 }
176 return static_cast<MIRStructType *>(lhsAggType);
177 }
178
GetIassignedStructType(const IassignNode * iassign)179 static MIRStructType *GetIassignedStructType(const IassignNode *iassign)
180 {
181 auto ptrTyIdx = iassign->GetTyIdx();
182 auto *ptrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(ptrTyIdx);
183 CHECK_FATAL(ptrType->IsMIRPtrType(), "must be pointer type");
184 auto aggTyIdx = static_cast<MIRPtrType *>(ptrType)->GetPointedTyIdxWithFieldID(iassign->GetFieldID());
185 auto *lhsAggType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(aggTyIdx);
186 if (!lhsAggType->IsStructType()) {
187 return nullptr;
188 }
189 if (lhsAggType->GetKind() == kTypeUnion) {
190 return nullptr;
191 }
192 if (static_cast<MIRStructType *>(lhsAggType)->NumberOfFieldIDs() > kUpperLimitOfFieldNum) {
193 return nullptr;
194 }
195 return static_cast<MIRStructType *>(lhsAggType);
196 }
197
GetReadedStructureType(const DreadNode * dread,const MIRFunction * func)198 static MIRStructType *GetReadedStructureType(const DreadNode *dread, const MIRFunction *func)
199 {
200 const auto &rhsStIdx = dread->GetStIdx();
201 auto rhsSymbol = func->GetLocalOrGlobalSymbol(rhsStIdx);
202 DEBUG_ASSERT(rhsSymbol != nullptr, "rhsSymbol should not be nullptr");
203 auto rhsAggType = rhsSymbol->GetType();
204 auto rhsFieldID = dread->GetFieldID();
205 if (rhsFieldID != 0) {
206 CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
207 rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
208 }
209 if (!rhsAggType->IsStructType()) {
210 return nullptr;
211 }
212 return static_cast<MIRStructType *>(rhsAggType);
213 }
214
GetReadedStructureType(const IreadNode * iread,const MIRFunction *)215 static MIRStructType *GetReadedStructureType(const IreadNode *iread, const MIRFunction *)
216 {
217 auto rhsPtrType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx());
218 CHECK_FATAL(rhsPtrType->IsMIRPtrType(), "must be pointer type");
219 auto rhsAggType = static_cast<MIRPtrType *>(rhsPtrType)->GetPointedType();
220 auto rhsFieldID = iread->GetFieldID();
221 if (rhsFieldID != 0) {
222 CHECK_FATAL(rhsAggType->IsStructType(), "only struct has non-zero fieldID");
223 rhsAggType = static_cast<MIRStructType *>(rhsAggType)->GetFieldType(rhsFieldID);
224 }
225 if (!rhsAggType->IsStructType()) {
226 return nullptr;
227 }
228 return static_cast<MIRStructType *>(rhsAggType);
229 }
230
231 template <class RhsType, class AssignType>
SplitAggCopy(const AssignType * assignNode,MIRStructType * structureType,BlockNode * block,MIRFunction * func)232 static StmtNode *SplitAggCopy(const AssignType *assignNode, MIRStructType *structureType, BlockNode *block,
233 MIRFunction *func)
234 {
235 auto *readNode = static_cast<RhsType *>(assignNode->GetRHS());
236 auto rhsFieldID = readNode->GetFieldID();
237 auto *rhsAggType = GetReadedStructureType(readNode, func);
238 if (structureType != rhsAggType) {
239 return nullptr;
240 }
241
242 for (FieldID id = 1; id <= static_cast<FieldID>(structureType->NumberOfFieldIDs()); ++id) {
243 MIRType *fieldType = structureType->GetFieldType(id);
244 if (fieldType->GetSize() == 0) {
245 continue; // field size is zero for empty struct/union;
246 }
247 if (fieldType->GetKind() == kTypeBitField && static_cast<MIRBitFieldType *>(fieldType)->GetFieldSize() == 0) {
248 continue; // bitfield size is zero
249 }
250 auto *newDassign = assignNode->CloneTree(func->GetCodeMemPoolAllocator());
251 newDassign->SetFieldID(assignNode->GetFieldID() + id);
252 auto *newRHS = static_cast<RhsType *>(newDassign->GetRHS());
253 newRHS->SetFieldID(rhsFieldID + id);
254 newRHS->SetPrimType(fieldType->GetPrimType());
255 block->InsertAfter(assignNode, newDassign);
256 if (fieldType->IsMIRUnionType()) {
257 id += static_cast<FieldID>(fieldType->NumberOfFieldIDs());
258 }
259 }
260 auto newAssign = assignNode->GetNext();
261 block->RemoveStmt(assignNode);
262 return newAssign;
263 }
264
SplitDassignAggCopy(DassignNode * dassign,BlockNode * block,MIRFunction * func)265 static StmtNode *SplitDassignAggCopy(DassignNode *dassign, BlockNode *block, MIRFunction *func)
266 {
267 auto *rhs = dassign->GetRHS();
268 if (rhs->GetPrimType() != PTY_agg) {
269 return nullptr;
270 }
271
272 auto *lhsAggType = GetDassignedStructType(dassign, func);
273 if (lhsAggType == nullptr) {
274 return nullptr;
275 }
276
277 if (rhs->GetOpCode() == OP_dread) {
278 auto *lhsSymbol = func->GetLocalOrGlobalSymbol(dassign->GetStIdx());
279 auto *rhsSymbol = func->GetLocalOrGlobalSymbol(static_cast<DreadNode *>(rhs)->GetStIdx());
280 if (!lhsSymbol->IsLocal() && !rhsSymbol->IsLocal()) {
281 return nullptr;
282 }
283
284 return SplitAggCopy<DreadNode>(dassign, lhsAggType, block, func);
285 } else if (rhs->GetOpCode() == OP_iread) {
286 return SplitAggCopy<IreadNode>(dassign, lhsAggType, block, func);
287 }
288 return nullptr;
289 }
290
SplitIassignAggCopy(IassignNode * iassign,BlockNode * block,MIRFunction * func)291 static StmtNode *SplitIassignAggCopy(IassignNode *iassign, BlockNode *block, MIRFunction *func)
292 {
293 auto rhs = iassign->GetRHS();
294 if (rhs->GetPrimType() != PTY_agg) {
295 return nullptr;
296 }
297
298 auto *lhsAggType = GetIassignedStructType(iassign);
299 if (lhsAggType == nullptr) {
300 return nullptr;
301 }
302
303 if (rhs->GetOpCode() == OP_dread) {
304 return SplitAggCopy<DreadNode>(iassign, lhsAggType, block, func);
305 } else if (rhs->GetOpCode() == OP_iread) {
306 return SplitAggCopy<IreadNode>(iassign, lhsAggType, block, func);
307 }
308 return nullptr;
309 }
310
UseGlobalVar(const BaseNode * expr)311 bool UseGlobalVar(const BaseNode *expr)
312 {
313 if (expr->GetOpCode() == OP_addrof || expr->GetOpCode() == OP_dread) {
314 StIdx stIdx = static_cast<const AddrofNode *>(expr)->GetStIdx();
315 if (stIdx.IsGlobal()) {
316 return true;
317 }
318 }
319 for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
320 if (UseGlobalVar(expr->Opnd(i))) {
321 return true;
322 }
323 }
324 return false;
325 }
326
ProcessStmt(StmtNode & stmt)327 void Simplify::ProcessStmt(StmtNode &stmt)
328 {
329 switch (stmt.GetOpCode()) {
330 case OP_callassigned: {
331 SimplifyCallAssigned(stmt, *currBlock);
332 break;
333 }
334 case OP_intrinsiccall: {
335 simplifyMemOp.SetDebug(dump);
336 simplifyMemOp.SetFunction(currFunc);
337 (void)simplifyMemOp.AutoSimplify(stmt, *currBlock, false);
338 break;
339 }
340 case OP_dassign: {
341 auto *newStmt = SplitDassignAggCopy(static_cast<DassignNode *>(&stmt), currBlock, currFunc);
342 if (newStmt) {
343 ProcessBlock(*newStmt);
344 }
345 break;
346 }
347 case OP_iassign: {
348 auto *newStmt = SplitIassignAggCopy(static_cast<IassignNode *>(&stmt), currBlock, currFunc);
349 if (newStmt) {
350 ProcessBlock(*newStmt);
351 }
352 break;
353 }
354 case OP_if:
355 case OP_while:
356 case OP_dowhile: {
357 auto unaryStmt = static_cast<UnaryStmtNode &>(stmt);
358 unaryStmt.SetRHS(SimplifyExpr(*unaryStmt.GetRHS()));
359 return;
360 }
361 default: {
362 break;
363 }
364 }
365 for (size_t i = 0; i < stmt.NumOpnds(); ++i) {
366 if (stmt.Opnd(i)) {
367 stmt.SetOpnd(SimplifyExpr(*stmt.Opnd(i)), i);
368 }
369 }
370 }
371
SimplifyExpr(BaseNode & expr)372 BaseNode *Simplify::SimplifyExpr(BaseNode &expr)
373 {
374 switch (expr.GetOpCode()) {
375 case OP_dread: {
376 auto &dread = static_cast<DreadNode &>(expr);
377 return ReplaceExprWithConst(dread);
378 }
379 default: {
380 for (auto i = 0; i < expr.GetNumOpnds(); i++) {
381 if (expr.Opnd(i)) {
382 expr.SetOpnd(SimplifyExpr(*expr.Opnd(i)), i);
383 }
384 }
385 break;
386 }
387 }
388 return &expr;
389 }
390
ReplaceExprWithConst(DreadNode & dread)391 BaseNode *Simplify::ReplaceExprWithConst(DreadNode &dread)
392 {
393 auto stIdx = dread.GetStIdx();
394 auto fieldId = dread.GetFieldID();
395 auto *symbol = currFunc->GetLocalOrGlobalSymbol(stIdx);
396 DEBUG_ASSERT(symbol != nullptr, "nullptr check");
397 auto *symbolConst = symbol->GetKonst();
398 if (!currFunc->GetModule()->IsCModule() || !symbolConst || !stIdx.IsGlobal() ||
399 !IsSymbolReplaceableWithConst(*symbol)) {
400 return &dread;
401 }
402 if (fieldId != 0) {
403 symbolConst = GetElementConstFromFieldId(fieldId, symbolConst);
404 }
405 if (!symbolConst || !IsConstRepalceable(*symbolConst)) {
406 return &dread;
407 }
408 return currFunc->GetModule()->GetMIRBuilder()->CreateConstval(symbolConst);
409 }
410
IsSymbolReplaceableWithConst(const MIRSymbol & symbol) const411 bool Simplify::IsSymbolReplaceableWithConst(const MIRSymbol &symbol) const
412 {
413 return (symbol.GetStorageClass() == kScFstatic && !symbol.HasPotentialAssignment()) ||
414 symbol.GetAttrs().GetAttr(ATTR_const);
415 }
416
IsConstRepalceable(const MIRConst & mirConst) const417 bool Simplify::IsConstRepalceable(const MIRConst &mirConst) const
418 {
419 switch (mirConst.GetKind()) {
420 case kConstInt:
421 case kConstFloatConst:
422 case kConstDoubleConst:
423 case kConstFloat128Const:
424 case kConstLblConst:
425 return true;
426 default:
427 return false;
428 }
429 }
430
GetElementConstFromFieldId(FieldID fieldId,MIRConst * mirConst)431 MIRConst *Simplify::GetElementConstFromFieldId(FieldID fieldId, MIRConst *mirConst)
432 {
433 FieldID currFieldId = 1;
434 MIRConst *resultConst = nullptr;
435 auto originAggConst = static_cast<MIRAggConst *>(mirConst);
436 auto originAggType = static_cast<MIRStructType &>(originAggConst->GetType());
437 bool hasReached = false;
438 std::function<void(MIRConst *)> traverseAgg = [&](MIRConst *currConst) {
439 auto *currAggConst = safe_cast<MIRAggConst>(currConst);
440 ASSERT_NOT_NULL(currAggConst);
441 auto *currAggType = safe_cast<MIRStructType>(currAggConst->GetType());
442 ASSERT_NOT_NULL(currAggType);
443 for (size_t iter = 0; iter < currAggType->GetFieldsSize() && !hasReached; ++iter) {
444 size_t constIdx = currAggType->GetKind() == kTypeUnion ? 1 : iter + 1;
445 auto *fieldConst = currAggConst->GetAggConstElement(constIdx);
446 auto *fieldType = originAggType.GetFieldType(currFieldId);
447
448 if (currFieldId == fieldId) {
449 if (auto *truncCst = TruncateUnionConstant(*currAggType, fieldConst, *fieldType)) {
450 resultConst = truncCst;
451 } else {
452 resultConst = fieldConst;
453 }
454
455 hasReached = true;
456 return;
457 }
458
459 ++currFieldId;
460 if (fieldType->GetKind() == kTypeUnion || fieldType->GetKind() == kTypeStruct) {
461 traverseAgg(fieldConst);
462 }
463 }
464 };
465 traverseAgg(mirConst);
466 CHECK_FATAL(hasReached, "const not found");
467 return resultConst;
468 }
469
Finish()470 void Simplify::Finish() {}
471
472 // Join `num` `byte`s into a number
473 // Example:
474 // byte num output
475 // 0x0a 2 0x0a0a
476 // 0x12 4 0x12121212
477 // 0xff 8 0xffffffffffffffff
JoinBytes(int byte,uint32 num)478 static uint64 JoinBytes(int byte, uint32 num)
479 {
480 CHECK_FATAL(num <= 8, "not support"); // just support num less or equal 8, see comment above
481 uint64 realByte = static_cast<uint64>(byte % 256);
482 if (realByte == 0) {
483 return 0;
484 }
485 uint64 result = 0;
486 for (uint32 i = 0; i < num; ++i) {
487 result += (realByte << (i * k8BitSize));
488 }
489 return result;
490 }
491
ConstructConstvalNode(uint64 val,PrimType primType,MIRBuilder & mirBuilder)492 static BaseNode *ConstructConstvalNode(uint64 val, PrimType primType, MIRBuilder &mirBuilder)
493 {
494 PrimType constPrimType = primType;
495 if (IsPrimitiveFloat(primType)) {
496 constPrimType = GetIntegerPrimTypeBySizeAndSign(GetPrimTypeBitSize(primType), false);
497 }
498 MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(constPrimType));
499 MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(val, *constType);
500 BaseNode *ret = mirBuilder.CreateConstval(mirConst);
501 if (IsPrimitiveFloat(primType)) {
502 MIRType *floatType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(primType));
503 ret = mirBuilder.CreateExprRetype(*floatType, constPrimType, ret);
504 }
505 return ret;
506 }
507
ConstructConstvalNode(int64 byte,uint64 num,PrimType primType,MIRBuilder & mirBuilder)508 static BaseNode *ConstructConstvalNode(int64 byte, uint64 num, PrimType primType, MIRBuilder &mirBuilder)
509 {
510 auto val = JoinBytes(byte, static_cast<uint32>(num));
511 return ConstructConstvalNode(val, primType, mirBuilder);
512 }
513
514 // Input total size of memory, split the memory into several blocks, the max block size is 8 bytes
515 // Example:
516 // input output
517 // 40 [ 8, 8, 8, 8, 8 ]
518 // 31 [ 8, 8, 8, 4, 2, 1 ]
SplitMemoryIntoBlocks(size_t totalMemorySize,std::vector<uint32> & blocks)519 static void SplitMemoryIntoBlocks(size_t totalMemorySize, std::vector<uint32> &blocks)
520 {
521 size_t leftSize = totalMemorySize;
522 size_t curBlockSize = kMaxMemoryBlockSizeToAssign; // max block size in byte
523 while (curBlockSize > 0) {
524 size_t n = leftSize / curBlockSize;
525 blocks.insert(blocks.end(), n, curBlockSize);
526 leftSize -= (n * curBlockSize);
527 curBlockSize = curBlockSize >> 1;
528 }
529 }
530
IsComplexExpr(const BaseNode * expr,MIRFunction & func)531 static bool IsComplexExpr(const BaseNode *expr, MIRFunction &func)
532 {
533 Opcode op = expr->GetOpCode();
534 if (op == OP_regread) {
535 return false;
536 }
537 if (op == OP_dread) {
538 auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const DreadNode *>(expr)->GetStIdx());
539 DEBUG_ASSERT(symbol != nullptr, "nullptr check");
540 if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
541 return true; // dread global/static var is complex expr because it will be lowered to adrp + add
542 } else {
543 return false;
544 }
545 }
546 if (op == OP_addrof) {
547 auto *symbol = func.GetLocalOrGlobalSymbol(static_cast<const AddrofNode *>(expr)->GetStIdx());
548 if (symbol->IsGlobal() || symbol->GetStorageClass() == kScPstatic) {
549 return true; // addrof global/static var is complex expr because it will be lowered to adrp + add
550 } else {
551 return false;
552 }
553 }
554 return true;
555 }
556
BuildAsRhsExpr(MIRFunction & func) const557 BaseNode *MemEntry::BuildAsRhsExpr(MIRFunction &func) const
558 {
559 BaseNode *expr = nullptr;
560 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
561 if (addrExpr->GetOpCode() == OP_addrof) {
562 // We prefer dread to iread
563 // consider iaddrof if possible
564 auto *addrof = static_cast<AddrofNode *>(addrExpr);
565 auto *symbol = func.GetLocalOrGlobalSymbol(addrof->GetStIdx());
566 expr = mirBuilder->CreateExprDread(*memType, addrof->GetFieldID(), *symbol);
567 } else {
568 MIRType *structPtrType = GlobalTables::GetTypeTable().GetOrCreatePointerType(*memType);
569 expr = mirBuilder->CreateExprIread(*memType, *structPtrType, 0, addrExpr);
570 }
571 return expr;
572 }
573
InsertAndMayPrintStmt(BlockNode & block,const StmtNode & anchor,bool debug,StmtNode * stmt)574 static void InsertAndMayPrintStmt(BlockNode &block, const StmtNode &anchor, bool debug, StmtNode *stmt)
575 {
576 if (stmt == nullptr) {
577 return;
578 }
579 block.InsertBefore(&anchor, stmt);
580 if (debug) {
581 stmt->Dump(0);
582 }
583 }
584
TryToExtractComplexExpr(BaseNode * expr,MIRFunction & func,BlockNode & block,const StmtNode & anchor,bool debug)585 static BaseNode *TryToExtractComplexExpr(BaseNode *expr, MIRFunction &func, BlockNode &block, const StmtNode &anchor,
586 bool debug)
587 {
588 if (!IsComplexExpr(expr, func)) {
589 return expr;
590 }
591 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
592 auto pregIdx = func.GetPregTab()->CreatePreg(PTY_ptr);
593 StmtNode *regassign = mirBuilder->CreateStmtRegassign(PTY_ptr, pregIdx, expr);
594 InsertAndMayPrintStmt(block, anchor, debug, regassign);
595 auto *extractedExpr = mirBuilder->CreateExprRegread(PTY_ptr, pregIdx);
596 return extractedExpr;
597 }
598
ExpandMemsetLowLevel(int64 byte,uint64 size,MIRFunction & func,StmtNode & stmt,BlockNode & block,MemOpKind memOpKind,bool debug,ErrorNumber errorNumber) const599 void MemEntry::ExpandMemsetLowLevel(int64 byte, uint64 size, MIRFunction &func, StmtNode &stmt, BlockNode &block,
600 MemOpKind memOpKind, bool debug, ErrorNumber errorNumber) const
601 {
602 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
603 std::vector<uint32> blocks;
604 SplitMemoryIntoBlocks(size, blocks);
605 int32 offset = 0;
606 // If blocks.size() > 1 and `dst` is not a leaf node,
607 // we should extract common expr to avoid redundant expression
608 BaseNode *realDstExpr = addrExpr;
609 if (blocks.size() > 1) {
610 realDstExpr = TryToExtractComplexExpr(addrExpr, func, block, stmt, debug);
611 }
612 BaseNode *readConst = nullptr;
613 // rhs const is big, extract it to avoid redundant expression
614 bool shouldExtractRhs = blocks.size() > 1 && (byte & 0xff) != 0;
615 for (auto curSize : blocks) {
616 // low level memset expand result:
617 // iassignoff <prim-type> <offset> (dstAddrExpr, constval <prim-type> xx)
618 PrimType constType = GetIntegerPrimTypeBySizeAndSign(curSize * 8, false);
619 BaseNode *rhsExpr = ConstructConstvalNode(byte, curSize, constType, *mirBuilder);
620 if (shouldExtractRhs) {
621 // we only need to extract u64 const once
622 PregIdx pregIdx = func.GetPregTab()->CreatePreg(constType);
623 auto *constAssign = mirBuilder->CreateStmtRegassign(constType, pregIdx, rhsExpr);
624 InsertAndMayPrintStmt(block, stmt, debug, constAssign);
625 readConst = mirBuilder->CreateExprRegread(constType, pregIdx);
626 shouldExtractRhs = false;
627 }
628 if (readConst != nullptr && curSize == kMaxMemoryBlockSizeToAssign) {
629 rhsExpr = readConst;
630 }
631 auto *iassignoff = mirBuilder->CreateStmtIassignoff(constType, offset, realDstExpr, rhsExpr);
632 InsertAndMayPrintStmt(block, stmt, debug, iassignoff);
633 if (debug) {
634 ASSERT_NOT_NULL(iassignoff);
635 iassignoff->Dump(0);
636 }
637 offset += static_cast<int32>(curSize);
638 }
639 // handle memset return val
640 auto *retAssign = GenMemopRetAssign(stmt, func, true, memOpKind, errorNumber);
641 InsertAndMayPrintStmt(block, stmt, debug, retAssign);
642 // return ERRNO_INVAL if memset_s dest is NULL
643 block.RemoveStmt(&stmt);
644 }
645
646 // handle memset, memcpy return val
GenMemopRetAssign(StmtNode & stmt,MIRFunction & func,bool isLowLevel,MemOpKind memOpKind,ErrorNumber errorNumber)647 StmtNode *MemEntry::GenMemopRetAssign(StmtNode &stmt, MIRFunction &func, bool isLowLevel, MemOpKind memOpKind,
648 ErrorNumber errorNumber)
649 {
650 if (stmt.GetOpCode() != OP_call && stmt.GetOpCode() != OP_callassigned) {
651 return nullptr;
652 }
653 auto &callStmt = static_cast<CallNode &>(stmt);
654 const auto &retVec = callStmt.GetReturnVec();
655 if (retVec.empty()) {
656 return nullptr;
657 }
658 MIRBuilder *mirBuilder = func.GetModule()->GetMIRBuilder();
659 BaseNode *rhs = callStmt.Opnd(0); // for memset, memcpy
660 if (memOpKind == MEM_OP_memset_s || memOpKind == MEM_OP_memcpy_s) {
661 // memset_s and memcpy_s must return an errorNumber
662 MIRType *constType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(PTY_i32));
663 MIRConst *mirConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(errorNumber, *constType);
664 rhs = mirBuilder->CreateConstval(mirConst);
665 }
666 if (!retVec[0].second.IsReg()) {
667 auto *retAssign = mirBuilder->CreateStmtDassign(retVec[0].first, 0, rhs);
668 return retAssign;
669 } else {
670 PregIdx pregIdx = retVec[0].second.GetPregIdx();
671 auto pregType = func.GetPregTab()->GetPregTableItem(static_cast<uint32>(pregIdx))->GetPrimType();
672 auto *retAssign = mirBuilder->CreateStmtRegassign(pregType, pregIdx, rhs);
673 if (isLowLevel) {
674 retAssign->GetRHS()->SetPrimType(pregType);
675 }
676 return retAssign;
677 }
678 }
679
ComputeMemOpKind(StmtNode & stmt)680 MemOpKind SimplifyMemOp::ComputeMemOpKind(StmtNode &stmt)
681 {
682 if (stmt.GetOpCode() == OP_intrinsiccall) {
683 auto intrinsicID = static_cast<IntrinsiccallNode &>(stmt).GetIntrinsic();
684 if (intrinsicID == INTRN_C_memset) {
685 return MEM_OP_memset;
686 } else if (intrinsicID == INTRN_C_memcpy) {
687 return MEM_OP_memcpy;
688 }
689 }
690 // lowered memop function (such as memset) may be a call, not callassigned
691 if (stmt.GetOpCode() != OP_callassigned && stmt.GetOpCode() != OP_call) {
692 return MEM_OP_unknown;
693 }
694 auto &callStmt = static_cast<CallNode &>(stmt);
695 MIRFunction *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callStmt.GetPUIdx());
696 const char *funcName = func->GetName().c_str();
697 if (strcmp(funcName, kFuncNameOfMemset) == 0) {
698 return MEM_OP_memset;
699 }
700 if (strcmp(funcName, kFuncNameOfMemcpy) == 0) {
701 return MEM_OP_memcpy;
702 }
703 if (strcmp(funcName, kFuncNameOfMemsetS) == 0) {
704 return MEM_OP_memset_s;
705 }
706 if (strcmp(funcName, kFuncNameOfMemcpyS) == 0) {
707 return MEM_OP_memcpy_s;
708 }
709 return MEM_OP_unknown;
710 }
711
AutoSimplify(StmtNode & stmt,BlockNode & block,bool isLowLevel)712 bool SimplifyMemOp::AutoSimplify(StmtNode &stmt, BlockNode &block, bool isLowLevel)
713 {
714 return false;
715 }
716 } // namespace maple
717