1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mir_lower.h"
17 #include "constantfold.h"
18 #include "ext_constantfold.h"
19 #include "me_option.h"
20
21 #define DO_LT_0_CHECK 1
22
23 namespace maple {
24
RoundUpConst(uint64 offset,uint32 align)25 static constexpr uint64 RoundUpConst(uint64 offset, uint32 align)
26 {
27 DEBUG_ASSERT(offset <= UINT64_MAX - align, "must not be zero");
28 return (-align) & (offset + align - 1);
29 }
30
RoundUp(uint64 offset,uint32 align)31 static inline uint64 RoundUp(uint64 offset, uint32 align)
32 {
33 if (align == 0) {
34 return offset;
35 }
36 return RoundUpConst(offset, align);
37 }
38
39 // Remove intrinsicop __builtin_expect and record likely info to brStmt
40 // Target condExpr example:
41 // ne u1 i64 (
42 // intrinsicop i64 C___builtin_expect (
43 // cvt i64 i32 (dread i32 %levVar_9354), cvt i64 i32 (constval i32 0)),
44 // constval i64 0)
LowerCondGotoStmtWithBuiltinExpect(CondGotoNode & brStmt)45 void LowerCondGotoStmtWithBuiltinExpect(CondGotoNode &brStmt)
46 {
47 BaseNode *condExpr = brStmt.Opnd(0);
48 // Poke ne for dread shortCircuit
49 // Example:
50 // dassign %shortCircuit 0 (ne u1 i64 (
51 // intrinsicop i64 C___builtin_expect (
52 // cvt i64 i32 (dread i32 %levVar_32349),
53 // cvt i64 i32 (constval i32 0)),
54 // constval i64 0))
55 // dassign %shortCircuit 0 (ne u1 u32 (dread u32 %shortCircuit, constval u1 0))
56 if (condExpr->GetOpCode() == OP_ne && condExpr->Opnd(0)->GetOpCode() == OP_dread &&
57 condExpr->Opnd(1)->GetOpCode() == OP_constval) {
58 auto *constVal = static_cast<ConstvalNode *>(condExpr->Opnd(1))->GetConstVal();
59 if (constVal->GetKind() == kConstInt && static_cast<MIRIntConst *>(constVal)->GetValue() == 0) {
60 condExpr = condExpr->Opnd(0);
61 }
62 }
63 if (condExpr->GetOpCode() == OP_dread) {
64 // Example:
65 // dassign %shortCircuit 0 (ne u1 i64 (
66 // intrinsicop i64 C___builtin_expect (
67 // cvt i64 i32 (dread i32 %levVar_9488),
68 // cvt i64 i32 (constval i32 1)),
69 // constval i64 0))
70 // brfalse @shortCircuit_label_13351 (dread u32 %shortCircuit)
71 StIdx stIdx = static_cast<DreadNode *>(condExpr)->GetStIdx();
72 FieldID fieldId = static_cast<DreadNode *>(condExpr)->GetFieldID();
73 if (fieldId != 0) {
74 return;
75 }
76 if (brStmt.GetPrev() == nullptr || brStmt.GetPrev()->GetOpCode() != OP_dassign) {
77 return; // prev stmt may be a label, we skip it too
78 }
79 auto *dassign = static_cast<DassignNode *>(brStmt.GetPrev());
80 if (stIdx != dassign->GetStIdx() || dassign->GetFieldID() != 0) {
81 return;
82 }
83 condExpr = dassign->GetRHS();
84 }
85 if (condExpr->GetOpCode() == OP_ne) {
86 // opnd1 must be int const 0
87 BaseNode *opnd1 = condExpr->Opnd(1);
88 if (opnd1->GetOpCode() != OP_constval) {
89 return;
90 }
91 auto *constVal = static_cast<ConstvalNode *>(opnd1)->GetConstVal();
92 if (constVal->GetKind() != kConstInt || static_cast<MIRIntConst *>(constVal)->GetValue() != 0) {
93 return;
94 }
95 // opnd0 must be intrinsicop C___builtin_expect
96 BaseNode *opnd0 = condExpr->Opnd(0);
97 if (opnd0->GetOpCode() != OP_intrinsicop ||
98 static_cast<IntrinsicopNode *>(opnd0)->GetIntrinsic() != INTRN_C___builtin_expect) {
99 return;
100 }
101 // We trust constant fold
102 auto *expectedConstExpr = opnd0->Opnd(1);
103 if (expectedConstExpr->GetOpCode() == OP_cvt) {
104 expectedConstExpr = expectedConstExpr->Opnd(0);
105 }
106 if (expectedConstExpr->GetOpCode() != OP_constval) {
107 return;
108 }
109 auto *expectedConstNode = static_cast<ConstvalNode *>(expectedConstExpr)->GetConstVal();
110 CHECK_FATAL(expectedConstNode->GetKind() == kConstInt, "must be");
111 auto expectedVal = static_cast<MIRIntConst *>(expectedConstNode)->GetValue();
112 if (expectedVal != 0 && expectedVal != 1) {
113 return;
114 }
115 bool likelyTrue = (expectedVal == 1); // The condition is likely to be true
116 bool likelyBranch = (brStmt.GetOpCode() == OP_brtrue ? likelyTrue : !likelyTrue); // High probability jump
117 if (likelyBranch) {
118 brStmt.SetBranchProb(kProbLikely);
119 } else {
120 brStmt.SetBranchProb(kProbUnlikely);
121 }
122 // Remove __builtin_expect
123 condExpr->SetOpnd(opnd0->Opnd(0), 0);
124 }
125 }
126
LowerBuiltinExpect(BlockNode & block)127 void MIRLower::LowerBuiltinExpect(BlockNode &block)
128 {
129 auto *stmt = block.GetFirst();
130 auto *last = block.GetLast();
131 while (stmt != last) {
132 if (stmt->GetOpCode() == OP_brtrue || stmt->GetOpCode() == OP_brfalse) {
133 LowerCondGotoStmtWithBuiltinExpect(*static_cast<CondGotoNode *>(stmt));
134 }
135 stmt = stmt->GetNext();
136 }
137 }
138
CreateCondGotoStmt(Opcode op,BlockNode & blk,const IfStmtNode & ifStmt)139 LabelIdx MIRLower::CreateCondGotoStmt(Opcode op, BlockNode &blk, const IfStmtNode &ifStmt)
140 {
141 auto *brStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(op);
142 brStmt->SetOpnd(ifStmt.Opnd(), 0);
143 brStmt->SetSrcPos(ifStmt.GetSrcPos());
144 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
145 LabelIdx lableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
146 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lableIdx);
147 brStmt->SetOffset(lableIdx);
148 blk.AddStatement(brStmt);
149 if (GetFuncProfData()) {
150 GetFuncProfData()->CopyStmtFreq(brStmt->GetStmtID(), ifStmt.GetStmtID());
151 }
152 bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
153 if (thenEmpty) {
154 blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
155 } else {
156 blk.AppendStatementsFromBlock(*ifStmt.GetThenPart());
157 }
158 return lableIdx;
159 }
160
CreateBrFalseStmt(BlockNode & blk,const IfStmtNode & ifStmt)161 void MIRLower::CreateBrFalseStmt(BlockNode &blk, const IfStmtNode &ifStmt)
162 {
163 LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
164 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
165 lableStmt->SetLabelIdx(labelIdx);
166 blk.AddStatement(lableStmt);
167 // set stmtfreqs
168 if (GetFuncProfData()) {
169 DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID()) >= 0, "sanity check");
170 int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
171 GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID());
172 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
173 }
174 }
175
CreateBrTrueStmt(BlockNode & blk,const IfStmtNode & ifStmt)176 void MIRLower::CreateBrTrueStmt(BlockNode &blk, const IfStmtNode &ifStmt)
177 {
178 LabelIdx labelIdx = CreateCondGotoStmt(OP_brtrue, blk, ifStmt);
179 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
180 lableStmt->SetLabelIdx(labelIdx);
181 blk.AddStatement(lableStmt);
182 // set stmtfreqs
183 if (GetFuncProfData()) {
184 DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID()) >= 0, "sanity check");
185 int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
186 GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID());
187 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
188 }
189 }
190
CreateBrFalseAndGotoStmt(BlockNode & blk,const IfStmtNode & ifStmt)191 void MIRLower::CreateBrFalseAndGotoStmt(BlockNode &blk, const IfStmtNode &ifStmt)
192 {
193 LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
194 bool fallThroughFromThen = !IfStmtNoFallThrough(ifStmt);
195 LabelIdx gotoLableIdx = 0;
196 if (fallThroughFromThen) {
197 auto *gotoStmt = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
198 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
199 gotoLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
200 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(gotoLableIdx);
201 gotoStmt->SetOffset(gotoLableIdx);
202 blk.AddStatement(gotoStmt);
203 // set stmtfreqs
204 if (GetFuncProfData()) {
205 GetFuncProfData()->CopyStmtFreq(gotoStmt->GetStmtID(), ifStmt.GetThenPart()->GetStmtID());
206 }
207 }
208 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
209 lableStmt->SetLabelIdx(labelIdx);
210 blk.AddStatement(lableStmt);
211 blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
212 // set stmtfreqs
213 if (GetFuncProfData()) {
214 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetElsePart()->GetStmtID());
215 }
216 if (fallThroughFromThen) {
217 lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
218 lableStmt->SetLabelIdx(gotoLableIdx);
219 blk.AddStatement(lableStmt);
220 // set endlabel stmtfreqs
221 if (GetFuncProfData()) {
222 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetStmtID());
223 }
224 }
225 }
226
LowerIfStmt(IfStmtNode & ifStmt,bool recursive)227 BlockNode *MIRLower::LowerIfStmt(IfStmtNode &ifStmt, bool recursive)
228 {
229 bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
230 bool elseEmpty = (ifStmt.GetElsePart() == nullptr) || (ifStmt.GetElsePart()->GetFirst() == nullptr);
231 if (recursive) {
232 if (!thenEmpty) {
233 ifStmt.SetThenPart(LowerBlock(*ifStmt.GetThenPart()));
234 }
235 if (!elseEmpty) {
236 ifStmt.SetElsePart(LowerBlock(*ifStmt.GetElsePart()));
237 }
238 }
239 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
240 if (thenEmpty && elseEmpty) {
241 // generate EVAL <cond> statement
242 auto *evalStmt = mirModule.CurFuncCodeMemPool()->New<UnaryStmtNode>(OP_eval);
243 evalStmt->SetOpnd(ifStmt.Opnd(), 0);
244 evalStmt->SetSrcPos(ifStmt.GetSrcPos());
245 blk->AddStatement(evalStmt);
246 if (GetFuncProfData()) {
247 GetFuncProfData()->CopyStmtFreq(evalStmt->GetStmtID(), ifStmt.GetStmtID());
248 }
249 } else if (elseEmpty) {
250 // brfalse <cond> <endlabel>
251 // <thenPart>
252 // label <endlabel>
253 CreateBrFalseStmt(*blk, ifStmt);
254 } else if (thenEmpty) {
255 // brtrue <cond> <endlabel>
256 // <elsePart>
257 // label <endlabel>
258 CreateBrTrueStmt(*blk, ifStmt);
259 } else {
260 // brfalse <cond> <elselabel>
261 // <thenPart>
262 // goto <endlabel>
263 // label <elselabel>
264 // <elsePart>
265 // label <endlabel>
266 CreateBrFalseAndGotoStmt(*blk, ifStmt);
267 }
268 return blk;
269 }
270
ConsecutiveCaseValsAndSameTarget(const CaseVector * switchTable)271 static bool ConsecutiveCaseValsAndSameTarget(const CaseVector *switchTable)
272 {
273 size_t caseNum = switchTable->size();
274 int lastVal = static_cast<int>((*switchTable)[0].first);
275 LabelIdx lblIdx = (*switchTable)[0].second;
276 for (size_t id = 1; id < caseNum; id++) {
277 lastVal++;
278 if (lastVal != (*switchTable)[id].first) {
279 return false;
280 }
281 if (lblIdx != (*switchTable)[id].second) {
282 return false;
283 }
284 }
285 return true;
286 }
287
288 // if there is only 1 case branch, replace with conditional branch(es) and
289 // return the optimized multiple statements; otherwise, return nullptr
LowerSwitchStmt(SwitchNode * switchNode)290 BlockNode *MIRLower::LowerSwitchStmt(SwitchNode *switchNode)
291 {
292 CaseVector *switchTable = &switchNode->GetSwitchTable();
293 if (switchTable->empty()) { // goto @defaultLabel
294 BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
295 LabelIdx defaultLabel = switchNode->GetDefaultLabel();
296 MIRBuilder *builder = mirModule.GetMIRBuilder();
297 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
298 blk->AddStatement(gotoStmt);
299 return blk;
300 }
301 if (!ConsecutiveCaseValsAndSameTarget(switchTable)) {
302 return nullptr;
303 }
304 BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
305 LabelIdx caseGotoLabel = switchTable->front().second;
306 LabelIdx defaultLabel = switchNode->GetDefaultLabel();
307 int64 minCaseVal = switchTable->front().first;
308 int64 maxCaseVal = switchTable->back().first;
309 BaseNode *switchOpnd = switchNode->Opnd(0);
310 MIRBuilder *builder = mirModule.GetMIRBuilder();
311 ConstvalNode *minCaseNode = builder->CreateIntConst(minCaseVal, switchOpnd->GetPrimType());
312 ConstvalNode *maxCaseNode = builder->CreateIntConst(maxCaseVal, switchOpnd->GetPrimType());
313 if (minCaseVal == maxCaseVal) {
314 // brtrue (x == minCaseVal) @case_goto_label
315 // goto @default_label
316 CompareNode *eqNode = builder->CreateExprCompare(
317 OP_eq, *GlobalTables::GetTypeTable().GetInt32(),
318 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
319 CondGotoNode *condGoto = builder->CreateStmtCondGoto(eqNode, OP_brtrue, caseGotoLabel);
320 blk->AddStatement(condGoto);
321 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
322 blk->AddStatement(gotoStmt);
323 } else {
324 // brtrue (x < minCaseVal) @default_label
325 // brtrue (x > maxCaseVal) @default_label
326 // goto @case_goto_label
327 CompareNode *ltNode = builder->CreateExprCompare(
328 OP_lt, *GlobalTables::GetTypeTable().GetInt32(),
329 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
330 CondGotoNode *condGoto = builder->CreateStmtCondGoto(ltNode, OP_brtrue, defaultLabel);
331 blk->AddStatement(condGoto);
332 CompareNode *gtNode = builder->CreateExprCompare(
333 OP_gt, *GlobalTables::GetTypeTable().GetInt32(),
334 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, maxCaseNode);
335 condGoto = builder->CreateStmtCondGoto(gtNode, OP_brtrue, defaultLabel);
336 blk->AddStatement(condGoto);
337 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, caseGotoLabel);
338 blk->AddStatement(gotoStmt);
339 }
340 return blk;
341 }
342
343 // while <cond> <body>
344 // is lowered to:
345 // brfalse <cond> <endlabel>
346 // label <bodylabel>
347 // <body>
348 // brtrue <cond> <bodylabel>
349 // label <endlabel>
LowerWhileStmt(WhileStmtNode & whileStmt)350 BlockNode *MIRLower::LowerWhileStmt(WhileStmtNode &whileStmt)
351 {
352 DEBUG_ASSERT(whileStmt.GetBody() != nullptr, "nullptr check");
353 whileStmt.SetBody(LowerBlock(*whileStmt.GetBody()));
354 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
355 auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
356 brFalseStmt->SetOpnd(whileStmt.Opnd(0), 0);
357 brFalseStmt->SetSrcPos(whileStmt.GetSrcPos());
358 LabelIdx lalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
359 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lalbeIdx);
360 brFalseStmt->SetOffset(lalbeIdx);
361 blk->AddStatement(brFalseStmt);
362 blk->AppendStatementsFromBlock(*whileStmt.GetBody());
363 if (MeOption::optForSize) {
364 // still keep while-do format to avoid coping too much condition-related stmt
365 LabelIdx whileLalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
366 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(whileLalbeIdx);
367 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
368 lableStmt->SetLabelIdx(whileLalbeIdx);
369 blk->InsertBefore(brFalseStmt, lableStmt);
370 auto *whilegotonode = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto, whileLalbeIdx);
371 if (GetFuncProfData() && blk->GetLast()) {
372 GetFuncProfData()->CopyStmtFreq(whilegotonode->GetStmtID(), blk->GetLast()->GetStmtID());
373 }
374 blk->AddStatement(whilegotonode);
375 } else {
376 LabelIdx bodyLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
377 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLableIdx);
378 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
379 lableStmt->SetLabelIdx(bodyLableIdx);
380 blk->InsertAfter(brFalseStmt, lableStmt);
381 // update frequency
382 if (GetFuncProfData()) {
383 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), whileStmt.GetStmtID());
384 GetFuncProfData()->CopyStmtFreq(brFalseStmt->GetStmtID(), whileStmt.GetStmtID());
385 }
386 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
387 brTrueStmt->SetOpnd(whileStmt.Opnd(0)->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
388 brTrueStmt->SetOffset(bodyLableIdx);
389 if (GetFuncProfData() && blk->GetLast()) {
390 GetFuncProfData()->CopyStmtFreq(brTrueStmt->GetStmtID(), whileStmt.GetBody()->GetStmtID());
391 }
392 blk->AddStatement(brTrueStmt);
393 }
394 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
395 lableStmt->SetLabelIdx(lalbeIdx);
396 blk->AddStatement(lableStmt);
397 if (GetFuncProfData()) {
398 int64_t freq = GetFuncProfData()->GetStmtFreq(whileStmt.GetStmtID()) -
399 GetFuncProfData()->GetStmtFreq(blk->GetLast()->GetStmtID());
400 DEBUG_ASSERT(freq >= 0, "sanity check");
401 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
402 }
403 return blk;
404 }
405
406 // doloop <do-var>(<start-expr>,<cont-expr>,<incr-amt>) {<body-stmts>}
407 // is lowered to:
408 // dassign <do-var> (<start-expr>)
409 // brfalse <cond-expr> <endlabel>
410 // label <bodylabel>
411 // <body-stmts>
412 // dassign <do-var> (<incr-amt>)
413 // brtrue <cond-expr> <bodylabel>
414 // label <endlabel>
LowerDoloopStmt(DoloopNode & doloop)415 BlockNode *MIRLower::LowerDoloopStmt(DoloopNode &doloop)
416 {
417 DEBUG_ASSERT(doloop.GetDoBody() != nullptr, "nullptr check");
418 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
419 doloop.SetDoBody(LowerBlock(*doloop.GetDoBody()));
420 int64_t doloopnodeFreq = 0, bodynodeFreq = 0;
421 if (GetFuncProfData()) {
422 doloopnodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetStmtID());
423 bodynodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetDoBody()->GetStmtID());
424 }
425 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
426 if (doloop.IsPreg()) {
427 PregIdx regIdx = static_cast<PregIdx>(doloop.GetDoVarStIdx().FullIdx());
428 MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
429 PrimType primType = mirPreg->GetPrimType();
430 DEBUG_ASSERT(primType != kPtyInvalid, "runtime check error");
431 auto *startRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
432 startRegassign->SetRegIdx(regIdx);
433 startRegassign->SetPrimType(primType);
434 startRegassign->SetOpnd(doloop.GetStartExpr(), 0);
435 startRegassign->SetSrcPos(doloop.GetSrcPos());
436 blk->AddStatement(startRegassign);
437 } else {
438 auto *startDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
439 startDassign->SetStIdx(doloop.GetDoVarStIdx());
440 startDassign->SetRHS(doloop.GetStartExpr());
441 startDassign->SetSrcPos(doloop.GetSrcPos());
442 blk->AddStatement(startDassign);
443 }
444 if (GetFuncProfData()) {
445 GetFuncProfData()->SetStmtFreq(blk->GetLast()->GetStmtID(), doloopnodeFreq - bodynodeFreq);
446 }
447 auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
448 brFalseStmt->SetOpnd(doloop.GetCondExpr(), 0);
449 LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
450 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
451 brFalseStmt->SetOffset(lIdx);
452 blk->AddStatement(brFalseStmt);
453 // udpate stmtFreq
454 if (GetFuncProfData()) {
455 GetFuncProfData()->SetStmtFreq(brFalseStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
456 }
457 LabelIdx bodyLabelIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
458 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLabelIdx);
459 auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
460 labelStmt->SetLabelIdx(bodyLabelIdx);
461 blk->AddStatement(labelStmt);
462 // udpate stmtFreq
463 if (GetFuncProfData()) {
464 GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), bodynodeFreq);
465 }
466 blk->AppendStatementsFromBlock(*doloop.GetDoBody());
467 if (doloop.IsPreg()) {
468 PregIdx regIdx = (PregIdx)doloop.GetDoVarStIdx().FullIdx();
469 MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
470 PrimType doVarPType = mirPreg->GetPrimType();
471 DEBUG_ASSERT(doVarPType != kPtyInvalid, "runtime check error");
472 auto *readDoVar = mirModule.CurFuncCodeMemPool()->New<RegreadNode>();
473 readDoVar->SetRegIdx(regIdx);
474 readDoVar->SetPrimType(doVarPType);
475 auto *add =
476 mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, doloop.GetIncrExpr(), readDoVar);
477 auto *endRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
478 endRegassign->SetRegIdx(regIdx);
479 endRegassign->SetPrimType(doVarPType);
480 endRegassign->SetOpnd(add, 0);
481 blk->AddStatement(endRegassign);
482 } else {
483 const MIRSymbol *doVarSym = mirModule.CurFunction()->GetLocalOrGlobalSymbol(doloop.GetDoVarStIdx());
484 DEBUG_ASSERT(doVarSym != nullptr, "nullptr check");
485 PrimType doVarPType = doVarSym->GetType()->GetPrimType();
486 auto *readDovar =
487 mirModule.CurFuncCodeMemPool()->New<DreadNode>(OP_dread, doVarPType, doloop.GetDoVarStIdx(), 0);
488 auto *add =
489 mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, readDovar, doloop.GetIncrExpr());
490 auto *endDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
491 endDassign->SetStIdx(doloop.GetDoVarStIdx());
492 endDassign->SetRHS(add);
493 blk->AddStatement(endDassign);
494 }
495 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
496 brTrueStmt->SetOpnd(doloop.GetCondExpr()->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
497 brTrueStmt->SetOffset(bodyLabelIdx);
498 blk->AddStatement(brTrueStmt);
499 // udpate stmtFreq
500 if (GetFuncProfData()) {
501 GetFuncProfData()->SetStmtFreq(brTrueStmt->GetStmtID(), bodynodeFreq);
502 }
503 labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
504 labelStmt->SetLabelIdx(lIdx);
505 blk->AddStatement(labelStmt);
506 // udpate stmtFreq
507 if (GetFuncProfData()) {
508 GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
509 }
510 return blk;
511 }
512
513 // dowhile <body> <cond>
514 // is lowered to:
515 // label <bodylabel>
516 // <body>
517 // brtrue <cond> <bodylabel>
LowerDowhileStmt(WhileStmtNode & doWhileStmt)518 BlockNode *MIRLower::LowerDowhileStmt(WhileStmtNode &doWhileStmt)
519 {
520 DEBUG_ASSERT(doWhileStmt.GetBody() != nullptr, "nullptr check");
521 doWhileStmt.SetBody(LowerBlock(*doWhileStmt.GetBody()));
522 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
523 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
524 LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
525 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
526 auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
527 labelStmt->SetLabelIdx(lIdx);
528 blk->AddStatement(labelStmt);
529 blk->AppendStatementsFromBlock(*doWhileStmt.GetBody());
530 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
531 brTrueStmt->SetOpnd(doWhileStmt.Opnd(0), 0);
532 brTrueStmt->SetOffset(lIdx);
533 blk->AddStatement(brTrueStmt);
534 return blk;
535 }
536
LowerBlock(BlockNode & block)537 BlockNode *MIRLower::LowerBlock(BlockNode &block)
538 {
539 auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
540 BlockNode *tmp = nullptr;
541 if (block.GetFirst() == nullptr) {
542 newBlock->SetStmtID(block.GetStmtID()); // keep original block stmtid
543 return newBlock;
544 }
545 StmtNode *nextStmt = block.GetFirst();
546 DEBUG_ASSERT(nextStmt != nullptr, "nullptr check");
547 do {
548 StmtNode *stmt = nextStmt;
549 nextStmt = stmt->GetNext();
550 switch (stmt->GetOpCode()) {
551 case OP_if:
552 tmp = LowerIfStmt(static_cast<IfStmtNode &>(*stmt), true);
553 newBlock->AppendStatementsFromBlock(*tmp);
554 break;
555 case OP_switch:
556 tmp = LowerSwitchStmt(static_cast<SwitchNode *>(stmt));
557 if (tmp != nullptr) {
558 newBlock->AppendStatementsFromBlock(*tmp);
559 } else {
560 newBlock->AddStatement(stmt);
561 }
562 break;
563 case OP_while:
564 newBlock->AppendStatementsFromBlock(*LowerWhileStmt(static_cast<WhileStmtNode &>(*stmt)));
565 break;
566 case OP_dowhile:
567 newBlock->AppendStatementsFromBlock(*LowerDowhileStmt(static_cast<WhileStmtNode &>(*stmt)));
568 break;
569 case OP_doloop:
570 newBlock->AppendStatementsFromBlock(*LowerDoloopStmt(static_cast<DoloopNode &>(*stmt)));
571 break;
572 case OP_icallassigned:
573 case OP_icall: {
574 if (mirModule.IsCModule()) {
575 // convert to icallproto/icallprotoassigned
576 IcallNode *ic = static_cast<IcallNode *>(stmt);
577 ic->SetOpCode(stmt->GetOpCode() == OP_icall ? OP_icallproto : OP_icallprotoassigned);
578 MIRFuncType *funcType = FuncTypeFromFuncPtrExpr(stmt->Opnd(0));
579 CHECK_FATAL(funcType != nullptr, "MIRLower::LowerBlock: cannot find prototype for icall");
580 ic->SetRetTyIdx(funcType->GetTypeIndex());
581 MIRType *retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(funcType->GetRetTyIdx());
582 if (retType->GetPrimType() == PTY_agg && retType->GetSize() > k16BitSize) {
583 funcType->funcAttrs.SetAttr(FUNCATTR_firstarg_return);
584 }
585 }
586 newBlock->AddStatement(stmt);
587 break;
588 }
589 case OP_block:
590 tmp = LowerBlock(static_cast<BlockNode &>(*stmt));
591 newBlock->AppendStatementsFromBlock(*tmp);
592 break;
593 default:
594 newBlock->AddStatement(stmt);
595 break;
596 }
597 } while (nextStmt != nullptr);
598 newBlock->SetStmtID(block.GetStmtID()); // keep original block stmtid
599 return newBlock;
600 }
601
602 // for lowering OP_cand and OP_cior embedded in the expression x which belongs
603 // to curstmt
LowerEmbeddedCandCior(BaseNode * x,StmtNode * curstmt,BlockNode * blk)604 BaseNode *MIRLower::LowerEmbeddedCandCior(BaseNode *x, StmtNode *curstmt, BlockNode *blk)
605 {
606 DEBUG_ASSERT(x != nullptr, "nullptr check");
607 if (x->GetOpCode() == OP_cand || x->GetOpCode() == OP_cior) {
608 MIRBuilder *builder = mirModule.GetMIRBuilder();
609 BinaryNode *bnode = static_cast<BinaryNode *>(x);
610 bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(0), curstmt, blk), 0);
611 PregIdx pregIdx = mirFunc->GetPregTab()->CreatePreg(x->GetPrimType());
612 RegassignNode *regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(0));
613 blk->InsertBefore(curstmt, regass);
614 LabelIdx labIdx = mirFunc->GetLabelTab()->CreateLabel();
615 mirFunc->GetLabelTab()->AddToStringLabelMap(labIdx);
616 BaseNode *cond = builder->CreateExprRegread(x->GetPrimType(), pregIdx);
617 CondGotoNode *cgoto =
618 mirFunc->GetCodeMempool()->New<CondGotoNode>(x->GetOpCode() == OP_cior ? OP_brtrue : OP_brfalse);
619 cgoto->SetOpnd(cond, 0);
620 cgoto->SetOffset(labIdx);
621 blk->InsertBefore(curstmt, cgoto);
622
623 bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(1), curstmt, blk), 1);
624 regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(1));
625 blk->InsertBefore(curstmt, regass);
626 LabelNode *lbl = mirFunc->GetCodeMempool()->New<LabelNode>();
627 lbl->SetLabelIdx(labIdx);
628 blk->InsertBefore(curstmt, lbl);
629 return builder->CreateExprRegread(x->GetPrimType(), pregIdx);
630 } else {
631 for (size_t i = 0; i < x->GetNumOpnds(); i++) {
632 x->SetOpnd(LowerEmbeddedCandCior(x->Opnd(i), curstmt, blk), i);
633 }
634 return x;
635 }
636 }
637
638 // for lowering all appearances of OP_cand and OP_cior associated with condional
639 // branches in the block
LowerCandCior(BlockNode & block)640 void MIRLower::LowerCandCior(BlockNode &block)
641 {
642 if (block.GetFirst() == nullptr) {
643 return;
644 }
645 StmtNode *nextStmt = block.GetFirst();
646 do {
647 StmtNode *stmt = nextStmt;
648 nextStmt = stmt->GetNext();
649 if (stmt->IsCondBr() && (stmt->Opnd(0) != nullptr &&
650 (stmt->Opnd(0)->GetOpCode() == OP_cand || stmt->Opnd(0)->GetOpCode() == OP_cior))) {
651 CondGotoNode *condGoto = static_cast<CondGotoNode *>(stmt);
652 BinaryNode *cond = static_cast<BinaryNode *>(condGoto->Opnd(0));
653 if ((stmt->GetOpCode() == OP_brfalse && cond->GetOpCode() == OP_cand) ||
654 (stmt->GetOpCode() == OP_brtrue && cond->GetOpCode() == OP_cior)) {
655 // short-circuit target label is same as original condGoto stmt
656 condGoto->SetOpnd(cond->GetBOpnd(0), 0);
657 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(Opcode(stmt->GetOpCode()));
658 newCondGoto->SetOpnd(cond->GetBOpnd(1), 0);
659 newCondGoto->SetOffset(condGoto->GetOffset());
660 block.InsertAfter(condGoto, newCondGoto);
661 nextStmt = stmt; // so it will be re-processed if another cand/cior
662 } else { // short-circuit target is next statement
663 LabelIdx lIdx;
664 LabelNode *labelStmt = nullptr;
665 if (nextStmt->GetOpCode() == OP_label) {
666 labelStmt = static_cast<LabelNode *>(nextStmt);
667 lIdx = labelStmt->GetLabelIdx();
668 } else {
669 DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
670 lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
671 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
672 labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
673 labelStmt->SetLabelIdx(lIdx);
674 block.InsertAfter(condGoto, labelStmt);
675 }
676 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(
677 stmt->GetOpCode() == OP_brfalse ? OP_brtrue : OP_brfalse);
678 newCondGoto->SetOpnd(cond->GetBOpnd(0), 0);
679 newCondGoto->SetOffset(lIdx);
680 block.InsertBefore(condGoto, newCondGoto);
681 condGoto->SetOpnd(cond->GetBOpnd(1), 0);
682 nextStmt = newCondGoto; // so it will be re-processed if another cand/cior
683 }
684 } else { // call LowerEmbeddedCandCior() for all the expression operands
685 for (size_t i = 0; i < stmt->GetNumOpnds(); i++) {
686 stmt->SetOpnd(LowerEmbeddedCandCior(stmt->Opnd(i), stmt, &block), i);
687 }
688 }
689 } while (nextStmt != nullptr);
690 }
691
LowerFunc(MIRFunction & func)692 void MIRLower::LowerFunc(MIRFunction &func)
693 {
694 if (GetOptLevel() > 0) {
695 ExtConstantFold ecf(func.GetModule());
696 (void)ecf.ExtSimplify(func.GetBody());
697 ;
698 }
699
700 mirModule.SetCurFunction(&func);
701 if (IsLowerExpandArray()) {
702 ExpandArrayMrt(func);
703 }
704 BlockNode *origBody = func.GetBody();
705 DEBUG_ASSERT(origBody != nullptr, "nullptr check");
706 BlockNode *newBody = LowerBlock(*origBody);
707 DEBUG_ASSERT(newBody != nullptr, "nullptr check");
708 LowerBuiltinExpect(*newBody);
709 if (!InLFO()) {
710 LowerCandCior(*newBody);
711 }
712 func.SetBody(newBody);
713 }
714
LowerFarray(ArrayNode * array)715 BaseNode *MIRLower::LowerFarray(ArrayNode *array)
716 {
717 auto *farrayType = static_cast<MIRFarrayType *>(array->GetArrayType(GlobalTables::GetTypeTable()));
718 size_t eSize = GlobalTables::GetTypeTable().GetTypeFromTyIdx(farrayType->GetElemTyIdx())->GetSize();
719 MIRType &arrayType = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()));
720 /* how about multi-dimension array? */
721 if (array->GetIndex(0)->GetOpCode() == OP_constval) {
722 const ConstvalNode *constvalNode = static_cast<const ConstvalNode *>(array->GetIndex(0));
723 if (constvalNode->GetConstVal()->GetKind() == kConstInt) {
724 const MIRIntConst *pIntConst = static_cast<const MIRIntConst *>(constvalNode->GetConstVal());
725 CHECK_FATAL(!pIntConst->IsNegative(), "Array index should >= 0.");
726 int64 eleOffset = pIntConst->GetExtValue() * static_cast<int64>(eSize);
727
728 BaseNode *baseNode = array->GetBase();
729 if (eleOffset == 0) {
730 return baseNode;
731 }
732
733 MIRIntConst *eleConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(eleOffset, arrayType);
734 BaseNode *offsetNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(eleConst);
735 offsetNode->SetPrimType(array->GetPrimType());
736
737 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
738 rAdd->SetPrimType(array->GetPrimType());
739 rAdd->SetOpnd(baseNode, 0);
740 rAdd->SetOpnd(offsetNode, 1);
741 return rAdd;
742 }
743 }
744
745 BaseNode *rMul = nullptr;
746
747 BaseNode *baseNode = array->GetBase();
748
749 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
750 rAdd->SetPrimType(array->GetPrimType());
751 rAdd->SetOpnd(baseNode, 0);
752 rAdd->SetOpnd(rMul, 1);
753 auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
754 rAdd = (newAdd != nullptr ? newAdd : rAdd);
755 return rAdd;
756 }
757
LowerCArray(ArrayNode * array)758 BaseNode *MIRLower::LowerCArray(ArrayNode *array)
759 {
760 MIRType *aType = array->GetArrayType(GlobalTables::GetTypeTable());
761 if (aType->GetKind() == kTypeJArray) {
762 return array;
763 }
764 if (aType->GetKind() == kTypeFArray) {
765 return LowerFarray(array);
766 }
767
768 MIRArrayType *arrayType = static_cast<MIRArrayType *>(aType);
769 /* There are two cases where dimension > 1.
770 * 1) arrayType->dim > 1. Process the current arrayType. (nestedArray = false)
771 * 2) arrayType->dim == 1, but arraytype->eTyIdx is another array. (nestedArray = true)
772 * Assume at this time 1) and 2) cannot mix.
773 * Along with the array dimension, there is the array indexing.
774 * It is allowed to index arrays less than the dimension.
775 * This is dictated by the number of indexes.
776 */
777 bool nestedArray = false;
778 uint64 dim = arrayType->GetDim();
779 MIRType *innerType = nullptr;
780 MIRArrayType *innerArrayType = nullptr;
781 uint64 elemSize = 0;
782 if (dim == 1) {
783 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(arrayType->GetElemTyIdx());
784 if (innerType->GetKind() == kTypeArray) {
785 nestedArray = true;
786 do {
787 innerArrayType = static_cast<MIRArrayType *>(innerType);
788 elemSize = RoundUp(innerArrayType->GetElemType()->GetSize(), arrayType->GetElemType()->GetAlign());
789 dim++;
790 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
791 } while (innerType->GetKind() == kTypeArray);
792 }
793 }
794
795 CHECK_FATAL(array->NumOpnds() > 0, "must not be zero");
796 size_t numIndex = array->NumOpnds() - 1;
797 MIRArrayType *curArrayType = arrayType;
798 BaseNode *resNode = array->GetIndex(0);
799 if (dim > 1) {
800 BaseNode *prevNode = nullptr;
801 for (size_t i = 0; (i < dim) && (i < numIndex); ++i) {
802 uint32 mpyDim = 1;
803 if (nestedArray) {
804 CHECK_FATAL(arrayType->GetSizeArrayItem(0) > 0, "Zero size array dimension");
805 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(curArrayType->GetElemTyIdx());
806 curArrayType = static_cast<MIRArrayType *>(innerType);
807 while (innerType->GetKind() == kTypeArray) {
808 innerArrayType = static_cast<MIRArrayType *>(innerType);
809 mpyDim *= innerArrayType->GetSizeArrayItem(0);
810 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
811 }
812 } else {
813 CHECK_FATAL(arrayType->GetSizeArrayItem(static_cast<uint32>(i)) > 0, "Zero size array dimension");
814 for (size_t j = i + 1; j < dim; ++j) {
815 mpyDim *= arrayType->GetSizeArrayItem(static_cast<uint32>(j));
816 }
817 }
818
819 BaseNode *index = static_cast<ConstvalNode *>(array->GetIndex(i));
820 bool isConst = false;
821 uint64 indexVal = 0;
822 if (index->op == OP_constval) {
823 ConstvalNode *constNode = static_cast<ConstvalNode *>(index);
824 indexVal = static_cast<uint64>((static_cast<MIRIntConst *>(constNode->GetConstVal()))->GetExtValue());
825 isConst = true;
826 MIRIntConst *newConstNode = mirModule.GetMemPool()->New<MIRIntConst>(
827 indexVal * mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
828 BaseNode *newValNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(newConstNode);
829 newValNode->SetPrimType(array->GetPrimType());
830 if (i == 0) {
831 prevNode = newValNode;
832 continue;
833 } else {
834 resNode = newValNode;
835 }
836 }
837 if (i > 0 && isConst == false) {
838 resNode = array->GetIndex(i);
839 }
840
841 BaseNode *mpyNode;
842 if (isConst) {
843 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
844 static_cast<int64>(mpyDim) * indexVal,
845 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
846 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
847 mulSize->SetPrimType(array->GetPrimType());
848 mpyNode = mulSize;
849 } else if (mpyDim == 1 && prevNode) {
850 mpyNode = prevNode;
851 prevNode = resNode;
852 } else {
853 mpyNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
854 mpyNode->SetPrimType(array->GetPrimType());
855 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
856 mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
857 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
858 mulSize->SetPrimType(array->GetPrimType());
859 mpyNode->SetOpnd(mulSize, 1);
860 PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
861 DEBUG_ASSERT(resNode != nullptr, "resNode should not be nullptr");
862 if (!IsPrimitiveInteger(resNode->GetPrimType())) {
863 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
864 resNode->GetPrimType(), resNode);
865 } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
866 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
867 OP_cvt, array->GetPrimType(), GetRegPrimType(resNode->GetPrimType()), resNode);
868 }
869 mpyNode->SetOpnd(resNode, 0);
870 }
871 if (i == 0) {
872 prevNode = mpyNode;
873 continue;
874 }
875 BaseNode *newResNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
876 newResNode->SetPrimType(array->GetPrimType());
877 newResNode->SetOpnd(mpyNode, 0);
878 if (prevNode != nullptr && NeedCvtOrRetype(prevNode->GetPrimType(), array->GetPrimType())) {
879 prevNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
880 OP_cvt, array->GetPrimType(), GetRegPrimType(prevNode->GetPrimType()), prevNode);
881 }
882 newResNode->SetOpnd(prevNode, 1);
883 prevNode = newResNode;
884 }
885 resNode = prevNode;
886 }
887
888 BaseNode *rMul = nullptr;
889 // esize is the size of the array element (eg. int = 4 long = 8)
890 uint64 esize;
891 if (nestedArray) {
892 esize = elemSize;
893 } else {
894 esize = arrayType->GetElemType()->GetSize();
895 }
896 Opcode opadd = OP_add;
897 MIRIntConst *econst = mirModule.GetMemPool()->New<MIRIntConst>(
898 esize, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
899 BaseNode *eSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(econst);
900 eSize->SetPrimType(array->GetPrimType());
901 rMul = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
902 PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
903 DEBUG_ASSERT(resNode != nullptr, "nullptr check");
904 if (!IsPrimitiveInteger(resNode->GetPrimType())) {
905 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
906 resNode->GetPrimType(), resNode);
907 } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
908 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, array->GetPrimType(),
909 GetRegPrimType(resNode->GetPrimType()), resNode);
910 }
911 rMul->SetPrimType(resNode->GetPrimType());
912 rMul->SetOpnd(resNode, 0);
913 rMul->SetOpnd(eSize, 1);
914 BaseNode *baseNode = array->GetBase();
915 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(opadd);
916 rAdd->SetPrimType(array->GetPrimType());
917 rAdd->SetOpnd(baseNode, 0);
918 rAdd->SetOpnd(rMul, 1);
919 auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
920 rAdd = (newAdd != nullptr ? newAdd : rAdd);
921 return rAdd;
922 }
923
ExpandArrayMrtIfBlock(IfStmtNode & node)924 IfStmtNode *MIRLower::ExpandArrayMrtIfBlock(IfStmtNode &node)
925 {
926 if (node.GetThenPart() != nullptr) {
927 node.SetThenPart(ExpandArrayMrtBlock(*node.GetThenPart()));
928 }
929 if (node.GetElsePart() != nullptr) {
930 node.SetElsePart(ExpandArrayMrtBlock(*node.GetElsePart()));
931 }
932 return &node;
933 }
934
ExpandArrayMrtWhileBlock(WhileStmtNode & node)935 WhileStmtNode *MIRLower::ExpandArrayMrtWhileBlock(WhileStmtNode &node)
936 {
937 if (node.GetBody() != nullptr) {
938 node.SetBody(ExpandArrayMrtBlock(*node.GetBody()));
939 }
940 return &node;
941 }
942
ExpandArrayMrtDoloopBlock(DoloopNode & node)943 DoloopNode *MIRLower::ExpandArrayMrtDoloopBlock(DoloopNode &node)
944 {
945 if (node.GetDoBody() != nullptr) {
946 node.SetDoBody(ExpandArrayMrtBlock(*node.GetDoBody()));
947 }
948 return &node;
949 }
950
ExpandArrayMrtForeachelemBlock(ForeachelemNode & node)951 ForeachelemNode *MIRLower::ExpandArrayMrtForeachelemBlock(ForeachelemNode &node)
952 {
953 if (node.GetLoopBody() != nullptr) {
954 node.SetLoopBody(ExpandArrayMrtBlock(*node.GetLoopBody()));
955 }
956 return &node;
957 }
958
AddArrayMrtMpl(BaseNode & exp,BlockNode & newBlock)959 void MIRLower::AddArrayMrtMpl(BaseNode &exp, BlockNode &newBlock)
960 {
961 MIRModule &mod = mirModule;
962 MIRBuilder *builder = mod.GetMIRBuilder();
963 for (size_t i = 0; i < exp.NumOpnds(); ++i) {
964 DEBUG_ASSERT(exp.Opnd(i) != nullptr, "nullptr check");
965 AddArrayMrtMpl(*exp.Opnd(i), newBlock);
966 }
967 if (exp.GetOpCode() == OP_array) {
968 auto &arrayNode = static_cast<ArrayNode &>(exp);
969 if (arrayNode.GetBoundsCheck()) {
970 BaseNode *arrAddr = arrayNode.Opnd(0);
971 UnaryStmtNode *nullCheck = builder->CreateStmtUnary(OP_assertnonnull, arrAddr);
972 newBlock.AddStatement(nullCheck);
973 }
974 }
975 }
976
ExpandArrayMrtBlock(BlockNode & block)977 BlockNode *MIRLower::ExpandArrayMrtBlock(BlockNode &block)
978 {
979 auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
980 if (block.GetFirst() == nullptr) {
981 return newBlock;
982 }
983 StmtNode *nextStmt = block.GetFirst();
984 do {
985 StmtNode *stmt = nextStmt;
986 DEBUG_ASSERT(stmt != nullptr, "nullptr check");
987 nextStmt = stmt->GetNext();
988 switch (stmt->GetOpCode()) {
989 case OP_if:
990 newBlock->AddStatement(ExpandArrayMrtIfBlock(static_cast<IfStmtNode &>(*stmt)));
991 break;
992 case OP_while:
993 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
994 break;
995 case OP_dowhile:
996 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
997 break;
998 case OP_doloop:
999 newBlock->AddStatement(ExpandArrayMrtDoloopBlock(static_cast<DoloopNode &>(*stmt)));
1000 break;
1001 case OP_foreachelem:
1002 newBlock->AddStatement(ExpandArrayMrtForeachelemBlock(static_cast<ForeachelemNode &>(*stmt)));
1003 break;
1004 case OP_block:
1005 newBlock->AddStatement(ExpandArrayMrtBlock(static_cast<BlockNode &>(*stmt)));
1006 break;
1007 default:
1008 AddArrayMrtMpl(*stmt, *newBlock);
1009 newBlock->AddStatement(stmt);
1010 break;
1011 }
1012 } while (nextStmt != nullptr);
1013 return newBlock;
1014 }
1015
ExpandArrayMrt(MIRFunction & func)1016 void MIRLower::ExpandArrayMrt(MIRFunction &func)
1017 {
1018 if (ShouldOptArrayMrt(func)) {
1019 BlockNode *origBody = func.GetBody();
1020 DEBUG_ASSERT(origBody != nullptr, "nullptr check");
1021 BlockNode *newBody = ExpandArrayMrtBlock(*origBody);
1022 func.SetBody(newBody);
1023 }
1024 }
1025
FuncTypeFromFuncPtrExpr(BaseNode * x)1026 MIRFuncType *MIRLower::FuncTypeFromFuncPtrExpr(BaseNode *x)
1027 {
1028 DEBUG_ASSERT(x != nullptr, "nullptr check");
1029 MIRFuncType *res = nullptr;
1030 MIRFunction *func = mirModule.CurFunction();
1031 switch (x->GetOpCode()) {
1032 case OP_regread: {
1033 RegreadNode *regread = static_cast<RegreadNode *>(x);
1034 MIRPreg *preg = func->GetPregTab()->PregFromPregIdx(regread->GetRegIdx());
1035 // see if it is promoted from a symbol
1036 if (preg->GetOp() == OP_dread) {
1037 const MIRSymbol *symbol = preg->rematInfo.sym;
1038 MIRType *mirType = symbol->GetType();
1039 if (preg->fieldID != 0) {
1040 MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1041 FieldPair thepair = structty->TraverseToField(preg->fieldID);
1042 mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1043 }
1044
1045 if (mirType->GetKind() == kTypePointer) {
1046 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1047 }
1048 if (res != nullptr) {
1049 break;
1050 }
1051 }
1052 // check if a formal promoted to preg
1053 for (FormalDef &formalDef : func->GetFormalDefVec()) {
1054 if (!formalDef.formalSym->IsPreg()) {
1055 continue;
1056 }
1057 if (formalDef.formalSym->GetPreg() == preg) {
1058 MIRType *mirType = formalDef.formalSym->GetType();
1059 if (mirType->GetKind() == kTypePointer) {
1060 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1061 }
1062 break;
1063 }
1064 }
1065 break;
1066 }
1067 case OP_dread: {
1068 DreadNode *dread = static_cast<DreadNode *>(x);
1069 MIRSymbol *symbol = func->GetLocalOrGlobalSymbol(dread->GetStIdx());
1070 MIRType *mirType = symbol->GetType();
1071 if (dread->GetFieldID() != 0) {
1072 MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1073 FieldPair thepair = structty->TraverseToField(dread->GetFieldID());
1074 mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1075 }
1076 if (mirType->GetKind() == kTypePointer) {
1077 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1078 }
1079 break;
1080 }
1081 case OP_iread: {
1082 IreadNode *iread = static_cast<IreadNode *>(x);
1083 MIRPtrType *ptrType = static_cast<MIRPtrType *>(iread->GetType());
1084 MIRType *mirType = ptrType->GetPointedType();
1085 if (mirType->GetKind() == kTypeFunction) {
1086 res = static_cast<MIRFuncType *>(mirType);
1087 } else if (mirType->GetKind() == kTypePointer) {
1088 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1089 }
1090 break;
1091 }
1092 case OP_addroffunc: {
1093 AddroffuncNode *addrofFunc = static_cast<AddroffuncNode *>(x);
1094 PUIdx puIdx = addrofFunc->GetPUIdx();
1095 MIRFunction *f = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
1096 res = f->GetMIRFuncType();
1097 break;
1098 }
1099 case OP_retype: {
1100 MIRType *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(static_cast<RetypeNode *>(x)->GetTyIdx());
1101 if (mirType->GetKind() == kTypePointer) {
1102 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1103 }
1104 if (res == nullptr) {
1105 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeFirstOpnd));
1106 }
1107 break;
1108 }
1109 case OP_select: {
1110 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeSecondOpnd));
1111 if (res == nullptr) {
1112 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeThirdOpnd));
1113 }
1114 break;
1115 }
1116 default:
1117 CHECK_FATAL(false, "LMBCLowerer::FuncTypeFromFuncPtrExpr: NYI");
1118 }
1119 return res;
1120 }
1121
1122 const std::set<std::string> MIRLower::kSetArrayHotFunc = {};
1123
ShouldOptArrayMrt(const MIRFunction & func)1124 bool MIRLower::ShouldOptArrayMrt(const MIRFunction &func)
1125 {
1126 return (MIRLower::kSetArrayHotFunc.find(func.GetName()) != MIRLower::kSetArrayHotFunc.end());
1127 }
1128 } // namespace maple
1129