• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "mir_lower.h"
17 #include "constantfold.h"
18 #include "ext_constantfold.h"
19 #include "me_option.h"
20 
21 #define DO_LT_0_CHECK 1
22 
23 namespace maple {
24 
RoundUpConst(uint64 offset,uint32 align)25 static constexpr uint64 RoundUpConst(uint64 offset, uint32 align)
26 {
27     DEBUG_ASSERT(offset <= UINT64_MAX - align, "must not be zero");
28     return (-align) & (offset + align - 1);
29 }
30 
RoundUp(uint64 offset,uint32 align)31 static inline uint64 RoundUp(uint64 offset, uint32 align)
32 {
33     if (align == 0) {
34         return offset;
35     }
36     return RoundUpConst(offset, align);
37 }
38 
39 // Remove intrinsicop __builtin_expect and record likely info to brStmt
40 // Target condExpr example:
41 //  ne u1 i64 (
42 //    intrinsicop i64 C___builtin_expect (
43 //     cvt i64 i32 (dread i32 %levVar_9354), cvt i64 i32 (constval i32 0)),
44 //    constval i64 0)
LowerCondGotoStmtWithBuiltinExpect(CondGotoNode & brStmt)45 void LowerCondGotoStmtWithBuiltinExpect(CondGotoNode &brStmt)
46 {
47     BaseNode *condExpr = brStmt.Opnd(0);
48     // Poke ne for dread shortCircuit
49     // Example:
50     //  dassign %shortCircuit 0 (ne u1 i64 (
51     //    intrinsicop i64 C___builtin_expect (
52     //      cvt i64 i32 (dread i32 %levVar_32349),
53     //      cvt i64 i32 (constval i32 0)),
54     //    constval i64 0))
55     //  dassign %shortCircuit 0 (ne u1 u32 (dread u32 %shortCircuit, constval u1 0))
56     if (condExpr->GetOpCode() == OP_ne && condExpr->Opnd(0)->GetOpCode() == OP_dread &&
57         condExpr->Opnd(1)->GetOpCode() == OP_constval) {
58         auto *constVal = static_cast<ConstvalNode *>(condExpr->Opnd(1))->GetConstVal();
59         if (constVal->GetKind() == kConstInt && static_cast<MIRIntConst *>(constVal)->GetValue() == 0) {
60             condExpr = condExpr->Opnd(0);
61         }
62     }
63     if (condExpr->GetOpCode() == OP_dread) {
64         // Example:
65         //    dassign %shortCircuit 0 (ne u1 i64 (
66         //      intrinsicop i64 C___builtin_expect (
67         //        cvt i64 i32 (dread i32 %levVar_9488),
68         //        cvt i64 i32 (constval i32 1)),
69         //      constval i64 0))
70         //    brfalse @shortCircuit_label_13351 (dread u32 %shortCircuit)
71         StIdx stIdx = static_cast<DreadNode *>(condExpr)->GetStIdx();
72         FieldID fieldId = static_cast<DreadNode *>(condExpr)->GetFieldID();
73         if (fieldId != 0) {
74             return;
75         }
76         if (brStmt.GetPrev() == nullptr || brStmt.GetPrev()->GetOpCode() != OP_dassign) {
77             return;  // prev stmt may be a label, we skip it too
78         }
79         auto *dassign = static_cast<DassignNode *>(brStmt.GetPrev());
80         if (stIdx != dassign->GetStIdx() || dassign->GetFieldID() != 0) {
81             return;
82         }
83         condExpr = dassign->GetRHS();
84     }
85     if (condExpr->GetOpCode() == OP_ne) {
86         // opnd1 must be int const 0
87         BaseNode *opnd1 = condExpr->Opnd(1);
88         if (opnd1->GetOpCode() != OP_constval) {
89             return;
90         }
91         auto *constVal = static_cast<ConstvalNode *>(opnd1)->GetConstVal();
92         if (constVal->GetKind() != kConstInt || static_cast<MIRIntConst *>(constVal)->GetValue() != 0) {
93             return;
94         }
95         // opnd0 must be intrinsicop C___builtin_expect
96         BaseNode *opnd0 = condExpr->Opnd(0);
97         if (opnd0->GetOpCode() != OP_intrinsicop ||
98             static_cast<IntrinsicopNode *>(opnd0)->GetIntrinsic() != INTRN_C___builtin_expect) {
99             return;
100         }
101         // We trust constant fold
102         auto *expectedConstExpr = opnd0->Opnd(1);
103         if (expectedConstExpr->GetOpCode() == OP_cvt) {
104             expectedConstExpr = expectedConstExpr->Opnd(0);
105         }
106         if (expectedConstExpr->GetOpCode() != OP_constval) {
107             return;
108         }
109         auto *expectedConstNode = static_cast<ConstvalNode *>(expectedConstExpr)->GetConstVal();
110         CHECK_FATAL(expectedConstNode->GetKind() == kConstInt, "must be");
111         auto expectedVal = static_cast<MIRIntConst *>(expectedConstNode)->GetValue();
112         if (expectedVal != 0 && expectedVal != 1) {
113             return;
114         }
115         bool likelyTrue = (expectedVal == 1);  // The condition is likely to be true
116         bool likelyBranch = (brStmt.GetOpCode() == OP_brtrue ? likelyTrue : !likelyTrue);  // High probability jump
117         if (likelyBranch) {
118             brStmt.SetBranchProb(kProbLikely);
119         } else {
120             brStmt.SetBranchProb(kProbUnlikely);
121         }
122         // Remove __builtin_expect
123         condExpr->SetOpnd(opnd0->Opnd(0), 0);
124     }
125 }
126 
LowerBuiltinExpect(BlockNode & block)127 void MIRLower::LowerBuiltinExpect(BlockNode &block)
128 {
129     auto *stmt = block.GetFirst();
130     auto *last = block.GetLast();
131     while (stmt != last) {
132         if (stmt->GetOpCode() == OP_brtrue || stmt->GetOpCode() == OP_brfalse) {
133             LowerCondGotoStmtWithBuiltinExpect(*static_cast<CondGotoNode *>(stmt));
134         }
135         stmt = stmt->GetNext();
136     }
137 }
138 
CreateCondGotoStmt(Opcode op,BlockNode & blk,const IfStmtNode & ifStmt)139 LabelIdx MIRLower::CreateCondGotoStmt(Opcode op, BlockNode &blk, const IfStmtNode &ifStmt)
140 {
141     auto *brStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(op);
142     brStmt->SetOpnd(ifStmt.Opnd(), 0);
143     brStmt->SetSrcPos(ifStmt.GetSrcPos());
144     DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
145     LabelIdx lableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
146     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lableIdx);
147     brStmt->SetOffset(lableIdx);
148     blk.AddStatement(brStmt);
149     if (GetFuncProfData()) {
150         GetFuncProfData()->CopyStmtFreq(brStmt->GetStmtID(), ifStmt.GetStmtID());
151     }
152     bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
153     if (thenEmpty) {
154         blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
155     } else {
156         blk.AppendStatementsFromBlock(*ifStmt.GetThenPart());
157     }
158     return lableIdx;
159 }
160 
CreateBrFalseStmt(BlockNode & blk,const IfStmtNode & ifStmt)161 void MIRLower::CreateBrFalseStmt(BlockNode &blk, const IfStmtNode &ifStmt)
162 {
163     LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
164     auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
165     lableStmt->SetLabelIdx(labelIdx);
166     blk.AddStatement(lableStmt);
167     // set stmtfreqs
168     if (GetFuncProfData()) {
169         DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID()) >= 0, "sanity check");
170         int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
171                        GetFuncProfData()->GetStmtFreq(ifStmt.GetThenPart()->GetStmtID());
172         GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
173     }
174 }
175 
CreateBrTrueStmt(BlockNode & blk,const IfStmtNode & ifStmt)176 void MIRLower::CreateBrTrueStmt(BlockNode &blk, const IfStmtNode &ifStmt)
177 {
178     LabelIdx labelIdx = CreateCondGotoStmt(OP_brtrue, blk, ifStmt);
179     auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
180     lableStmt->SetLabelIdx(labelIdx);
181     blk.AddStatement(lableStmt);
182     // set stmtfreqs
183     if (GetFuncProfData()) {
184         DEBUG_ASSERT(GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID()) >= 0, "sanity check");
185         int64_t freq = GetFuncProfData()->GetStmtFreq(ifStmt.GetStmtID()) -
186                        GetFuncProfData()->GetStmtFreq(ifStmt.GetElsePart()->GetStmtID());
187         GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
188     }
189 }
190 
CreateBrFalseAndGotoStmt(BlockNode & blk,const IfStmtNode & ifStmt)191 void MIRLower::CreateBrFalseAndGotoStmt(BlockNode &blk, const IfStmtNode &ifStmt)
192 {
193     LabelIdx labelIdx = CreateCondGotoStmt(OP_brfalse, blk, ifStmt);
194     bool fallThroughFromThen = !IfStmtNoFallThrough(ifStmt);
195     LabelIdx gotoLableIdx = 0;
196     if (fallThroughFromThen) {
197         auto *gotoStmt = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto);
198         DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
199         gotoLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
200         mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(gotoLableIdx);
201         gotoStmt->SetOffset(gotoLableIdx);
202         blk.AddStatement(gotoStmt);
203         // set stmtfreqs
204         if (GetFuncProfData()) {
205             GetFuncProfData()->CopyStmtFreq(gotoStmt->GetStmtID(), ifStmt.GetThenPart()->GetStmtID());
206         }
207     }
208     auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
209     lableStmt->SetLabelIdx(labelIdx);
210     blk.AddStatement(lableStmt);
211     blk.AppendStatementsFromBlock(*ifStmt.GetElsePart());
212     // set stmtfreqs
213     if (GetFuncProfData()) {
214         GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetElsePart()->GetStmtID());
215     }
216     if (fallThroughFromThen) {
217         lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
218         lableStmt->SetLabelIdx(gotoLableIdx);
219         blk.AddStatement(lableStmt);
220         // set endlabel stmtfreqs
221         if (GetFuncProfData()) {
222             GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), ifStmt.GetStmtID());
223         }
224     }
225 }
226 
LowerIfStmt(IfStmtNode & ifStmt,bool recursive)227 BlockNode *MIRLower::LowerIfStmt(IfStmtNode &ifStmt, bool recursive)
228 {
229     bool thenEmpty = (ifStmt.GetThenPart() == nullptr) || (ifStmt.GetThenPart()->GetFirst() == nullptr);
230     bool elseEmpty = (ifStmt.GetElsePart() == nullptr) || (ifStmt.GetElsePart()->GetFirst() == nullptr);
231     if (recursive) {
232         if (!thenEmpty) {
233             ifStmt.SetThenPart(LowerBlock(*ifStmt.GetThenPart()));
234         }
235         if (!elseEmpty) {
236             ifStmt.SetElsePart(LowerBlock(*ifStmt.GetElsePart()));
237         }
238     }
239     auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
240     if (thenEmpty && elseEmpty) {
241         // generate EVAL <cond> statement
242         auto *evalStmt = mirModule.CurFuncCodeMemPool()->New<UnaryStmtNode>(OP_eval);
243         evalStmt->SetOpnd(ifStmt.Opnd(), 0);
244         evalStmt->SetSrcPos(ifStmt.GetSrcPos());
245         blk->AddStatement(evalStmt);
246         if (GetFuncProfData()) {
247             GetFuncProfData()->CopyStmtFreq(evalStmt->GetStmtID(), ifStmt.GetStmtID());
248         }
249     } else if (elseEmpty) {
250         // brfalse <cond> <endlabel>
251         // <thenPart>
252         // label <endlabel>
253         CreateBrFalseStmt(*blk, ifStmt);
254     } else if (thenEmpty) {
255         // brtrue <cond> <endlabel>
256         // <elsePart>
257         // label <endlabel>
258         CreateBrTrueStmt(*blk, ifStmt);
259     } else {
260         // brfalse <cond> <elselabel>
261         // <thenPart>
262         // goto <endlabel>
263         // label <elselabel>
264         // <elsePart>
265         // label <endlabel>
266         CreateBrFalseAndGotoStmt(*blk, ifStmt);
267     }
268     return blk;
269 }
270 
ConsecutiveCaseValsAndSameTarget(const CaseVector * switchTable)271 static bool ConsecutiveCaseValsAndSameTarget(const CaseVector *switchTable)
272 {
273     size_t caseNum = switchTable->size();
274     int lastVal = static_cast<int>((*switchTable)[0].first);
275     LabelIdx lblIdx = (*switchTable)[0].second;
276     for (size_t id = 1; id < caseNum; id++) {
277         lastVal++;
278         if (lastVal != (*switchTable)[id].first) {
279             return false;
280         }
281         if (lblIdx != (*switchTable)[id].second) {
282             return false;
283         }
284     }
285     return true;
286 }
287 
288 // if there is only 1 case branch, replace with conditional branch(es) and
289 // return the optimized multiple statements; otherwise, return nullptr
LowerSwitchStmt(SwitchNode * switchNode)290 BlockNode *MIRLower::LowerSwitchStmt(SwitchNode *switchNode)
291 {
292     CaseVector *switchTable = &switchNode->GetSwitchTable();
293     if (switchTable->empty()) {  // goto @defaultLabel
294         BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
295         LabelIdx defaultLabel = switchNode->GetDefaultLabel();
296         MIRBuilder *builder = mirModule.GetMIRBuilder();
297         GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
298         blk->AddStatement(gotoStmt);
299         return blk;
300     }
301     if (!ConsecutiveCaseValsAndSameTarget(switchTable)) {
302         return nullptr;
303     }
304     BlockNode *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
305     LabelIdx caseGotoLabel = switchTable->front().second;
306     LabelIdx defaultLabel = switchNode->GetDefaultLabel();
307     int64 minCaseVal = switchTable->front().first;
308     int64 maxCaseVal = switchTable->back().first;
309     BaseNode *switchOpnd = switchNode->Opnd(0);
310     MIRBuilder *builder = mirModule.GetMIRBuilder();
311     ConstvalNode *minCaseNode = builder->CreateIntConst(minCaseVal, switchOpnd->GetPrimType());
312     ConstvalNode *maxCaseNode = builder->CreateIntConst(maxCaseVal, switchOpnd->GetPrimType());
313     if (minCaseVal == maxCaseVal) {
314         // brtrue (x == minCaseVal) @case_goto_label
315         // goto @default_label
316         CompareNode *eqNode = builder->CreateExprCompare(
317             OP_eq, *GlobalTables::GetTypeTable().GetInt32(),
318             *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
319         CondGotoNode *condGoto = builder->CreateStmtCondGoto(eqNode, OP_brtrue, caseGotoLabel);
320         blk->AddStatement(condGoto);
321         GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel);
322         blk->AddStatement(gotoStmt);
323     } else {
324         // brtrue (x < minCaseVal) @default_label
325         // brtrue (x > maxCaseVal) @default_label
326         // goto @case_goto_label
327         CompareNode *ltNode = builder->CreateExprCompare(
328             OP_lt, *GlobalTables::GetTypeTable().GetInt32(),
329             *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode);
330         CondGotoNode *condGoto = builder->CreateStmtCondGoto(ltNode, OP_brtrue, defaultLabel);
331         blk->AddStatement(condGoto);
332         CompareNode *gtNode = builder->CreateExprCompare(
333             OP_gt, *GlobalTables::GetTypeTable().GetInt32(),
334             *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, maxCaseNode);
335         condGoto = builder->CreateStmtCondGoto(gtNode, OP_brtrue, defaultLabel);
336         blk->AddStatement(condGoto);
337         GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, caseGotoLabel);
338         blk->AddStatement(gotoStmt);
339     }
340     return blk;
341 }
342 
343 //     while <cond> <body>
344 // is lowered to:
345 //     brfalse <cond> <endlabel>
346 //   label <bodylabel>
347 //     <body>
348 //     brtrue <cond> <bodylabel>
349 //   label <endlabel>
LowerWhileStmt(WhileStmtNode & whileStmt)350 BlockNode *MIRLower::LowerWhileStmt(WhileStmtNode &whileStmt)
351 {
352     DEBUG_ASSERT(whileStmt.GetBody() != nullptr, "nullptr check");
353     whileStmt.SetBody(LowerBlock(*whileStmt.GetBody()));
354     auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
355     auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
356     brFalseStmt->SetOpnd(whileStmt.Opnd(0), 0);
357     brFalseStmt->SetSrcPos(whileStmt.GetSrcPos());
358     LabelIdx lalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
359     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lalbeIdx);
360     brFalseStmt->SetOffset(lalbeIdx);
361     blk->AddStatement(brFalseStmt);
362     blk->AppendStatementsFromBlock(*whileStmt.GetBody());
363     if (MeOption::optForSize) {
364         // still keep while-do format to avoid coping too much condition-related stmt
365         LabelIdx whileLalbeIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
366         mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(whileLalbeIdx);
367         auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
368         lableStmt->SetLabelIdx(whileLalbeIdx);
369         blk->InsertBefore(brFalseStmt, lableStmt);
370         auto *whilegotonode = mirModule.CurFuncCodeMemPool()->New<GotoNode>(OP_goto, whileLalbeIdx);
371         if (GetFuncProfData() && blk->GetLast()) {
372             GetFuncProfData()->CopyStmtFreq(whilegotonode->GetStmtID(), blk->GetLast()->GetStmtID());
373         }
374         blk->AddStatement(whilegotonode);
375     } else {
376         LabelIdx bodyLableIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
377         mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLableIdx);
378         auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
379         lableStmt->SetLabelIdx(bodyLableIdx);
380         blk->InsertAfter(brFalseStmt, lableStmt);
381         // update frequency
382         if (GetFuncProfData()) {
383             GetFuncProfData()->CopyStmtFreq(lableStmt->GetStmtID(), whileStmt.GetStmtID());
384             GetFuncProfData()->CopyStmtFreq(brFalseStmt->GetStmtID(), whileStmt.GetStmtID());
385         }
386         auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
387         brTrueStmt->SetOpnd(whileStmt.Opnd(0)->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
388         brTrueStmt->SetOffset(bodyLableIdx);
389         if (GetFuncProfData() && blk->GetLast()) {
390             GetFuncProfData()->CopyStmtFreq(brTrueStmt->GetStmtID(), whileStmt.GetBody()->GetStmtID());
391         }
392         blk->AddStatement(brTrueStmt);
393     }
394     auto *lableStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
395     lableStmt->SetLabelIdx(lalbeIdx);
396     blk->AddStatement(lableStmt);
397     if (GetFuncProfData()) {
398         int64_t freq = GetFuncProfData()->GetStmtFreq(whileStmt.GetStmtID()) -
399                        GetFuncProfData()->GetStmtFreq(blk->GetLast()->GetStmtID());
400         DEBUG_ASSERT(freq >= 0, "sanity check");
401         GetFuncProfData()->SetStmtFreq(lableStmt->GetStmtID(), freq);
402     }
403     return blk;
404 }
405 
406 //    doloop <do-var>(<start-expr>,<cont-expr>,<incr-amt>) {<body-stmts>}
407 // is lowered to:
408 //     dassign <do-var> (<start-expr>)
409 //     brfalse <cond-expr> <endlabel>
410 //   label <bodylabel>
411 //     <body-stmts>
412 //     dassign <do-var> (<incr-amt>)
413 //     brtrue <cond-expr>  <bodylabel>
414 //   label <endlabel>
LowerDoloopStmt(DoloopNode & doloop)415 BlockNode *MIRLower::LowerDoloopStmt(DoloopNode &doloop)
416 {
417     DEBUG_ASSERT(doloop.GetDoBody() != nullptr, "nullptr check");
418     DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
419     doloop.SetDoBody(LowerBlock(*doloop.GetDoBody()));
420     int64_t doloopnodeFreq = 0, bodynodeFreq = 0;
421     if (GetFuncProfData()) {
422         doloopnodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetStmtID());
423         bodynodeFreq = GetFuncProfData()->GetStmtFreq(doloop.GetDoBody()->GetStmtID());
424     }
425     auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
426     if (doloop.IsPreg()) {
427         PregIdx regIdx = static_cast<PregIdx>(doloop.GetDoVarStIdx().FullIdx());
428         MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
429         PrimType primType = mirPreg->GetPrimType();
430         DEBUG_ASSERT(primType != kPtyInvalid, "runtime check error");
431         auto *startRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
432         startRegassign->SetRegIdx(regIdx);
433         startRegassign->SetPrimType(primType);
434         startRegassign->SetOpnd(doloop.GetStartExpr(), 0);
435         startRegassign->SetSrcPos(doloop.GetSrcPos());
436         blk->AddStatement(startRegassign);
437     } else {
438         auto *startDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
439         startDassign->SetStIdx(doloop.GetDoVarStIdx());
440         startDassign->SetRHS(doloop.GetStartExpr());
441         startDassign->SetSrcPos(doloop.GetSrcPos());
442         blk->AddStatement(startDassign);
443     }
444     if (GetFuncProfData()) {
445         GetFuncProfData()->SetStmtFreq(blk->GetLast()->GetStmtID(), doloopnodeFreq - bodynodeFreq);
446     }
447     auto *brFalseStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
448     brFalseStmt->SetOpnd(doloop.GetCondExpr(), 0);
449     LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
450     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
451     brFalseStmt->SetOffset(lIdx);
452     blk->AddStatement(brFalseStmt);
453     // udpate stmtFreq
454     if (GetFuncProfData()) {
455         GetFuncProfData()->SetStmtFreq(brFalseStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
456     }
457     LabelIdx bodyLabelIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
458     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(bodyLabelIdx);
459     auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
460     labelStmt->SetLabelIdx(bodyLabelIdx);
461     blk->AddStatement(labelStmt);
462     // udpate stmtFreq
463     if (GetFuncProfData()) {
464         GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), bodynodeFreq);
465     }
466     blk->AppendStatementsFromBlock(*doloop.GetDoBody());
467     if (doloop.IsPreg()) {
468         PregIdx regIdx = (PregIdx)doloop.GetDoVarStIdx().FullIdx();
469         MIRPreg *mirPreg = mirModule.CurFunction()->GetPregTab()->PregFromPregIdx(regIdx);
470         PrimType doVarPType = mirPreg->GetPrimType();
471         DEBUG_ASSERT(doVarPType != kPtyInvalid, "runtime check error");
472         auto *readDoVar = mirModule.CurFuncCodeMemPool()->New<RegreadNode>();
473         readDoVar->SetRegIdx(regIdx);
474         readDoVar->SetPrimType(doVarPType);
475         auto *add =
476             mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, doloop.GetIncrExpr(), readDoVar);
477         auto *endRegassign = mirModule.CurFuncCodeMemPool()->New<RegassignNode>();
478         endRegassign->SetRegIdx(regIdx);
479         endRegassign->SetPrimType(doVarPType);
480         endRegassign->SetOpnd(add, 0);
481         blk->AddStatement(endRegassign);
482     } else {
483         const MIRSymbol *doVarSym = mirModule.CurFunction()->GetLocalOrGlobalSymbol(doloop.GetDoVarStIdx());
484         DEBUG_ASSERT(doVarSym != nullptr, "nullptr check");
485         PrimType doVarPType = doVarSym->GetType()->GetPrimType();
486         auto *readDovar =
487             mirModule.CurFuncCodeMemPool()->New<DreadNode>(OP_dread, doVarPType, doloop.GetDoVarStIdx(), 0);
488         auto *add =
489             mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add, doVarPType, readDovar, doloop.GetIncrExpr());
490         auto *endDassign = mirModule.CurFuncCodeMemPool()->New<DassignNode>();
491         endDassign->SetStIdx(doloop.GetDoVarStIdx());
492         endDassign->SetRHS(add);
493         blk->AddStatement(endDassign);
494     }
495     auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
496     brTrueStmt->SetOpnd(doloop.GetCondExpr()->CloneTree(mirModule.GetCurFuncCodeMPAllocator()), 0);
497     brTrueStmt->SetOffset(bodyLabelIdx);
498     blk->AddStatement(brTrueStmt);
499     // udpate stmtFreq
500     if (GetFuncProfData()) {
501         GetFuncProfData()->SetStmtFreq(brTrueStmt->GetStmtID(), bodynodeFreq);
502     }
503     labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
504     labelStmt->SetLabelIdx(lIdx);
505     blk->AddStatement(labelStmt);
506     // udpate stmtFreq
507     if (GetFuncProfData()) {
508         GetFuncProfData()->SetStmtFreq(labelStmt->GetStmtID(), (doloopnodeFreq - bodynodeFreq));
509     }
510     return blk;
511 }
512 
513 //     dowhile <body> <cond>
514 // is lowered to:
515 //   label <bodylabel>
516 //     <body>
517 //     brtrue <cond> <bodylabel>
LowerDowhileStmt(WhileStmtNode & doWhileStmt)518 BlockNode *MIRLower::LowerDowhileStmt(WhileStmtNode &doWhileStmt)
519 {
520     DEBUG_ASSERT(doWhileStmt.GetBody() != nullptr, "nullptr check");
521     doWhileStmt.SetBody(LowerBlock(*doWhileStmt.GetBody()));
522     auto *blk = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
523     DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
524     LabelIdx lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
525     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
526     auto *labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
527     labelStmt->SetLabelIdx(lIdx);
528     blk->AddStatement(labelStmt);
529     blk->AppendStatementsFromBlock(*doWhileStmt.GetBody());
530     auto *brTrueStmt = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(OP_brtrue);
531     brTrueStmt->SetOpnd(doWhileStmt.Opnd(0), 0);
532     brTrueStmt->SetOffset(lIdx);
533     blk->AddStatement(brTrueStmt);
534     return blk;
535 }
536 
LowerBlock(BlockNode & block)537 BlockNode *MIRLower::LowerBlock(BlockNode &block)
538 {
539     auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
540     BlockNode *tmp = nullptr;
541     if (block.GetFirst() == nullptr) {
542         newBlock->SetStmtID(block.GetStmtID());  // keep original block stmtid
543         return newBlock;
544     }
545     StmtNode *nextStmt = block.GetFirst();
546     DEBUG_ASSERT(nextStmt != nullptr, "nullptr check");
547     do {
548         StmtNode *stmt = nextStmt;
549         nextStmt = stmt->GetNext();
550         switch (stmt->GetOpCode()) {
551             case OP_if:
552                 tmp = LowerIfStmt(static_cast<IfStmtNode &>(*stmt), true);
553                 newBlock->AppendStatementsFromBlock(*tmp);
554                 break;
555             case OP_switch:
556                 tmp = LowerSwitchStmt(static_cast<SwitchNode *>(stmt));
557                 if (tmp != nullptr) {
558                     newBlock->AppendStatementsFromBlock(*tmp);
559                 } else {
560                     newBlock->AddStatement(stmt);
561                 }
562                 break;
563             case OP_while:
564                 newBlock->AppendStatementsFromBlock(*LowerWhileStmt(static_cast<WhileStmtNode &>(*stmt)));
565                 break;
566             case OP_dowhile:
567                 newBlock->AppendStatementsFromBlock(*LowerDowhileStmt(static_cast<WhileStmtNode &>(*stmt)));
568                 break;
569             case OP_doloop:
570                 newBlock->AppendStatementsFromBlock(*LowerDoloopStmt(static_cast<DoloopNode &>(*stmt)));
571                 break;
572             case OP_icallassigned:
573             case OP_icall: {
574                 if (mirModule.IsCModule()) {
575                     // convert to icallproto/icallprotoassigned
576                     IcallNode *ic = static_cast<IcallNode *>(stmt);
577                     ic->SetOpCode(stmt->GetOpCode() == OP_icall ? OP_icallproto : OP_icallprotoassigned);
578                     MIRFuncType *funcType = FuncTypeFromFuncPtrExpr(stmt->Opnd(0));
579                     CHECK_FATAL(funcType != nullptr, "MIRLower::LowerBlock: cannot find prototype for icall");
580                     ic->SetRetTyIdx(funcType->GetTypeIndex());
581                     MIRType *retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(funcType->GetRetTyIdx());
582                     if (retType->GetPrimType() == PTY_agg && retType->GetSize() > k16BitSize) {
583                         funcType->funcAttrs.SetAttr(FUNCATTR_firstarg_return);
584                     }
585                 }
586                 newBlock->AddStatement(stmt);
587                 break;
588             }
589             case OP_block:
590                 tmp = LowerBlock(static_cast<BlockNode &>(*stmt));
591                 newBlock->AppendStatementsFromBlock(*tmp);
592                 break;
593             default:
594                 newBlock->AddStatement(stmt);
595                 break;
596         }
597     } while (nextStmt != nullptr);
598     newBlock->SetStmtID(block.GetStmtID());  // keep original block stmtid
599     return newBlock;
600 }
601 
602 // for lowering OP_cand and OP_cior embedded in the expression x which belongs
603 // to curstmt
LowerEmbeddedCandCior(BaseNode * x,StmtNode * curstmt,BlockNode * blk)604 BaseNode *MIRLower::LowerEmbeddedCandCior(BaseNode *x, StmtNode *curstmt, BlockNode *blk)
605 {
606     DEBUG_ASSERT(x != nullptr, "nullptr check");
607     if (x->GetOpCode() == OP_cand || x->GetOpCode() == OP_cior) {
608         MIRBuilder *builder = mirModule.GetMIRBuilder();
609         BinaryNode *bnode = static_cast<BinaryNode *>(x);
610         bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(0), curstmt, blk), 0);
611         PregIdx pregIdx = mirFunc->GetPregTab()->CreatePreg(x->GetPrimType());
612         RegassignNode *regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(0));
613         blk->InsertBefore(curstmt, regass);
614         LabelIdx labIdx = mirFunc->GetLabelTab()->CreateLabel();
615         mirFunc->GetLabelTab()->AddToStringLabelMap(labIdx);
616         BaseNode *cond = builder->CreateExprRegread(x->GetPrimType(), pregIdx);
617         CondGotoNode *cgoto =
618             mirFunc->GetCodeMempool()->New<CondGotoNode>(x->GetOpCode() == OP_cior ? OP_brtrue : OP_brfalse);
619         cgoto->SetOpnd(cond, 0);
620         cgoto->SetOffset(labIdx);
621         blk->InsertBefore(curstmt, cgoto);
622 
623         bnode->SetOpnd(LowerEmbeddedCandCior(bnode->Opnd(1), curstmt, blk), 1);
624         regass = builder->CreateStmtRegassign(x->GetPrimType(), pregIdx, bnode->Opnd(1));
625         blk->InsertBefore(curstmt, regass);
626         LabelNode *lbl = mirFunc->GetCodeMempool()->New<LabelNode>();
627         lbl->SetLabelIdx(labIdx);
628         blk->InsertBefore(curstmt, lbl);
629         return builder->CreateExprRegread(x->GetPrimType(), pregIdx);
630     } else {
631         for (size_t i = 0; i < x->GetNumOpnds(); i++) {
632             x->SetOpnd(LowerEmbeddedCandCior(x->Opnd(i), curstmt, blk), i);
633         }
634         return x;
635     }
636 }
637 
638 // for lowering all appearances of OP_cand and OP_cior associated with condional
639 // branches in the block
LowerCandCior(BlockNode & block)640 void MIRLower::LowerCandCior(BlockNode &block)
641 {
642     if (block.GetFirst() == nullptr) {
643         return;
644     }
645     StmtNode *nextStmt = block.GetFirst();
646     do {
647         StmtNode *stmt = nextStmt;
648         nextStmt = stmt->GetNext();
649         if (stmt->IsCondBr() && (stmt->Opnd(0) != nullptr &&
650            (stmt->Opnd(0)->GetOpCode() == OP_cand || stmt->Opnd(0)->GetOpCode() == OP_cior))) {
651             CondGotoNode *condGoto = static_cast<CondGotoNode *>(stmt);
652             BinaryNode *cond = static_cast<BinaryNode *>(condGoto->Opnd(0));
653             if ((stmt->GetOpCode() == OP_brfalse && cond->GetOpCode() == OP_cand) ||
654                 (stmt->GetOpCode() == OP_brtrue && cond->GetOpCode() == OP_cior)) {
655                 // short-circuit target label is same as original condGoto stmt
656                 condGoto->SetOpnd(cond->GetBOpnd(0), 0);
657                 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(Opcode(stmt->GetOpCode()));
658                 newCondGoto->SetOpnd(cond->GetBOpnd(1), 0);
659                 newCondGoto->SetOffset(condGoto->GetOffset());
660                 block.InsertAfter(condGoto, newCondGoto);
661                 nextStmt = stmt;  // so it will be re-processed if another cand/cior
662             } else {              // short-circuit target is next statement
663                 LabelIdx lIdx;
664                 LabelNode *labelStmt = nullptr;
665                 if (nextStmt->GetOpCode() == OP_label) {
666                     labelStmt = static_cast<LabelNode *>(nextStmt);
667                     lIdx = labelStmt->GetLabelIdx();
668                 } else {
669                     DEBUG_ASSERT(mirModule.CurFunction() != nullptr, "mirModule.CurFunction() should not be nullptr");
670                     lIdx = mirModule.CurFunction()->GetLabelTab()->CreateLabel();
671                     mirModule.CurFunction()->GetLabelTab()->AddToStringLabelMap(lIdx);
672                     labelStmt = mirModule.CurFuncCodeMemPool()->New<LabelNode>();
673                     labelStmt->SetLabelIdx(lIdx);
674                     block.InsertAfter(condGoto, labelStmt);
675                 }
676                 auto *newCondGoto = mirModule.CurFuncCodeMemPool()->New<CondGotoNode>(
677                     stmt->GetOpCode() == OP_brfalse ? OP_brtrue : OP_brfalse);
678                 newCondGoto->SetOpnd(cond->GetBOpnd(0), 0);
679                 newCondGoto->SetOffset(lIdx);
680                 block.InsertBefore(condGoto, newCondGoto);
681                 condGoto->SetOpnd(cond->GetBOpnd(1), 0);
682                 nextStmt = newCondGoto;  // so it will be re-processed if another cand/cior
683             }
684         } else {  // call LowerEmbeddedCandCior() for all the expression operands
685             for (size_t i = 0; i < stmt->GetNumOpnds(); i++) {
686                 stmt->SetOpnd(LowerEmbeddedCandCior(stmt->Opnd(i), stmt, &block), i);
687             }
688         }
689     } while (nextStmt != nullptr);
690 }
691 
LowerFunc(MIRFunction & func)692 void MIRLower::LowerFunc(MIRFunction &func)
693 {
694     if (GetOptLevel() > 0) {
695         ExtConstantFold ecf(func.GetModule());
696         (void)ecf.ExtSimplify(func.GetBody());
697         ;
698     }
699 
700     mirModule.SetCurFunction(&func);
701     if (IsLowerExpandArray()) {
702         ExpandArrayMrt(func);
703     }
704     BlockNode *origBody = func.GetBody();
705     DEBUG_ASSERT(origBody != nullptr, "nullptr check");
706     BlockNode *newBody = LowerBlock(*origBody);
707     DEBUG_ASSERT(newBody != nullptr, "nullptr check");
708     LowerBuiltinExpect(*newBody);
709     if (!InLFO()) {
710         LowerCandCior(*newBody);
711     }
712     func.SetBody(newBody);
713 }
714 
LowerFarray(ArrayNode * array)715 BaseNode *MIRLower::LowerFarray(ArrayNode *array)
716 {
717     auto *farrayType = static_cast<MIRFarrayType *>(array->GetArrayType(GlobalTables::GetTypeTable()));
718     size_t eSize = GlobalTables::GetTypeTable().GetTypeFromTyIdx(farrayType->GetElemTyIdx())->GetSize();
719     MIRType &arrayType = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()));
720     /* how about multi-dimension array? */
721     if (array->GetIndex(0)->GetOpCode() == OP_constval) {
722         const ConstvalNode *constvalNode = static_cast<const ConstvalNode *>(array->GetIndex(0));
723         if (constvalNode->GetConstVal()->GetKind() == kConstInt) {
724             const MIRIntConst *pIntConst = static_cast<const MIRIntConst *>(constvalNode->GetConstVal());
725             CHECK_FATAL(!pIntConst->IsNegative(), "Array index should >= 0.");
726             int64 eleOffset = pIntConst->GetExtValue() * static_cast<int64>(eSize);
727 
728             BaseNode *baseNode = array->GetBase();
729             if (eleOffset == 0) {
730                 return baseNode;
731             }
732 
733             MIRIntConst *eleConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(eleOffset, arrayType);
734             BaseNode *offsetNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(eleConst);
735             offsetNode->SetPrimType(array->GetPrimType());
736 
737             BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
738             rAdd->SetPrimType(array->GetPrimType());
739             rAdd->SetOpnd(baseNode, 0);
740             rAdd->SetOpnd(offsetNode, 1);
741             return rAdd;
742         }
743     }
744 
745     BaseNode *rMul = nullptr;
746 
747     BaseNode *baseNode = array->GetBase();
748 
749     BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
750     rAdd->SetPrimType(array->GetPrimType());
751     rAdd->SetOpnd(baseNode, 0);
752     rAdd->SetOpnd(rMul, 1);
753     auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
754     rAdd = (newAdd != nullptr ? newAdd : rAdd);
755     return rAdd;
756 }
757 
LowerCArray(ArrayNode * array)758 BaseNode *MIRLower::LowerCArray(ArrayNode *array)
759 {
760     MIRType *aType = array->GetArrayType(GlobalTables::GetTypeTable());
761     if (aType->GetKind() == kTypeJArray) {
762         return array;
763     }
764     if (aType->GetKind() == kTypeFArray) {
765         return LowerFarray(array);
766     }
767 
768     MIRArrayType *arrayType = static_cast<MIRArrayType *>(aType);
769     /* There are two cases where dimension > 1.
770      * 1) arrayType->dim > 1.  Process the current arrayType. (nestedArray = false)
771      * 2) arrayType->dim == 1, but arraytype->eTyIdx is another array. (nestedArray = true)
772      * Assume at this time 1) and 2) cannot mix.
773      * Along with the array dimension, there is the array indexing.
774      * It is allowed to index arrays less than the dimension.
775      * This is dictated by the number of indexes.
776      */
777     bool nestedArray = false;
778     uint64 dim = arrayType->GetDim();
779     MIRType *innerType = nullptr;
780     MIRArrayType *innerArrayType = nullptr;
781     uint64 elemSize = 0;
782     if (dim == 1) {
783         innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(arrayType->GetElemTyIdx());
784         if (innerType->GetKind() == kTypeArray) {
785             nestedArray = true;
786             do {
787                 innerArrayType = static_cast<MIRArrayType *>(innerType);
788                 elemSize = RoundUp(innerArrayType->GetElemType()->GetSize(), arrayType->GetElemType()->GetAlign());
789                 dim++;
790                 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
791             } while (innerType->GetKind() == kTypeArray);
792         }
793     }
794 
795     CHECK_FATAL(array->NumOpnds() > 0, "must not be zero");
796     size_t numIndex = array->NumOpnds() - 1;
797     MIRArrayType *curArrayType = arrayType;
798     BaseNode *resNode = array->GetIndex(0);
799     if (dim > 1) {
800         BaseNode *prevNode = nullptr;
801         for (size_t i = 0; (i < dim) && (i < numIndex); ++i) {
802             uint32 mpyDim = 1;
803             if (nestedArray) {
804                 CHECK_FATAL(arrayType->GetSizeArrayItem(0) > 0, "Zero size array dimension");
805                 innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(curArrayType->GetElemTyIdx());
806                 curArrayType = static_cast<MIRArrayType *>(innerType);
807                 while (innerType->GetKind() == kTypeArray) {
808                     innerArrayType = static_cast<MIRArrayType *>(innerType);
809                     mpyDim *= innerArrayType->GetSizeArrayItem(0);
810                     innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx());
811                 }
812             } else {
813                 CHECK_FATAL(arrayType->GetSizeArrayItem(static_cast<uint32>(i)) > 0, "Zero size array dimension");
814                 for (size_t j = i + 1; j < dim; ++j) {
815                     mpyDim *= arrayType->GetSizeArrayItem(static_cast<uint32>(j));
816                 }
817             }
818 
819             BaseNode *index = static_cast<ConstvalNode *>(array->GetIndex(i));
820             bool isConst = false;
821             uint64 indexVal = 0;
822             if (index->op == OP_constval) {
823                 ConstvalNode *constNode = static_cast<ConstvalNode *>(index);
824                 indexVal = static_cast<uint64>((static_cast<MIRIntConst *>(constNode->GetConstVal()))->GetExtValue());
825                 isConst = true;
826                 MIRIntConst *newConstNode = mirModule.GetMemPool()->New<MIRIntConst>(
827                     indexVal * mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
828                 BaseNode *newValNode = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(newConstNode);
829                 newValNode->SetPrimType(array->GetPrimType());
830                 if (i == 0) {
831                     prevNode = newValNode;
832                     continue;
833                 } else {
834                     resNode = newValNode;
835                 }
836             }
837             if (i > 0 && isConst == false) {
838                 resNode = array->GetIndex(i);
839             }
840 
841             BaseNode *mpyNode;
842             if (isConst) {
843                 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
844                     static_cast<int64>(mpyDim) * indexVal,
845                     *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
846                 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
847                 mulSize->SetPrimType(array->GetPrimType());
848                 mpyNode = mulSize;
849             } else if (mpyDim == 1 && prevNode) {
850                 mpyNode = prevNode;
851                 prevNode = resNode;
852             } else {
853                 mpyNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
854                 mpyNode->SetPrimType(array->GetPrimType());
855                 MIRIntConst *mulConst = mirModule.GetMemPool()->New<MIRIntConst>(
856                     mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
857                 BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(mulConst);
858                 mulSize->SetPrimType(array->GetPrimType());
859                 mpyNode->SetOpnd(mulSize, 1);
860                 PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
861                 DEBUG_ASSERT(resNode != nullptr, "resNode should not be nullptr");
862                 if (!IsPrimitiveInteger(resNode->GetPrimType())) {
863                     resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
864                                                                                resNode->GetPrimType(), resNode);
865                 } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
866                     resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
867                         OP_cvt, array->GetPrimType(), GetRegPrimType(resNode->GetPrimType()), resNode);
868                 }
869                 mpyNode->SetOpnd(resNode, 0);
870             }
871             if (i == 0) {
872                 prevNode = mpyNode;
873                 continue;
874             }
875             BaseNode *newResNode = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_add);
876             newResNode->SetPrimType(array->GetPrimType());
877             newResNode->SetOpnd(mpyNode, 0);
878             if (prevNode != nullptr && NeedCvtOrRetype(prevNode->GetPrimType(), array->GetPrimType())) {
879                 prevNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(
880                     OP_cvt, array->GetPrimType(), GetRegPrimType(prevNode->GetPrimType()), prevNode);
881             }
882             newResNode->SetOpnd(prevNode, 1);
883             prevNode = newResNode;
884         }
885         resNode = prevNode;
886     }
887 
888     BaseNode *rMul = nullptr;
889     // esize is the size of the array element (eg. int = 4 long = 8)
890     uint64 esize;
891     if (nestedArray) {
892         esize = elemSize;
893     } else {
894         esize = arrayType->GetElemType()->GetSize();
895     }
896     Opcode opadd = OP_add;
897     MIRIntConst *econst = mirModule.GetMemPool()->New<MIRIntConst>(
898         esize, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())));
899     BaseNode *eSize = mirModule.CurFuncCodeMemPool()->New<ConstvalNode>(econst);
900     eSize->SetPrimType(array->GetPrimType());
901     rMul = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(OP_mul);
902     PrimType signedInt4AddressCompute = GetSignedPrimType(array->GetPrimType());
903     DEBUG_ASSERT(resNode != nullptr, "nullptr check");
904     if (!IsPrimitiveInteger(resNode->GetPrimType())) {
905         resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, signedInt4AddressCompute,
906                                                                    resNode->GetPrimType(), resNode);
907     } else if (GetPrimTypeSize(resNode->GetPrimType()) != GetPrimTypeSize(array->GetPrimType())) {
908         resNode = mirModule.CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, array->GetPrimType(),
909                                                                    GetRegPrimType(resNode->GetPrimType()), resNode);
910     }
911     rMul->SetPrimType(resNode->GetPrimType());
912     rMul->SetOpnd(resNode, 0);
913     rMul->SetOpnd(eSize, 1);
914     BaseNode *baseNode = array->GetBase();
915     BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New<BinaryNode>(opadd);
916     rAdd->SetPrimType(array->GetPrimType());
917     rAdd->SetOpnd(baseNode, 0);
918     rAdd->SetOpnd(rMul, 1);
919     auto *newAdd = ConstantFold(mirModule).Fold(rAdd);
920     rAdd = (newAdd != nullptr ? newAdd : rAdd);
921     return rAdd;
922 }
923 
ExpandArrayMrtIfBlock(IfStmtNode & node)924 IfStmtNode *MIRLower::ExpandArrayMrtIfBlock(IfStmtNode &node)
925 {
926     if (node.GetThenPart() != nullptr) {
927         node.SetThenPart(ExpandArrayMrtBlock(*node.GetThenPart()));
928     }
929     if (node.GetElsePart() != nullptr) {
930         node.SetElsePart(ExpandArrayMrtBlock(*node.GetElsePart()));
931     }
932     return &node;
933 }
934 
ExpandArrayMrtWhileBlock(WhileStmtNode & node)935 WhileStmtNode *MIRLower::ExpandArrayMrtWhileBlock(WhileStmtNode &node)
936 {
937     if (node.GetBody() != nullptr) {
938         node.SetBody(ExpandArrayMrtBlock(*node.GetBody()));
939     }
940     return &node;
941 }
942 
ExpandArrayMrtDoloopBlock(DoloopNode & node)943 DoloopNode *MIRLower::ExpandArrayMrtDoloopBlock(DoloopNode &node)
944 {
945     if (node.GetDoBody() != nullptr) {
946         node.SetDoBody(ExpandArrayMrtBlock(*node.GetDoBody()));
947     }
948     return &node;
949 }
950 
ExpandArrayMrtForeachelemBlock(ForeachelemNode & node)951 ForeachelemNode *MIRLower::ExpandArrayMrtForeachelemBlock(ForeachelemNode &node)
952 {
953     if (node.GetLoopBody() != nullptr) {
954         node.SetLoopBody(ExpandArrayMrtBlock(*node.GetLoopBody()));
955     }
956     return &node;
957 }
958 
AddArrayMrtMpl(BaseNode & exp,BlockNode & newBlock)959 void MIRLower::AddArrayMrtMpl(BaseNode &exp, BlockNode &newBlock)
960 {
961     MIRModule &mod = mirModule;
962     MIRBuilder *builder = mod.GetMIRBuilder();
963     for (size_t i = 0; i < exp.NumOpnds(); ++i) {
964         DEBUG_ASSERT(exp.Opnd(i) != nullptr, "nullptr check");
965         AddArrayMrtMpl(*exp.Opnd(i), newBlock);
966     }
967     if (exp.GetOpCode() == OP_array) {
968         auto &arrayNode = static_cast<ArrayNode &>(exp);
969         if (arrayNode.GetBoundsCheck()) {
970             BaseNode *arrAddr = arrayNode.Opnd(0);
971             UnaryStmtNode *nullCheck = builder->CreateStmtUnary(OP_assertnonnull, arrAddr);
972             newBlock.AddStatement(nullCheck);
973         }
974     }
975 }
976 
ExpandArrayMrtBlock(BlockNode & block)977 BlockNode *MIRLower::ExpandArrayMrtBlock(BlockNode &block)
978 {
979     auto *newBlock = mirModule.CurFuncCodeMemPool()->New<BlockNode>();
980     if (block.GetFirst() == nullptr) {
981         return newBlock;
982     }
983     StmtNode *nextStmt = block.GetFirst();
984     do {
985         StmtNode *stmt = nextStmt;
986         DEBUG_ASSERT(stmt != nullptr, "nullptr check");
987         nextStmt = stmt->GetNext();
988         switch (stmt->GetOpCode()) {
989             case OP_if:
990                 newBlock->AddStatement(ExpandArrayMrtIfBlock(static_cast<IfStmtNode &>(*stmt)));
991                 break;
992             case OP_while:
993                 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
994                 break;
995             case OP_dowhile:
996                 newBlock->AddStatement(ExpandArrayMrtWhileBlock(static_cast<WhileStmtNode &>(*stmt)));
997                 break;
998             case OP_doloop:
999                 newBlock->AddStatement(ExpandArrayMrtDoloopBlock(static_cast<DoloopNode &>(*stmt)));
1000                 break;
1001             case OP_foreachelem:
1002                 newBlock->AddStatement(ExpandArrayMrtForeachelemBlock(static_cast<ForeachelemNode &>(*stmt)));
1003                 break;
1004             case OP_block:
1005                 newBlock->AddStatement(ExpandArrayMrtBlock(static_cast<BlockNode &>(*stmt)));
1006                 break;
1007             default:
1008                 AddArrayMrtMpl(*stmt, *newBlock);
1009                 newBlock->AddStatement(stmt);
1010                 break;
1011         }
1012     } while (nextStmt != nullptr);
1013     return newBlock;
1014 }
1015 
ExpandArrayMrt(MIRFunction & func)1016 void MIRLower::ExpandArrayMrt(MIRFunction &func)
1017 {
1018     if (ShouldOptArrayMrt(func)) {
1019         BlockNode *origBody = func.GetBody();
1020         DEBUG_ASSERT(origBody != nullptr, "nullptr check");
1021         BlockNode *newBody = ExpandArrayMrtBlock(*origBody);
1022         func.SetBody(newBody);
1023     }
1024 }
1025 
FuncTypeFromFuncPtrExpr(BaseNode * x)1026 MIRFuncType *MIRLower::FuncTypeFromFuncPtrExpr(BaseNode *x)
1027 {
1028     DEBUG_ASSERT(x != nullptr, "nullptr check");
1029     MIRFuncType *res = nullptr;
1030     MIRFunction *func = mirModule.CurFunction();
1031     switch (x->GetOpCode()) {
1032         case OP_regread: {
1033             RegreadNode *regread = static_cast<RegreadNode *>(x);
1034             MIRPreg *preg = func->GetPregTab()->PregFromPregIdx(regread->GetRegIdx());
1035             // see if it is promoted from a symbol
1036             if (preg->GetOp() == OP_dread) {
1037                 const MIRSymbol *symbol = preg->rematInfo.sym;
1038                 MIRType *mirType = symbol->GetType();
1039                 if (preg->fieldID != 0) {
1040                     MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1041                     FieldPair thepair = structty->TraverseToField(preg->fieldID);
1042                     mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1043                 }
1044 
1045                 if (mirType->GetKind() == kTypePointer) {
1046                     res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1047                 }
1048                 if (res != nullptr) {
1049                     break;
1050                 }
1051             }
1052             // check if a formal promoted to preg
1053             for (FormalDef &formalDef : func->GetFormalDefVec()) {
1054                 if (!formalDef.formalSym->IsPreg()) {
1055                     continue;
1056                 }
1057                 if (formalDef.formalSym->GetPreg() == preg) {
1058                     MIRType *mirType = formalDef.formalSym->GetType();
1059                     if (mirType->GetKind() == kTypePointer) {
1060                         res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1061                     }
1062                     break;
1063                 }
1064             }
1065             break;
1066         }
1067         case OP_dread: {
1068             DreadNode *dread = static_cast<DreadNode *>(x);
1069             MIRSymbol *symbol = func->GetLocalOrGlobalSymbol(dread->GetStIdx());
1070             MIRType *mirType = symbol->GetType();
1071             if (dread->GetFieldID() != 0) {
1072                 MIRStructType *structty = static_cast<MIRStructType *>(mirType);
1073                 FieldPair thepair = structty->TraverseToField(dread->GetFieldID());
1074                 mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(thepair.second.first);
1075             }
1076             if (mirType->GetKind() == kTypePointer) {
1077                 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1078             }
1079             break;
1080         }
1081         case OP_iread: {
1082             IreadNode *iread = static_cast<IreadNode *>(x);
1083             MIRPtrType *ptrType = static_cast<MIRPtrType *>(iread->GetType());
1084             MIRType *mirType = ptrType->GetPointedType();
1085             if (mirType->GetKind() == kTypeFunction) {
1086                 res = static_cast<MIRFuncType *>(mirType);
1087             } else if (mirType->GetKind() == kTypePointer) {
1088                 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1089             }
1090             break;
1091         }
1092         case OP_addroffunc: {
1093             AddroffuncNode *addrofFunc = static_cast<AddroffuncNode *>(x);
1094             PUIdx puIdx = addrofFunc->GetPUIdx();
1095             MIRFunction *f = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
1096             res = f->GetMIRFuncType();
1097             break;
1098         }
1099         case OP_retype: {
1100             MIRType *mirType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(static_cast<RetypeNode *>(x)->GetTyIdx());
1101             if (mirType->GetKind() == kTypePointer) {
1102                 res = static_cast<MIRPtrType *>(mirType)->GetPointedFuncType();
1103             }
1104             if (res == nullptr) {
1105                 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeFirstOpnd));
1106             }
1107             break;
1108         }
1109         case OP_select: {
1110             res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeSecondOpnd));
1111             if (res == nullptr) {
1112                 res = FuncTypeFromFuncPtrExpr(x->Opnd(kNodeThirdOpnd));
1113             }
1114             break;
1115         }
1116         default:
1117             CHECK_FATAL(false, "LMBCLowerer::FuncTypeFromFuncPtrExpr: NYI");
1118     }
1119     return res;
1120 }
1121 
1122 const std::set<std::string> MIRLower::kSetArrayHotFunc = {};
1123 
ShouldOptArrayMrt(const MIRFunction & func)1124 bool MIRLower::ShouldOptArrayMrt(const MIRFunction &func)
1125 {
1126     return (MIRLower::kSetArrayHotFunc.find(func.GetName()) != MIRLower::kSetArrayHotFunc.end());
1127 }
1128 }  // namespace maple
1129