• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "java_eh_lower.h"
17 #include "mir_function.h"
18 #include "mir_builder.h"
19 #include "global_tables.h"
20 #include "option.h"
21 
22 namespace {
23 const std::string strDivOpnd = "__div_opnd1";
24 const std::string strDivRes = "__div_res";
25 const std::string strMCCThrowArrayIndexOutOfBoundsException = "MCC_ThrowArrayIndexOutOfBoundsException";
26 const std::string strMCCThrowNullPointerException = "MCC_ThrowNullPointerException";
27 }  // namespace
28 
29 // Do exception handling runtime insertion of runtime function call
30 // scan the entire function body once to lookup expression that
31 // could potentially raise exceptions such as division,
32 // for example:
33 // if we have x = a/b
34 // and we don't know the value of b during compile time
35 // then we will insert the test for exception:
36 // if b == 0,
37 //  call MCC_ThrowArithmeticException()
38 // x = a/b
39 namespace maple {
DoLowerDiv(BinaryNode & expr,BlockNode & blknode)40 BaseNode *JavaEHLowerer::DoLowerDiv(BinaryNode &expr, BlockNode &blknode)
41 {
42     PrimType ptype = expr.GetPrimType();
43     MIRBuilder *mirBuilder = GetMIRModule().GetMIRBuilder();
44     MIRFunction *func = GetMIRModule().CurFunction();
45     if (!IsPrimitiveInteger(ptype)) {
46         return &expr;
47     }
48 
49     // Store divopnd to a tmp st if not a leaf node.
50     BaseNode *divOpnd = expr.Opnd(1);
51     if (!divOpnd->IsLeaf()) {
52         std::string opnd1name(strDivOpnd);
53         opnd1name.append(std::to_string(divSTIndex));
54         if (useRegTmp) {
55             PregIdx pregIdx = func->GetPregTab()->CreatePreg(ptype);
56             RegassignNode *regassDivnode = mirBuilder->CreateStmtRegassign(ptype, pregIdx, divOpnd);
57             blknode.AddStatement(regassDivnode);
58             divOpnd = mirBuilder->CreateExprRegread(ptype, pregIdx);
59         } else {
60             MIRSymbol *divOpndSymbol = mirBuilder->CreateSymbol(TyIdx(ptype), opnd1name, kStVar, kScAuto,
61                                                                 GetMIRModule().CurFunction(), kScopeLocal);
62             DassignNode *dssDivNode = mirBuilder->CreateStmtDassign(*divOpndSymbol, 0, divOpnd);
63             blknode.AddStatement(dssDivNode);
64             divOpnd = mirBuilder->CreateExprDread(*divOpndSymbol);
65         }
66         expr.SetBOpnd(divOpnd, 1);
67     }
68     BaseNode *retExprNode = nullptr;
69     StmtNode *divStmt = nullptr;
70     if (useRegTmp) {
71         PregIdx resPregIdx = func->GetPregTab()->CreatePreg(ptype);
72         divStmt = mirBuilder->CreateStmtRegassign(ptype, resPregIdx, &expr);
73         retExprNode = GetMIRModule().GetMIRBuilder()->CreateExprRegread(ptype, resPregIdx);
74     } else {
75         std::string resName(strDivRes);
76         resName.append(std::to_string(divSTIndex++));
77         MIRSymbol *divResSymbol =
78             mirBuilder->CreateSymbol(TyIdx(ptype), resName, kStVar, kScAuto, GetMIRModule().CurFunction(), kScopeLocal);
79         // Put expr result to dssnode.
80         divStmt = mirBuilder->CreateStmtDassign(*divResSymbol, 0, &expr);
81         retExprNode = GetMIRModule().GetMIRBuilder()->CreateExprDread(*divResSymbol, 0);
82     }
83     // Check if the second operand of the div expression is 0.
84     // Inser if statement for high level ir.
85     CompareNode *cmpNode = mirBuilder->CreateExprCompare(OP_eq, *GlobalTables::GetTypeTable().GetInt32(),
86                                                          *GlobalTables::GetTypeTable().GetTypeFromTyIdx((TyIdx)ptype),
87                                                          divOpnd, mirBuilder->CreateIntConst(0, ptype));
88     IfStmtNode *ifStmtNode = mirBuilder->CreateStmtIf(cmpNode);
89     blknode.AddStatement(ifStmtNode);
90     // Call the MCC_ThrowArithmeticException() that will never return.
91     MapleVector<BaseNode *> args(GetMIRModule().GetMIRBuilder()->GetCurrentFuncCodeMpAllocator()->Adapter());
92     IntrinsiccallNode *intrinCallNode = mirBuilder->CreateStmtIntrinsicCall(INTRN_JAVA_THROW_ARITHMETIC, args);
93     ifStmtNode->GetThenPart()->AddStatement(intrinCallNode);
94     blknode.AddStatement(divStmt);
95     // Make dread from the divresst and return it as new expression for this function.
96     return retExprNode;
97 }
98 
DoLowerExpr(BaseNode & expr,BlockNode & curblk)99 BaseNode *JavaEHLowerer::DoLowerExpr(BaseNode &expr, BlockNode &curblk)
100 {
101     for (size_t i = 0; i < expr.NumOpnds(); ++i) {
102         expr.SetOpnd(DoLowerExpr(*(expr.Opnd(i)), curblk), i);
103     }
104     switch (expr.GetOpCode()) {
105         case OP_div: {
106             return DoLowerDiv(static_cast<BinaryNode &>(expr), curblk);
107         }
108         case OP_rem: {
109             return DoLowerRem(static_cast<BinaryNode &>(expr), curblk);
110         }
111         default:
112             return &expr;
113     }
114 }
115 
DoLowerBoundaryCheck(IntrinsiccallNode & intrincall,BlockNode & newblk)116 void JavaEHLowerer::DoLowerBoundaryCheck(IntrinsiccallNode &intrincall, BlockNode &newblk)
117 {
118     const size_t intrincallNopndSize = intrincall.GetNopndSize();
119     CHECK_FATAL(intrincallNopndSize > 0, "null ptr check");
120     CondGotoNode *brFalseStmt = GetMIRModule().CurFuncCodeMemPool()->New<CondGotoNode>(OP_brfalse);
121     brFalseStmt->SetOpnd(DoLowerExpr(*(intrincall.GetNopndAt(0)), newblk), 0);
122     brFalseStmt->SetSrcPos(intrincall.GetSrcPos());
123     LabelIdx lbidx = GetMIRModule().CurFunction()->GetLabelTab()->CreateLabel();
124     GetMIRModule().CurFunction()->GetLabelTab()->AddToStringLabelMap(lbidx);
125     brFalseStmt->SetOffset(lbidx);
126     newblk.AddStatement(brFalseStmt);
127     LabelNode *labStmt = GetMIRModule().CurFuncCodeMemPool()->New<LabelNode>();
128     labStmt->SetLabelIdx(lbidx);
129     MIRFunction *func =
130         GetMIRModule().GetMIRBuilder()->GetOrCreateFunction(strMCCThrowArrayIndexOutOfBoundsException, TyIdx(PTY_void));
131     MapleVector<BaseNode *> args(GetMIRModule().GetMIRBuilder()->GetCurrentFuncCodeMpAllocator()->Adapter());
132     CallNode *callStmt = GetMIRModule().GetMIRBuilder()->CreateStmtCall(func->GetPuidx(), args);
133     newblk.AddStatement(callStmt);
134     newblk.AddStatement(labStmt);
135 }
136 
DoLowerBlock(BlockNode & block)137 BlockNode *JavaEHLowerer::DoLowerBlock(BlockNode &block)
138 {
139     BlockNode *newBlock = GetMIRModule().CurFuncCodeMemPool()->New<BlockNode>();
140     StmtNode *nextStmt = block.GetFirst();
141     if (nextStmt == nullptr) {
142         return newBlock;
143     }
144 
145     do {
146         StmtNode *stmt = nextStmt;
147         nextStmt = stmt->GetNext();
148         stmt->SetNext(nullptr);
149 
150         switch (stmt->GetOpCode()) {
151             case OP_switch: {
152                 auto *switchNode = static_cast<SwitchNode *>(stmt);
153                 switchNode->SetSwitchOpnd(DoLowerExpr(*(switchNode->GetSwitchOpnd()), *newBlock));
154                 newBlock->AddStatement(switchNode);
155                 break;
156             }
157             case OP_if: {
158                 auto *ifStmtNode = static_cast<IfStmtNode *>(stmt);
159                 BlockNode *thenPart = ifStmtNode->GetThenPart();
160                 BlockNode *elsePart = ifStmtNode->GetElsePart();
161                 DEBUG_ASSERT(ifStmtNode->Opnd() != nullptr, "null ptr check!");
162                 ifStmtNode->SetOpnd(DoLowerExpr(*(ifStmtNode->Opnd()), *newBlock), 0);
163                 ifStmtNode->SetThenPart(DoLowerBlock(*thenPart));
164                 if (elsePart != nullptr) {
165                     ifStmtNode->SetElsePart(DoLowerBlock(*elsePart));
166                 }
167                 newBlock->AddStatement(ifStmtNode);
168                 break;
169             }
170             case OP_while:
171             case OP_dowhile: {
172                 auto *whileStmtNode = static_cast<WhileStmtNode *>(stmt);
173                 BaseNode *testOpnd = whileStmtNode->Opnd(0);
174                 whileStmtNode->SetOpnd(DoLowerExpr(*testOpnd, *newBlock), 0);
175                 whileStmtNode->SetBody(DoLowerBlock(*(whileStmtNode->GetBody())));
176                 newBlock->AddStatement(whileStmtNode);
177                 break;
178             }
179             case OP_doloop: {
180                 auto *doLoopNode = static_cast<DoloopNode *>(stmt);
181                 doLoopNode->SetStartExpr(DoLowerExpr(*(doLoopNode->GetStartExpr()), *newBlock));
182                 doLoopNode->SetContExpr(DoLowerExpr(*(doLoopNode->GetCondExpr()), *newBlock));
183                 doLoopNode->SetIncrExpr(DoLowerExpr(*(doLoopNode->GetIncrExpr()), *newBlock));
184                 doLoopNode->SetDoBody(DoLowerBlock(*(doLoopNode->GetDoBody())));
185                 newBlock->AddStatement(doLoopNode);
186                 break;
187             }
188             case OP_block: {
189                 auto *tmp = DoLowerBlock(*(static_cast<BlockNode *>(stmt)));
190                 CHECK_FATAL(tmp != nullptr, "null ptr check");
191                 newBlock->AddStatement(tmp);
192                 break;
193             }
194             case OP_throw: {
195                 auto *tstmt = static_cast<UnaryStmtNode *>(stmt);
196                 BaseNode *opnd0 = DoLowerExpr(*(tstmt->Opnd(0)), *newBlock);
197                 if (opnd0->GetOpCode() == OP_constval) {
198                     CHECK_FATAL(IsPrimitiveInteger(opnd0->GetPrimType()), "must be integer or something wrong");
199                     auto *intConst = safe_cast<MIRIntConst>(static_cast<ConstvalNode *>(opnd0)->GetConstVal());
200                     CHECK_FATAL(intConst->IsZero(), "can only be zero");
201                     MIRFunction *func = GetMIRModule().GetMIRBuilder()->GetOrCreateFunction(
202                         strMCCThrowNullPointerException, TyIdx(PTY_void));
203                     func->SetNoReturn();
204                     MapleVector<BaseNode *> args(
205                         GetMIRModule().GetMIRBuilder()->GetCurrentFuncCodeMpAllocator()->Adapter());
206                     CallNode *callStmt = GetMIRModule().GetMIRBuilder()->CreateStmtCall(func->GetPuidx(), args);
207                     newBlock->AddStatement(callStmt);
208                 } else {
209                     tstmt->SetOpnd(opnd0, 0);
210                     newBlock->AddStatement(tstmt);
211                 }
212                 break;
213             }
214             case OP_intrinsiccall: {
215                 auto *intrinCall = static_cast<IntrinsiccallNode *>(stmt);
216                 if (intrinCall->GetIntrinsic() == INTRN_MPL_BOUNDARY_CHECK) {
217                     DoLowerBoundaryCheck(*intrinCall, *newBlock);
218                     break;
219                 }
220             }
221                 [[clang::fallthrough]];
222             default: {
223                 for (size_t i = 0; i < stmt->NumOpnds(); ++i) {
224                     stmt->SetOpnd(DoLowerExpr(*(stmt->Opnd(i)), *newBlock), i);
225                 }
226                 newBlock->AddStatement(stmt);
227                 break;
228             }
229         }
230     } while (nextStmt != nullptr);
231     return newBlock;
232 }
233 
ProcessFunc(MIRFunction * func)234 void JavaEHLowerer::ProcessFunc(MIRFunction *func)
235 {
236     GetMIRModule().SetCurFunction(func);
237     if (func->GetBody() == nullptr) {
238         return;
239     }
240     divSTIndex = 0;  // Init it to 0.
241     BlockNode *newBody = DoLowerBlock(*(func->GetBody()));
242     func->SetBody(newBody);
243 }
244 
PhaseRun(maple::MIRModule & m)245 bool M2MJavaEHLowerer::PhaseRun(maple::MIRModule &m)
246 {
247     OPT_TEMPLATE_NEWPM(JavaEHLowerer, m);
248     return true;
249 }
250 
GetAnalysisDependence(maple::AnalysisDep & aDep) const251 void M2MJavaEHLowerer::GetAnalysisDependence(maple::AnalysisDep &aDep) const
252 {
253     aDep.AddRequired<M2MKlassHierarchy>();
254     aDep.SetPreservedAll();
255 }
256 }  // namespace maple
257