1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mir_lower.h"
17 #include "constantfold.h"
18 #include "ext_constantfold.h"
19 #include "me_option.h"
20
21 #define DO_LT_0_CHECK 1
22
23 namespace maple {
24
RoundUpConst(uint64 offset,uint32 align)25 static constexpr uint64 RoundUpConst(uint64 offset, uint32 align)
26 {
27 return (-align) & (offset + align - 1);
28 }
29
RoundUp(uint64 offset,uint32 align)30 static inline uint64 RoundUp(uint64 offset, uint32 align)
31 {
32 if (align == 0) {
33 return offset;
34 }
35 return RoundUpConst(offset, align);
36 }
37
38 // Remove intrinsicop __builtin_expect and record likely info to brStmt
39 // Target condExpr example:
40 // ne u1 i64 (
41 // intrinsicop i64 C___builtin_expect (
42 // cvt i64 i32 (dread i32 %levVar_9354), cvt i64 i32 (constval i32 0)),
43 // constval i64 0)
LowerCondGotoStmtWithBuiltinExpect(CondGotoNode & brStmt)44 void LowerCondGotoStmtWithBuiltinExpect(CondGotoNode &brStmt)
45 {
46 BaseNode *condExpr = brStmt.Opnd(0);
47 // Poke ne for dread shortCircuit
48 // Example:
49 // dassign %shortCircuit 0 (ne u1 i64 (
50 // intrinsicop i64 C___builtin_expect (
51 // cvt i64 i32 (dread i32 %levVar_32349),
52 // cvt i64 i32 (constval i32 0)),
53 // constval i64 0))
54 // dassign %shortCircuit 0 (ne u1 u32 (dread u32 %shortCircuit, constval u1 0))
55 if (condExpr->GetOpCode() == OP_ne && condExpr->Opnd(0)->GetOpCode() == OP_dread &&
56 condExpr->Opnd(1)->GetOpCode() == OP_constval) {
57 auto *constVal = static_cast<ConstvalNode *>(condExpr->Opnd(1))->GetConstVal();
58 if (constVal->GetKind() == kConstInt && static_cast<MIRIntConst *>(constVal)->GetValue() == 0) {
59 condExpr = condExpr->Opnd(0);
60 }
61 }
62 if (condExpr->GetOpCode() == OP_dread) {
63 // Example:
64 // dassign %shortCircuit 0 (ne u1 i64 (
65 // intrinsicop i64 C___builtin_expect (
66 // cvt i64 i32 (dread i32 %levVar_9488),
67 // cvt i64 i32 (constval i32 1)),
68 // constval i64 0))
69 // brfalse @shortCircuit_label_13351 (dread u32 %shortCircuit)
70 StIdx stIdx = static_cast<DreadNode *>(condExpr)->GetStIdx();
71 FieldID fieldId = static_cast<DreadNode *>(condExpr)->GetFieldID();
72 if (fieldId != 0) {
73 return;
74 }
75 if (brStmt.GetPrev() == nullptr || brStmt.GetPrev()->GetOpCode() != OP_dassign) {
76 return; // prev stmt may be a label, we skip it too
77 }
78 auto *dassign = static_cast<DassignNode *>(brStmt.GetPrev());
79 if (stIdx != dassign->GetStIdx() || dassign->GetFieldID() != 0) {
80 return;
81 }
82 condExpr = dassign->GetRHS();
83 }
84 if (condExpr->GetOpCode() == OP_ne) {
85 // opnd1 must be int const 0
86 BaseNode *opnd1 = condExpr->Opnd(1);
87 if (opnd1->GetOpCode() != OP_constval) {
88 return;
89 }
90 auto *constVal = static_cast<ConstvalNode *>(opnd1)->GetConstVal();
91 if (constVal->GetKind() != kConstInt || static_cast<MIRIntConst *>(constVal)->GetValue() != 0) {
92 return;
93 }
94 // opnd0 must be intrinsicop C___builtin_expect
95 BaseNode *opnd0 = condExpr->Opnd(0);
96 if (opnd0->GetOpCode() != OP_intrinsicop ||
97 static_cast<IntrinsicopNode *>(opnd0)->GetIntrinsic() != INTRN_C___builtin_expect) {
98 return;
99 }
100 // We trust constant fold
101 auto *expectedConstExpr = opnd0->Opnd(1);
102 if (expectedConstExpr->GetOpCode() == OP_cvt) {
103 expectedConstExpr = expectedConstExpr->Opnd(0);
104 }
105 if (expectedConstExpr->GetOpCode() != OP_constval) {
106 return;
107 }
108 auto *expectedConstNode = static_cast<ConstvalNode *>(expectedConstExpr)->GetConstVal();
109 CHECK_FATAL(expectedConstNode->GetKind() == kConstInt, "must be");
110 auto expectedVal = static_cast<MIRIntConst *>(expectedConstNode)->GetValue();
111 if (expectedVal != 0 && expectedVal != 1) {
112 return;
113 }
114 bool likelyTrue = (expectedVal == 1); // The condition is likely to be true
115 bool likelyBranch = (brStmt.GetOpCode() == OP_brtrue ? likelyTrue : !likelyTrue); // High probability jump
116 if (likelyBranch) {
117 brStmt.SetBranchProb(kProbLikely);
118 } else {
119 brStmt.SetBranchProb(kProbUnlikely);
120 }
121 // Remove __builtin_expect
122 condExpr->SetOpnd(opnd0->Opnd(0), 0);
123 }
124 }
125
LowerBuiltinExpect(BlockNode & block)126 void MIRLower::LowerBuiltinExpect(BlockNode &block)
127 {
128 auto *stmt = block.GetFirst();
129 auto *last = block.GetLast();
130 while (stmt != last) {
131 if (stmt->GetOpCode() == OP_brtrue || stmt->GetOpCode() == OP_brfalse) {
132 LowerCondGotoStmtWithBuiltinExpect(*static_cast<CondGotoNode *>(stmt));
133 }
134 stmt = stmt->GetNext();
135 }
136 }
137
CreateCondGotoStmt(Opcode op,BlockNode & blk,const IfStmtNode & ifStmt)138 LabelIdx MIRLower::CreateCondGotoStmt(Opcode op, BlockNode &blk, const IfStmtNode &ifStmt)
139 {
140 auto *brStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(op);
141 brStmt->SetOpnd(ifStmt.Opnd(), 0);
142 brStmt->SetSrcPos(ifStmt.GetSrcPos());
143 LabelIdx lableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
144 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lableIdx);
145 brStmt->SetOffset(lableIdx);
146 blk.AddStatement(brStmt);
147 if (GetFuncProfData()) {
148 GetFuncProfData()->CopyStmtFreq(brStmt->GetStmtID(), ifStmt.GetStmtID());
149 }
150 bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
151 if (thenEmpty) {
152 blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
153 } else {
154 blk.AppendStatementsFromBlock(*ifStmt.GetThenPart());
155 }
156 return lableIdx;
157 }
158
CreateBrFalseStmt(BlockNode & blk,const IfStmtNode & ifStmt)159 void MIRLower::CreateBrFalseStmt(BlockNode &blk, const IfStmtNode &ifStmt)
160 {
161 LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
162 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
163 lableStmt->SetLabelIdx(labelIdx);
164 blk.AddStatement(lableStmt);
165 // set stmtfreqs
166 if (GetFuncProfData()) {
167 DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID()) >= 0, "sanity check");
168 int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
169 GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID());
170 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
171 }
172 }
173
CreateBrTrueStmt(BlockNode & blk,const IfStmtNode & ifStmt)174 void MIRLower::CreateBrTrueStmt(BlockNode &blk, const IfStmtNode &ifStmt)
175 {
176 LabelIdx labelIdx = CreateCondGotoStmt(OP_brtrue, blk, ifStmt);
177 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
178 lableStmt->SetLabelIdx(labelIdx);
179 blk.AddStatement(lableStmt);
180 // set stmtfreqs
181 if (GetFuncProfData()) {
182 DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID()) >= 0, "sanity check");
183 int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
184 GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID());
185 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
186 }
187 }
188
CreateBrFalseAndGotoStmt(BlockNode & blk,const IfStmtNode & ifStmt)189 void MIRLower::CreateBrFalseAndGotoStmt(BlockNode &blk, const IfStmtNode &ifStmt)
190 {
191 LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
192 bool fallThroughFromThen = !IfStmtNoFallThrough(ifStmt);
193 LabelIdx gotoLableIdx = 0;
194 if (fallThroughFromThen) {
195 auto *gotoStmt = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
196 gotoLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
197 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(gotoLableIdx);
198 gotoStmt->SetOffset(gotoLableIdx);
199 blk.AddStatement(gotoStmt);
200 // set stmtfreqs
201 if (GetFuncProfData()) {
202 GetFuncProfData()->CopyStmtFreq(gotoStmt->GetStmtID(), ifStmt.GetThenPart()->GetStmtID());
203 }
204 }
205 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
206 lableStmt->SetLabelIdx(labelIdx);
207 blk.AddStatement(lableStmt);
208 blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
209 // set stmtfreqs
210 if (GetFuncProfData()) {
211 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetElsePart()->GetStmtID());
212 }
213 if (fallThroughFromThen) {
214 lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
215 lableStmt->SetLabelIdx(gotoLableIdx);
216 blk.AddStatement(lableStmt);
217 // set endlabel stmtfreqs
218 if (GetFuncProfData()) {
219 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetStmtID());
220 }
221 }
222 }
223
LowerIfStmt(IfStmtNode & ifStmt,bool recursive)224 BlockNode *MIRLower::LowerIfStmt(IfStmtNode &ifStmt, bool recursive)
225 {
226 bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
227 bool elseEmpty = (ifStmt.GetElsePart() == nullptr) || (ifStmt.GetElsePart()->GetFirst() == nullptr);
228 if (recursive) {
229 if (!thenEmpty) {
230 ifStmt.SetThenPart(LowerBlock(*ifStmt.GetThenPart()));
231 }
232 if (!elseEmpty) {
233 ifStmt.SetElsePart(LowerBlock(*ifStmt.GetElsePart()));
234 }
235 }
236 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
237 if (thenEmpty && elseEmpty) {
238 // generate EVAL <cond> statement
239 auto *evalStmt = mirModule.CurFuncCodeMemPool()->New<UnaryStmtNode>(OP_eval);
240 evalStmt->SetOpnd(ifStmt.Opnd(), 0);
241 evalStmt->SetSrcPos(ifStmt.GetSrcPos());
242 blk->AddStatement(evalStmt);
243 if (GetFuncProfData()) {
244 GetFuncProfData()->CopyStmtFreq(evalStmt->GetStmtID(), ifStmt.GetStmtID());
245 }
246 } else if (elseEmpty) {
247 // brfalse <cond> <endlabel>
248 // <thenPart>
249 // label <endlabel>
250 CreateBrFalseStmt(*blk, ifStmt);
251 } else if (thenEmpty) {
252 // brtrue <cond> <endlabel>
253 // <elsePart>
254 // label <endlabel>
255 CreateBrTrueStmt(*blk, ifStmt);
256 } else {
257 // brfalse <cond> <elselabel>
258 // <thenPart>
259 // goto <endlabel>
260 // label <elselabel>
261 // <elsePart>
262 // label <endlabel>
263 CreateBrFalseAndGotoStmt(*blk, ifStmt);
264 }
265 return blk;
266 }
267
ConsecutiveCaseValsAndSameTarget(const CaseVector * switchTable)268 static bool ConsecutiveCaseValsAndSameTarget(const CaseVector *switchTable)
269 {
270 size_t caseNum = switchTable->size();
271 int lastVal = static_cast<int>((*switchTable)[0].first);
272 LabelIdx lblIdx = (*switchTable)[0].second;
273 for (size_t id = 1; id < caseNum; id++) {
274 lastVal++;
275 if (lastVal != (*switchTable)[id].first) {
276 return false;
277 }
278 if (lblIdx != (*switchTable)[id].second) {
279 return false;
280 }
281 }
282 return true;
283 }
284
285 // if there is only 1 case branch, replace with conditional branch(es) and
286 // return the optimized multiple statements; otherwise, return nullptr
LowerSwitchStmt(SwitchNode * switchNode)287 BlockNode *MIRLower::LowerSwitchStmt(SwitchNode *switchNode)
288 {
289 CaseVector *switchTable = &switchNode->GetSwitchTable();
290 if (switchTable->empty()) { // goto @defaultLabel
291 BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
292 LabelIdx defaultLabel = switchNode->GetDefaultLabel();
293 MIRBuilder *builder = mirModule.GetMIRBuilder();
294 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
295 blk->AddStatement(gotoStmt);
296 return blk;
297 }
298 if (!ConsecutiveCaseValsAndSameTarget(switchTable)) {
299 return nullptr;
300 }
301 BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
302 LabelIdx caseGotoLabel = switchTable->front().second;
303 LabelIdx defaultLabel = switchNode->GetDefaultLabel();
304 int64 minCaseVal = switchTable->front().first;
305 int64 maxCaseVal = switchTable->back().first;
306 BaseNode *switchOpnd = switchNode->Opnd(0);
307 MIRBuilder *builder = mirModule.GetMIRBuilder();
308 ConstvalNode *minCaseNode = builder->CreateIntConst(minCaseVal, switchOpnd->GetPrimType());
309 ConstvalNode *maxCaseNode = builder->CreateIntConst(maxCaseVal, switchOpnd->GetPrimType());
310 if (minCaseVal == maxCaseVal) {
311 // brtrue (x == minCaseVal) @case_goto_label
312 // goto @default_label
313 CompareNode *eqNode = builder->CreateExprCompare(
314 OP_eq, *GlobalTables::GetTypeTable().GetInt32(),
315 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
316 CondGotoNode *condGoto = builder->CreateStmtCondGoto(eqNode, OP_brtrue, caseGotoLabel);
317 blk->AddStatement(condGoto);
318 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
319 blk->AddStatement(gotoStmt);
320 } else {
321 // brtrue (x < minCaseVal) @default_label
322 // brtrue (x > maxCaseVal) @default_label
323 // goto @case_goto_label
324 CompareNode *ltNode = builder->CreateExprCompare(
325 OP_lt, *GlobalTables::GetTypeTable().GetInt32(),
326 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
327 CondGotoNode *condGoto = builder->CreateStmtCondGoto(ltNode, OP_brtrue, defaultLabel);
328 blk->AddStatement(condGoto);
329 CompareNode *gtNode = builder->CreateExprCompare(
330 OP_gt, *GlobalTables::GetTypeTable().GetInt32(),
331 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, maxCaseNode);
332 condGoto = builder->CreateStmtCondGoto(gtNode, OP_brtrue, defaultLabel);
333 blk->AddStatement(condGoto);
334 GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, caseGotoLabel);
335 blk->AddStatement(gotoStmt);
336 }
337 return blk;
338 }
339
340 // while <cond> <body>
341 // is lowered to:
342 // brfalse <cond> <endlabel>
343 // label <bodylabel>
344 // <body>
345 // brtrue <cond> <bodylabel>
346 // label <endlabel>
LowerWhileStmt(WhileStmtNode & whileStmt)347 BlockNode *MIRLower::LowerWhileStmt(WhileStmtNode &whileStmt)
348 {
349 DEBUG_ASSERT(whileStmt.GetBody() != nullptr, "nullptr check");
350 whileStmt.SetBody(LowerBlock(*whileStmt.GetBody()));
351 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
352 auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
353 brFalseStmt->SetOpnd(whileStmt.Opnd(0), 0);
354 brFalseStmt->SetSrcPos(whileStmt.GetSrcPos());
355 LabelIdx lalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
356 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lalbeIdx);
357 brFalseStmt->SetOffset(lalbeIdx);
358 blk->AddStatement(brFalseStmt);
359 blk->AppendStatementsFromBlock(*whileStmt.GetBody());
360 if (MeOption::optForSize) {
361 // still keep while-do format to avoid coping too much condition-related stmt
362 LabelIdx whileLalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
363 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(whileLalbeIdx);
364 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
365 lableStmt->SetLabelIdx(whileLalbeIdx);
366 blk->InsertBefore(brFalseStmt, lableStmt);
367 auto *whilegotonode = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto, whileLalbeIdx);
368 if (GetFuncProfData() && blk->GetLast()) {
369 GetFuncProfData()->CopyStmtFreq(whilegotonode->GetStmtID(), blk->GetLast()->GetStmtID());
370 }
371 blk->AddStatement(whilegotonode);
372 } else {
373 LabelIdx bodyLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
374 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLableIdx);
375 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
376 lableStmt->SetLabelIdx(bodyLableIdx);
377 blk->InsertAfter(brFalseStmt, lableStmt);
378 // update frequency
379 if (GetFuncProfData()) {
380 GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), whileStmt.GetStmtID());
381 GetFuncProfData()->CopyStmtFreq(brFalseStmt->GetStmtID(), whileStmt.GetStmtID());
382 }
383 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
384 brTrueStmt->SetOpnd(whileStmt.Opnd(0)->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
385 brTrueStmt->SetOffset(bodyLableIdx);
386 if (GetFuncProfData() && blk->GetLast()) {
387 GetFuncProfData()->CopyStmtFreq(brTrueStmt->GetStmtID(), whileStmt.GetBody()->GetStmtID());
388 }
389 blk->AddStatement(brTrueStmt);
390 }
391 auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
392 lableStmt->SetLabelIdx(lalbeIdx);
393 blk->AddStatement(lableStmt);
394 if (GetFuncProfData()) {
395 int64_t freq = GetFuncProfData()->GetStmtFreq(whileStmt.GetStmtID()) -
396 GetFuncProfData()->GetStmtFreq(blk->GetLast()->GetStmtID());
397 DEBUG_ASSERT(freq >= 0, "sanity check");
398 GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
399 }
400 return blk;
401 }
402
403 // doloop <do-var>(<start-expr>,<cont-expr>,<incr-amt>) {<body-stmts>}
404 // is lowered to:
405 // dassign <do-var> (<start-expr>)
406 // brfalse <cond-expr> <endlabel>
407 // label <bodylabel>
408 // <body-stmts>
409 // dassign <do-var> (<incr-amt>)
410 // brtrue <cond-expr> <bodylabel>
411 // label <endlabel>
LowerDoloopStmt(DoloopNode & doloop)412 BlockNode *MIRLower::LowerDoloopStmt(DoloopNode &doloop)
413 {
414 DEBUG_ASSERT(doloop.GetDoBody() != nullptr, "nullptr check");
415 doloop.SetDoBody(LowerBlock(*doloop.GetDoBody()));
416 int64_t doloopnodeFreq = 0, bodynodeFreq = 0;
417 if (GetFuncProfData()) {
418 doloopnodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetStmtID());
419 bodynodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetDoBody()->GetStmtID());
420 }
421 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
422 if (doloop.IsPreg()) {
423 PregIdx regIdx = static_cast<PregIdx>(doloop.GetDoVarStIdx().FullIdx());
424 MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
425 PrimType primType = mirPreg->GetPrimType();
426 DEBUG_ASSERT(primType != kPtyInvalid, "runtime check error");
427 auto *startRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
428 startRegassign->SetRegIdx(regIdx);
429 startRegassign->SetPrimType(primType);
430 startRegassign->SetOpnd(doloop.GetStartExpr(), 0);
431 startRegassign->SetSrcPos(doloop.GetSrcPos());
432 blk->AddStatement(startRegassign);
433 } else {
434 auto *startDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
435 startDassign->SetStIdx(doloop.GetDoVarStIdx());
436 startDassign->SetRHS(doloop.GetStartExpr());
437 startDassign->SetSrcPos(doloop.GetSrcPos());
438 blk->AddStatement(startDassign);
439 }
440 if (GetFuncProfData()) {
441 GetFuncProfData()->SetStmtFreq(blk->GetLast()->GetStmtID(), doloopnodeFreq - bodynodeFreq);
442 }
443 auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
444 brFalseStmt->SetOpnd(doloop.GetCondExpr(), 0);
445 LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
446 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
447 brFalseStmt->SetOffset(lIdx);
448 blk->AddStatement(brFalseStmt);
449 // udpate stmtFreq
450 if (GetFuncProfData()) {
451 GetFuncProfData()->SetStmtFreq(brFalseStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
452 }
453 LabelIdx bodyLabelIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
454 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLabelIdx);
455 auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
456 labelStmt->SetLabelIdx(bodyLabelIdx);
457 blk->AddStatement(labelStmt);
458 // udpate stmtFreq
459 if (GetFuncProfData()) {
460 GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), bodynodeFreq);
461 }
462 blk->AppendStatementsFromBlock(*doloop.GetDoBody());
463 if (doloop.IsPreg()) {
464 PregIdx regIdx = (PregIdx)doloop.GetDoVarStIdx().FullIdx();
465 MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
466 PrimType doVarPType = mirPreg->GetPrimType();
467 DEBUG_ASSERT(doVarPType != kPtyInvalid, "runtime check error");
468 auto *readDoVar = mirModule.CurFuncCodeMemPool()->New<RegreadNode>();
469 readDoVar->SetRegIdx(regIdx);
470 readDoVar->SetPrimType(doVarPType);
471 auto *add =
472 mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, doloop.GetIncrExpr(), readDoVar);
473 auto *endRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
474 endRegassign->SetRegIdx(regIdx);
475 endRegassign->SetPrimType(doVarPType);
476 endRegassign->SetOpnd(add, 0);
477 blk->AddStatement(endRegassign);
478 } else {
479 const MIRSymbol *doVarSym = mirModule.CurFunction()->GetLocalOrGlobalSymbol(doloop.GetDoVarStIdx());
480 PrimType doVarPType = doVarSym->GetType()->GetPrimType();
481 auto *readDovar =
482 mirModule.CurFuncCodeMemPool()->New<DreadNode>(OP_dread, doVarPType, doloop.GetDoVarStIdx(), 0);
483 auto *add =
484 mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, readDovar, doloop.GetIncrExpr());
485 auto *endDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
486 endDassign->SetStIdx(doloop.GetDoVarStIdx());
487 endDassign->SetRHS(add);
488 blk->AddStatement(endDassign);
489 }
490 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
491 brTrueStmt->SetOpnd(doloop.GetCondExpr()->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
492 brTrueStmt->SetOffset(bodyLabelIdx);
493 blk->AddStatement(brTrueStmt);
494 // udpate stmtFreq
495 if (GetFuncProfData()) {
496 GetFuncProfData()->SetStmtFreq(brTrueStmt->GetStmtID(), bodynodeFreq);
497 }
498 labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
499 labelStmt->SetLabelIdx(lIdx);
500 blk->AddStatement(labelStmt);
501 // udpate stmtFreq
502 if (GetFuncProfData()) {
503 GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
504 }
505 return blk;
506 }
507
508 // dowhile <body> <cond>
509 // is lowered to:
510 // label <bodylabel>
511 // <body>
512 // brtrue <cond> <bodylabel>
LowerDowhileStmt(WhileStmtNode & doWhileStmt)513 BlockNode *MIRLower::LowerDowhileStmt(WhileStmtNode &doWhileStmt)
514 {
515 DEBUG_ASSERT(doWhileStmt.GetBody() != nullptr, "nullptr check");
516 doWhileStmt.SetBody(LowerBlock(*doWhileStmt.GetBody()));
517 auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
518 LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
519 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
520 auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
521 labelStmt->SetLabelIdx(lIdx);
522 blk->AddStatement(labelStmt);
523 blk->AppendStatementsFromBlock(*doWhileStmt.GetBody());
524 auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
525 brTrueStmt->SetOpnd(doWhileStmt.Opnd(0), 0);
526 brTrueStmt->SetOffset(lIdx);
527 blk->AddStatement(brTrueStmt);
528 return blk;
529 }
530
LowerBlock(BlockNode & block)531 BlockNode *MIRLower::LowerBlock(BlockNode &block)
532 {
533 auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
534 BlockNode *tmp = nullptr;
535 if (block.GetFirst() == nullptr) {
536 newBlock->SetStmtID(block.GetStmtID()); // keep original block stmtid
537 return newBlock;
538 }
539 StmtNode *nextStmt = block.GetFirst();
540 DEBUG_ASSERT(nextStmt != nullptr, "nullptr check");
541 do {
542 StmtNode *stmt = nextStmt;
543 nextStmt = stmt->GetNext();
544 switch (stmt->GetOpCode()) {
545 case OP_if:
546 tmp = LowerIfStmt(static_cast<IfStmtNode &>(*stmt), true);
547 newBlock->AppendStatementsFromBlock(*tmp);
548 break;
549 case OP_switch:
550 tmp = LowerSwitchStmt(static_cast<SwitchNode *>(stmt));
551 if (tmp != nullptr) {
552 newBlock->AppendStatementsFromBlock(*tmp);
553 } else {
554 newBlock->AddStatement(stmt);
555 }
556 break;
557 case OP_while:
558 newBlock->AppendStatementsFromBlock(*LowerWhileStmt(static_cast<WhileStmtNode &>(*stmt)));
559 break;
560 case OP_dowhile:
561 newBlock->AppendStatementsFromBlock(*LowerDowhileStmt(static_cast<WhileStmtNode &>(*stmt)));
562 break;
563 case OP_doloop:
564 newBlock->AppendStatementsFromBlock(*LowerDoloopStmt(static_cast<DoloopNode &>(*stmt)));
565 break;
566 case OP_icallassigned:
567 case OP_icall: {
568 if (mirModule.IsCModule()) {
569 // convert to icallproto/icallprotoassigned
570 IcallNode *ic = static_cast<IcallNode *>(stmt);
571 ic->SetOpCode(stmt->GetOpCode() == OP_icall ? OP_icallproto : OP_icallprotoassigned);
572 MIRFuncType *funcType = FuncTypeFromFuncPtrExpr(stmt->Opnd(0));
573 CHECK_FATAL(funcType != nullptr, "MIRLower::LowerBlock: cannot find prototype for icall");
574 ic->SetRetTyIdx(funcType->GetTypeIndex());
575 MIRType *retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(funcType->GetRetTyIdx());
576 if (retType->GetPrimType() == PTY_agg && retType->GetSize() > k16BitSize) {
577 funcType->funcAttrs.SetAttr(FUNCATTR_firstarg_return);
578 }
579 }
580 newBlock->AddStatement(stmt);
581 break;
582 }
583 case OP_block:
584 tmp = LowerBlock(static_cast<BlockNode &>(*stmt));
585 newBlock->AppendStatementsFromBlock(*tmp);
586 break;
587 default:
588 newBlock->AddStatement(stmt);
589 break;
590 }
591 } while (nextStmt != nullptr);
592 newBlock->SetStmtID(block.GetStmtID()); // keep original block stmtid
593 return newBlock;
594 }
595
596 // for lowering OP_cand and OP_cior embedded in the expression x which belongs
597 // to curstmt
LowerEmbeddedCandCior(BaseNode * x,StmtNode * curstmt,BlockNode * blk)598 BaseNode *MIRLower::LowerEmbeddedCandCior(BaseNode *x, StmtNode *curstmt, BlockNode *blk)
599 {
600 if (x->GetOpCode() == OP_cand || x->GetOpCode() == OP_cior) {
601 MIRBuilder *builder = mirModule.GetMIRBuilder();
602 BinaryNode *bnode = static_cast<BinaryNode *>(x);
603 bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(0), curstmt, blk), 0);
604 PregIdx pregIdx = mirFunc->GetPregTab()->CreatePreg(x->GetPrimType());
605 RegassignNode *regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(0));
606 blk->InsertBefore(curstmt, regass);
607 LabelIdx labIdx = mirFunc->GetLabelTab()->CreateLabel();
608 mirFunc->GetLabelTab()->AddToStringLabelMap(labIdx);
609 BaseNode *cond = builder->CreateExprRegread(x->GetPrimType(), pregIdx);
610 CondGotoNode *cgoto =
611 mirFunc->GetCodeMempool()->New<CondGotoNode>(x->GetOpCode() == OP_cior ? OP_brtrue : OP_brfalse);
612 cgoto->SetOpnd(cond, 0);
613 cgoto->SetOffset(labIdx);
614 blk->InsertBefore(curstmt, cgoto);
615
616 bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(1), curstmt, blk), 1);
617 regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(1));
618 blk->InsertBefore(curstmt, regass);
619 LabelNode *lbl = mirFunc->GetCodeMempool()->New<LabelNode>();
620 lbl->SetLabelIdx(labIdx);
621 blk->InsertBefore(curstmt, lbl);
622 return builder->CreateExprRegread(x->GetPrimType(), pregIdx);
623 } else {
624 for (size_t i = 0; i < x->GetNumOpnds(); i++) {
625 x->SetOpnd(LowerEmbeddedCandCior(x->Opnd(i), curstmt, blk), i);
626 }
627 return x;
628 }
629 }
630
631 // for lowering all appearances of OP_cand and OP_cior associated with condional
632 // branches in the block
LowerCandCior(BlockNode & block)633 void MIRLower::LowerCandCior(BlockNode &block)
634 {
635 if (block.GetFirst() == nullptr) {
636 return;
637 }
638 StmtNode *nextStmt = block.GetFirst();
639 do {
640 StmtNode *stmt = nextStmt;
641 nextStmt = stmt->GetNext();
642 if (stmt->IsCondBr() && (stmt->Opnd(0)->GetOpCode() == OP_cand || stmt->Opnd(0)->GetOpCode() == OP_cior)) {
643 CondGotoNode *condGoto = static_cast<CondGotoNode *>(stmt);
644 BinaryNode *cond = static_cast<BinaryNode *>(condGoto->Opnd(0));
645 if ((stmt->GetOpCode() == OP_brfalse && cond->GetOpCode() == OP_cand) ||
646 (stmt->GetOpCode() == OP_brtrue && cond->GetOpCode() == OP_cior)) {
647 // short-circuit target label is same as original condGoto stmt
648 condGoto->SetOpnd(cond->GetBOpnd(0), 0);
649 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(Opcode(stmt->GetOpCode()));
650 newCondGoto->SetOpnd(cond->GetBOpnd(1), 0);
651 newCondGoto->SetOffset(condGoto->GetOffset());
652 block.InsertAfter(condGoto, newCondGoto);
653 nextStmt = stmt; // so it will be re-processed if another cand/cior
654 } else { // short-circuit target is next statement
655 LabelIdx lIdx;
656 LabelNode *labelStmt = nullptr;
657 if (nextStmt->GetOpCode() == OP_label) {
658 labelStmt = static_cast<LabelNode *>(nextStmt);
659 lIdx = labelStmt->GetLabelIdx();
660 } else {
661 lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
662 mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
663 labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
664 labelStmt->SetLabelIdx(lIdx);
665 block.InsertAfter(condGoto, labelStmt);
666 }
667 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(
668 stmt->GetOpCode() == OP_brfalse ? OP_brtrue : OP_brfalse);
669 newCondGoto->SetOpnd(cond->GetBOpnd(0), 0);
670 newCondGoto->SetOffset(lIdx);
671 block.InsertBefore(condGoto, newCondGoto);
672 condGoto->SetOpnd(cond->GetBOpnd(1), 0);
673 nextStmt = newCondGoto; // so it will be re-processed if another cand/cior
674 }
675 } else { // call LowerEmbeddedCandCior() for all the expression operands
676 for (size_t i = 0; i < stmt->GetNumOpnds(); i++) {
677 stmt->SetOpnd(LowerEmbeddedCandCior(stmt->Opnd(i), stmt, &block), i);
678 }
679 }
680 } while (nextStmt != nullptr);
681 }
682
LowerFunc(MIRFunction & func)683 void MIRLower::LowerFunc(MIRFunction &func)
684 {
685 if (GetOptLevel() > 0) {
686 ExtConstantFold ecf(func.GetModule());
687 (void)ecf.ExtSimplify(func.GetBody());
688 ;
689 }
690
691 mirModule.SetCurFunction(&func);
692 if (IsLowerExpandArray()) {
693 ExpandArrayMrt(func);
694 }
695 BlockNode *origBody = func.GetBody();
696 DEBUG_ASSERT(origBody != nullptr, "nullptr check");
697 BlockNode *newBody = LowerBlock(*origBody);
698 DEBUG_ASSERT(newBody != nullptr, "nullptr check");
699 LowerBuiltinExpect(*newBody);
700 if (!InLFO()) {
701 LowerCandCior(*newBody);
702 }
703 func.SetBody(newBody);
704 }
705
LowerFarray(ArrayNode * array)706 BaseNode *MIRLower::LowerFarray(ArrayNode *array)
707 {
708 auto *farrayType = static_cast<MIRFarrayType *>(array->GetArrayType(GlobalTables::GetTypeTable()));
709 size_t eSize = GlobalTables::GetTypeTable().GetTypeFromTyIdx(farrayType->GetElemTyIdx())->GetSize();
710 MIRType &arrayType = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()));
711 /* how about multi-dimension array? */
712 if (array->GetIndex(0)->GetOpCode() == OP_constval) {
713 const ConstvalNode *constvalNode = static_cast<const ConstvalNode *>(array->GetIndex(0));
714 if (constvalNode->GetConstVal()->GetKind() == kConstInt) {
715 const MIRIntConst *pIntConst = static_cast<const MIRIntConst *>(constvalNode->GetConstVal());
716 CHECK_FATAL(mirModule.IsJavaModule() || !pIntConst->IsNegative(), "Array index should >= 0.");
717 int64 eleOffset = pIntConst->GetExtValue() * eSize;
718
719 BaseNode *baseNode = array->GetBase();
720 if (eleOffset == 0) {
721 return baseNode;
722 }
723
724 MIRIntConst *eleConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(eleOffset, arrayType);
725 BaseNode *offsetNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(eleConst);
726 offsetNode->SetPrimType(array->GetPrimType());
727
728 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
729 rAdd->SetPrimType(array->GetPrimType());
730 rAdd->SetOpnd(baseNode, 0);
731 rAdd->SetOpnd(offsetNode, 1);
732 return rAdd;
733 }
734 }
735
736 BaseNode *rMul = nullptr;
737
738 BaseNode *baseNode = array->GetBase();
739
740 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
741 rAdd->SetPrimType(array->GetPrimType());
742 rAdd->SetOpnd(baseNode, 0);
743 rAdd->SetOpnd(rMul, 1);
744 auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
745 rAdd = (newAdd != nullptr ? newAdd : rAdd);
746 return rAdd;
747 }
748
LowerCArray(ArrayNode * array)749 BaseNode *MIRLower::LowerCArray(ArrayNode *array)
750 {
751 MIRType *aType = array->GetArrayType(GlobalTables::GetTypeTable());
752 if (aType->GetKind() == kTypeJArray) {
753 return array;
754 }
755 if (aType->GetKind() == kTypeFArray) {
756 return LowerFarray(array);
757 }
758
759 MIRArrayType *arrayType = static_cast<MIRArrayType *>(aType);
760 /* There are two cases where dimension > 1.
761 * 1) arrayType->dim > 1. Process the current arrayType. (nestedArray = false)
762 * 2) arrayType->dim == 1, but arraytype->eTyIdx is another array. (nestedArray = true)
763 * Assume at this time 1) and 2) cannot mix.
764 * Along with the array dimension, there is the array indexing.
765 * It is allowed to index arrays less than the dimension.
766 * This is dictated by the number of indexes.
767 */
768 bool nestedArray = false;
769 uint64 dim = arrayType->GetDim();
770 MIRType *innerType = nullptr;
771 MIRArrayType *innerArrayType = nullptr;
772 uint64 elemSize = 0;
773 if (dim == 1) {
774 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(arrayType->GetElemTyIdx());
775 if (innerType->GetKind() == kTypeArray) {
776 nestedArray = true;
777 do {
778 innerArrayType = static_cast<MIRArrayType *>(innerType);
779 elemSize = RoundUp(innerArrayType->GetElemType()->GetSize(), arrayType->GetElemType()->GetAlign());
780 dim++;
781 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
782 } while (innerType->GetKind() == kTypeArray);
783 }
784 }
785
786 size_t numIndex = array->NumOpnds() - 1;
787 MIRArrayType *curArrayType = arrayType;
788 BaseNode *resNode = array->GetIndex(0);
789 if (dim > 1) {
790 BaseNode *prevNode = nullptr;
791 for (size_t i = 0; (i < dim) && (i < numIndex); ++i) {
792 uint32 mpyDim = 1;
793 if (nestedArray) {
794 CHECK_FATAL(arrayType->GetSizeArrayItem(0) > 0, "Zero size array dimension");
795 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(curArrayType->GetElemTyIdx());
796 curArrayType = static_cast<MIRArrayType *>(innerType);
797 while (innerType->GetKind() == kTypeArray) {
798 innerArrayType = static_cast<MIRArrayType *>(innerType);
799 mpyDim *= innerArrayType->GetSizeArrayItem(0);
800 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
801 }
802 } else {
803 CHECK_FATAL(arrayType->GetSizeArrayItem(static_cast<uint32>(i)) > 0, "Zero size array dimension");
804 for (size_t j = i + 1; j < dim; ++j) {
805 mpyDim *= arrayType->GetSizeArrayItem(static_cast<uint32>(j));
806 }
807 }
808
809 BaseNode *index = static_cast<ConstvalNode *>(array->GetIndex(i));
810 bool isConst = false;
811 uint64 indexVal = 0;
812 if (index->op == OP_constval) {
813 ConstvalNode *constNode = static_cast<ConstvalNode *>(index);
814 indexVal = (static_cast<MIRIntConst *>(constNode->GetConstVal()))->GetExtValue();
815 isConst = true;
816 MIRIntConst *newConstNode = mirModule.GetMemPool()->New<MIRIntConst>(
817 indexVal * mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
818 BaseNode *newValNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(newConstNode);
819 newValNode->SetPrimType(array->GetPrimType());
820 if (i == 0) {
821 prevNode = newValNode;
822 continue;
823 } else {
824 resNode = newValNode;
825 }
826 }
827 if (i > 0 && isConst == false) {
828 resNode = array->GetIndex(i);
829 }
830
831 BaseNode *mpyNode;
832 if (isConst) {
833 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
834 static_cast<int64>(mpyDim) * indexVal,
835 *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
836 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
837 mulSize->SetPrimType(array->GetPrimType());
838 mpyNode = mulSize;
839 } else if (mpyDim == 1 && prevNode) {
840 mpyNode = prevNode;
841 prevNode = resNode;
842 } else {
843 mpyNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
844 mpyNode->SetPrimType(array->GetPrimType());
845 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
846 mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
847 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
848 mulSize->SetPrimType(array->GetPrimType());
849 mpyNode->SetOpnd(mulSize, 1);
850 PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
851 if (!IsPrimitiveInteger(resNode->GetPrimType())) {
852 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
853 resNode->GetPrimType(), resNode);
854 } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
855 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
856 OP_cvt, array->GetPrimType(), GetRegPrimType(resNode->GetPrimType()), resNode);
857 }
858 mpyNode->SetOpnd(resNode, 0);
859 }
860 if (i == 0) {
861 prevNode = mpyNode;
862 continue;
863 }
864 BaseNode *newResNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
865 newResNode->SetPrimType(array->GetPrimType());
866 newResNode->SetOpnd(mpyNode, 0);
867 if (NeedCvtOrRetype(prevNode->GetPrimType(), array->GetPrimType())) {
868 prevNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
869 OP_cvt, array->GetPrimType(), GetRegPrimType(prevNode->GetPrimType()), prevNode);
870 }
871 newResNode->SetOpnd(prevNode, 1);
872 prevNode = newResNode;
873 }
874 resNode = prevNode;
875 }
876
877 BaseNode *rMul = nullptr;
878 // esize is the size of the array element (eg. int = 4 long = 8)
879 uint64 esize;
880 if (nestedArray) {
881 esize = elemSize;
882 } else {
883 esize = arrayType->GetElemType()->GetSize();
884 }
885 Opcode opadd = OP_add;
886 MIRIntConst *econst = mirModule.GetMemPool()->New<MIRIntConst>(
887 esize, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
888 BaseNode *eSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(econst);
889 eSize->SetPrimType(array->GetPrimType());
890 rMul = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
891 PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
892 if (!IsPrimitiveInteger(resNode->GetPrimType())) {
893 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
894 resNode->GetPrimType(), resNode);
895 } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
896 resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, array->GetPrimType(),
897 GetRegPrimType(resNode->GetPrimType()), resNode);
898 }
899 rMul->SetPrimType(resNode->GetPrimType());
900 rMul->SetOpnd(resNode, 0);
901 rMul->SetOpnd(eSize, 1);
902 BaseNode *baseNode = array->GetBase();
903 BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(opadd);
904 rAdd->SetPrimType(array->GetPrimType());
905 rAdd->SetOpnd(baseNode, 0);
906 rAdd->SetOpnd(rMul, 1);
907 auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
908 rAdd = (newAdd != nullptr ? newAdd : rAdd);
909 return rAdd;
910 }
911
ExpandArrayMrtIfBlock(IfStmtNode & node)912 IfStmtNode *MIRLower::ExpandArrayMrtIfBlock(IfStmtNode &node)
913 {
914 if (node.GetThenPart() != nullptr) {
915 node.SetThenPart(ExpandArrayMrtBlock(*node.GetThenPart()));
916 }
917 if (node.GetElsePart() != nullptr) {
918 node.SetElsePart(ExpandArrayMrtBlock(*node.GetElsePart()));
919 }
920 return &node;
921 }
922
ExpandArrayMrtWhileBlock(WhileStmtNode & node)923 WhileStmtNode *MIRLower::ExpandArrayMrtWhileBlock(WhileStmtNode &node)
924 {
925 if (node.GetBody() != nullptr) {
926 node.SetBody(ExpandArrayMrtBlock(*node.GetBody()));
927 }
928 return &node;
929 }
930
ExpandArrayMrtDoloopBlock(DoloopNode & node)931 DoloopNode *MIRLower::ExpandArrayMrtDoloopBlock(DoloopNode &node)
932 {
933 if (node.GetDoBody() != nullptr) {
934 node.SetDoBody(ExpandArrayMrtBlock(*node.GetDoBody()));
935 }
936 return &node;
937 }
938
ExpandArrayMrtForeachelemBlock(ForeachelemNode & node)939 ForeachelemNode *MIRLower::ExpandArrayMrtForeachelemBlock(ForeachelemNode &node)
940 {
941 if (node.GetLoopBody() != nullptr) {
942 node.SetLoopBody(ExpandArrayMrtBlock(*node.GetLoopBody()));
943 }
944 return &node;
945 }
946
AddArrayMrtMpl(BaseNode & exp,BlockNode & newBlock)947 void MIRLower::AddArrayMrtMpl(BaseNode &exp, BlockNode &newBlock)
948 {
949 MIRModule &mod = mirModule;
950 MIRBuilder *builder = mod.GetMIRBuilder();
951 for (size_t i = 0; i < exp.NumOpnds(); ++i) {
952 DEBUG_ASSERT(exp.Opnd(i) != nullptr, "nullptr check");
953 AddArrayMrtMpl(*exp.Opnd(i), newBlock);
954 }
955 if (exp.GetOpCode() == OP_array) {
956 auto &arrayNode = static_cast<ArrayNode &>(exp);
957 if (arrayNode.GetBoundsCheck()) {
958 BaseNode *arrAddr = arrayNode.Opnd(0);
959 BaseNode *index = arrayNode.Opnd(1);
960 DEBUG_ASSERT(index != nullptr, "null ptr check");
961 MIRType *indexType = GlobalTables::GetTypeTable().GetPrimType(index->GetPrimType());
962 UnaryStmtNode *nullCheck = builder->CreateStmtUnary(OP_assertnonnull, arrAddr);
963 newBlock.AddStatement(nullCheck);
964 #if DO_LT_0_CHECK
965 ConstvalNode *indexZero = builder->GetConstUInt32(0);
966 CompareNode *lessZero =
967 builder->CreateExprCompare(OP_lt, *GlobalTables::GetTypeTable().GetUInt1(),
968 *GlobalTables::GetTypeTable().GetUInt32(), index, indexZero);
969 #endif
970 MIRType *infoLenType = GlobalTables::GetTypeTable().GetInt32();
971 MapleVector<BaseNode *> arguments(builder->GetCurrentFuncCodeMpAllocator()->Adapter());
972 arguments.push_back(arrAddr);
973 BaseNode *arrLen =
974 builder->CreateExprIntrinsicop(INTRN_JAVA_ARRAY_LENGTH, OP_intrinsicop, *infoLenType, arguments);
975 BaseNode *cpmIndex = index;
976 if (arrLen->GetPrimType() != index->GetPrimType()) {
977 cpmIndex = builder->CreateExprTypeCvt(OP_cvt, *infoLenType, *indexType, index);
978 }
979 CompareNode *largeLen =
980 builder->CreateExprCompare(OP_ge, *GlobalTables::GetTypeTable().GetUInt1(),
981 *GlobalTables::GetTypeTable().GetUInt32(), cpmIndex, arrLen);
982 // maybe should use cior
983 #if DO_LT_0_CHECK
984 BinaryNode *indexCon =
985 builder->CreateExprBinary(OP_lior, *GlobalTables::GetTypeTable().GetUInt1(), lessZero, largeLen);
986 #endif
987 MapleVector<BaseNode *> args(builder->GetCurrentFuncCodeMpAllocator()->Adapter());
988 #if DO_LT_0_CHECK
989 args.push_back(indexCon);
990 IntrinsiccallNode *boundaryTrinsicCall = builder->CreateStmtIntrinsicCall(INTRN_MPL_BOUNDARY_CHECK, args);
991 #else
992 args.push_back(largeLen);
993 IntrinsiccallNode *boundaryTrinsicCall = builder->CreateStmtIntrinsicCall(INTRN_MPL_BOUNDARY_CHECK, args);
994 #endif
995 newBlock.AddStatement(boundaryTrinsicCall);
996 }
997 }
998 }
999
ExpandArrayMrtBlock(BlockNode & block)1000 BlockNode *MIRLower::ExpandArrayMrtBlock(BlockNode &block)
1001 {
1002 auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
1003 if (block.GetFirst() == nullptr) {
1004 return newBlock;
1005 }
1006 StmtNode *nextStmt = block.GetFirst();
1007 do {
1008 StmtNode *stmt = nextStmt;
1009 DEBUG_ASSERT(stmt != nullptr, "nullptr check");
1010 nextStmt = stmt->GetNext();
1011 switch (stmt->GetOpCode()) {
1012 case OP_if:
1013 newBlock->AddStatement(ExpandArrayMrtIfBlock(static_cast<IfStmtNode &>(*stmt)));
1014 break;
1015 case OP_while:
1016 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
1017 break;
1018 case OP_dowhile:
1019 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
1020 break;
1021 case OP_doloop:
1022 newBlock->AddStatement(ExpandArrayMrtDoloopBlock(static_cast<DoloopNode &>(*stmt)));
1023 break;
1024 case OP_foreachelem:
1025 newBlock->AddStatement(ExpandArrayMrtForeachelemBlock(static_cast<ForeachelemNode &>(*stmt)));
1026 break;
1027 case OP_block:
1028 newBlock->AddStatement(ExpandArrayMrtBlock(static_cast<BlockNode &>(*stmt)));
1029 break;
1030 default:
1031 AddArrayMrtMpl(*stmt, *newBlock);
1032 newBlock->AddStatement(stmt);
1033 break;
1034 }
1035 } while (nextStmt != nullptr);
1036 return newBlock;
1037 }
1038
ExpandArrayMrt(MIRFunction & func)1039 void MIRLower::ExpandArrayMrt(MIRFunction &func)
1040 {
1041 if (ShouldOptArrayMrt(func)) {
1042 BlockNode *origBody = func.GetBody();
1043 DEBUG_ASSERT(origBody != nullptr, "nullptr check");
1044 BlockNode *newBody = ExpandArrayMrtBlock(*origBody);
1045 func.SetBody(newBody);
1046 }
1047 }
1048
FuncTypeFromFuncPtrExpr(BaseNode * x)1049 MIRFuncType *MIRLower::FuncTypeFromFuncPtrExpr(BaseNode *x)
1050 {
1051 MIRFuncType *res = nullptr;
1052 MIRFunction *func = mirModule.CurFunction();
1053 switch (x->GetOpCode()) {
1054 case OP_regread: {
1055 RegreadNode *regread = static_cast<RegreadNode *>(x);
1056 MIRPreg *preg = func->GetPregTab()->PregFromPregIdx(regread->GetRegIdx());
1057 // see if it is promoted from a symbol
1058 if (preg->GetOp() == OP_dread) {
1059 const MIRSymbol *symbol = preg->rematInfo.sym;
1060 MIRType *mirType = symbol->GetType();
1061 if (preg->fieldID != 0) {
1062 MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1063 FieldPair thepair = structty->TraverseToField(preg->fieldID);
1064 mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1065 }
1066
1067 if (mirType->GetKind() == kTypePointer) {
1068 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1069 }
1070 if (res != nullptr) {
1071 break;
1072 }
1073 }
1074 // check if a formal promoted to preg
1075 for (FormalDef &formalDef : func->GetFormalDefVec()) {
1076 if (!formalDef.formalSym->IsPreg()) {
1077 continue;
1078 }
1079 if (formalDef.formalSym->GetPreg() == preg) {
1080 MIRType *mirType = formalDef.formalSym->GetType();
1081 if (mirType->GetKind() == kTypePointer) {
1082 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1083 }
1084 break;
1085 }
1086 }
1087 break;
1088 }
1089 case OP_dread: {
1090 DreadNode *dread = static_cast<DreadNode *>(x);
1091 MIRSymbol *symbol = func->GetLocalOrGlobalSymbol(dread->GetStIdx());
1092 MIRType *mirType = symbol->GetType();
1093 if (dread->GetFieldID() != 0) {
1094 MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1095 FieldPair thepair = structty->TraverseToField(dread->GetFieldID());
1096 mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1097 }
1098 if (mirType->GetKind() == kTypePointer) {
1099 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1100 }
1101 break;
1102 }
1103 case OP_iread: {
1104 IreadNode *iread = static_cast<IreadNode *>(x);
1105 MIRPtrType *ptrType = static_cast<MIRPtrType *>(iread->GetType());
1106 MIRType *mirType = ptrType->GetPointedType();
1107 if (mirType->GetKind() == kTypeFunction) {
1108 res = static_cast<MIRFuncType *>(mirType);
1109 } else if (mirType->GetKind() == kTypePointer) {
1110 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1111 }
1112 break;
1113 }
1114 case OP_addroffunc: {
1115 AddroffuncNode *addrofFunc = static_cast<AddroffuncNode *>(x);
1116 PUIdx puIdx = addrofFunc->GetPUIdx();
1117 MIRFunction *f = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
1118 res = f->GetMIRFuncType();
1119 break;
1120 }
1121 case OP_retype: {
1122 MIRType *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(static_cast<RetypeNode *>(x)->GetTyIdx());
1123 if (mirType->GetKind() == kTypePointer) {
1124 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1125 }
1126 if (res == nullptr) {
1127 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeFirstOpnd));
1128 }
1129 break;
1130 }
1131 case OP_select: {
1132 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeSecondOpnd));
1133 if (res == nullptr) {
1134 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeThirdOpnd));
1135 }
1136 break;
1137 }
1138 default:
1139 CHECK_FATAL(false, "LMBCLowerer::FuncTypeFromFuncPtrExpr: NYI");
1140 }
1141 return res;
1142 }
1143
1144 const std::set<std::string> MIRLower::kSetArrayHotFunc = {};
1145
ShouldOptArrayMrt(const MIRFunction & func)1146 bool MIRLower::ShouldOptArrayMrt(const MIRFunction &func)
1147 {
1148 return (MIRLower::kSetArrayHotFunc.find(func.GetName()) != MIRLower::kSetArrayHotFunc.end());
1149 }
1150 } // namespace maple
1151