• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "call_graph.h"
17 
18 #include <algorithm>
19 #include <iostream>
20 #include <queue>
21 #include <unordered_set>
22 #include <fstream>
23 
24 #include "option.h"
25 #include "string_utils.h"
26 #include "maple_phase_manager.h"
27 
28 //                   Call Graph Analysis
29 // This phase is a foundation phase of compilation. This phase build
30 // the call graph not only for this module also for the modules it
31 // depends on when this phase is running for IPA.
32 // The main procedure shows as following.
33 // A. Devirtual virtual call of private final static and none-static
34 //    variable. This step aims to reduce the callee set for each call
35 //    which can benefit IPA analysis.
36 // B. Build Call Graph.
37 //    i)  For IPA, it rebuild all the call graph of the modules this
38 //        module depends on. All necessary information is stored in mplt.
39 //    ii) Analysis each function in this module. For each call statement
40 //        create a CGNode, and collect potential callee functions to
41 //        generate Call Graph.
42 // C. Find All Root Node for the Call Graph.
43 // D. Construct SCC based on Tarjan Algorithm
44 // E. Set compilation order as the bottom-up order of callgraph. So callee
45 //    is always compiled before caller. This benefits those optimizations
46 //    need interprocedure information like escape analysis.
47 namespace maple {
GetCallTypeName() const48 const std::string CallInfo::GetCallTypeName() const
49 {
50     switch (cType) {
51         case kCallTypeCall:
52             return "c";
53         case kCallTypeVirtualCall:
54             return "v";
55         case kCallTypeSuperCall:
56             return "s";
57         case kCallTypeInterfaceCall:
58             return "i";
59         case kCallTypeIcall:
60             return "icall";
61         case kCallTypeIntrinsicCall:
62             return "intrinsiccall";
63         case kCallTypeXinitrinsicCall:
64             return "xintrinsiccall";
65         case kCallTypeIntrinsicCallWithType:
66             return "intrinsiccallwithtype";
67         case kCallTypeFakeThreadStartRun:
68             return "fakecallstartrun";
69         case kCallTypeCustomCall:
70             return "customcall";
71         case kCallTypePolymorphicCall:
72             return "polymorphiccall";
73         default:
74             CHECK_FATAL(false, "unsupport CALL type");
75             return "";
76     }
77 }
78 
GetCalleeName() const79 const std::string CallInfo::GetCalleeName() const
80 {
81     if ((cType >= kCallTypeCall) && (cType <= kCallTypeInterfaceCall)) {
82         MIRFunction &mirf = *mirFunc;
83         return mirf.GetName();
84     } else if (cType == kCallTypeIcall) {
85         return "IcallUnknown";
86     } else if ((cType >= kCallTypeIntrinsicCall) && (cType <= kCallTypeIntrinsicCallWithType)) {
87         return "IntrinsicCall";
88     } else if (cType == kCallTypeCustomCall) {
89         return "CustomCall";
90     } else if (cType == kCallTypePolymorphicCall) {
91         return "PolymorphicCall";
92     }
93     CHECK_FATAL(false, "should not be here");
94     return "";
95 }
96 
DumpDetail() const97 void CGNode::DumpDetail() const
98 {
99     LogInfo::MapleLogger() << "---CGNode  @" << this << ": " << mirFunc->GetName() << "\t";
100     if (Options::profileUse && mirFunc->GetFuncProfData()) {
101         LogInfo::MapleLogger() << " Ffreq " << GetFuncFrequency() << "\t";
102     }
103     if (HasOneCandidate() != nullptr) {
104         LogInfo::MapleLogger() << "@One Candidate\n";
105     } else {
106         LogInfo::MapleLogger() << std::endl;
107     }
108     if (HasSetVCallCandidates()) {
109         for (uint32 i = 0; i < vcallCands.size(); ++i) {
110             LogInfo::MapleLogger() << "   virtual call candidates: " << vcallCands[i]->GetName() << "\n";
111         }
112     }
113     for (auto &callSite : callees) {
114         for (auto &cgIt : *callSite.second) {
115             CallInfo *ci = callSite.first;
116             CGNode *node = cgIt;
117             MIRFunction *mf = node->GetMIRFunction();
118             if (mf != nullptr) {
119                 LogInfo::MapleLogger() << "\tcallee in module : " << mf->GetName() << "  ";
120             } else {
121                 LogInfo::MapleLogger() << "\tcallee external: " << ci->GetCalleeName();
122             }
123             if (Options::profileUse) {
124                 LogInfo::MapleLogger() << " callsite freq" << GetCallsiteFrequency(ci->GetCallStmt());
125             }
126             LogInfo::MapleLogger() << "\n";
127         }
128     }
129     // dump caller
130     for (auto const &callerNode : GetCaller()) {
131         CHECK_NULL_FATAL(callerNode.first);
132         CHECK_NULL_FATAL(callerNode.first->mirFunc);
133         LogInfo::MapleLogger() << "\tcaller : " << callerNode.first->mirFunc->GetName() << std::endl;
134     }
135 }
136 
Dump(std::ofstream & fout) const137 void CGNode::Dump(std::ofstream &fout) const
138 {
139     // if dumpall == 1, dump whole call graph
140     // else dump callgraph with function defined in same module
141     CHECK_NULL_FATAL(mirFunc);
142     constexpr size_t withoutRingNodeSize = 1;
143     if (callees.empty()) {
144         fout << "\"" << mirFunc->GetName() << "\";\n";
145         return;
146     }
147     for (auto &callSite : callees) {
148         for (auto &cgIt : *callSite.second) {
149             CallInfo *ci = callSite.first;
150             CGNode *node = cgIt;
151             if (node == nullptr) {
152                 continue;
153             }
154             MIRFunction *func = node->GetMIRFunction();
155             fout << "\"" << mirFunc->GetName() << "\" -> ";
156             if (func != nullptr) {
157                 if (node->GetSCCNode() != nullptr && node->GetSCCNode()->GetNodes().size() > withoutRingNodeSize) {
158                     fout << "\"" << func->GetName() << "\"[label=" << node->GetSCCNode()->GetID() << " color=red];\n";
159                 } else {
160                     fout << "\"" << func->GetName() << "\"[label=" << 0 << " color=blue];\n";
161                 }
162             } else {
163                 // unknown / external function with empty function body
164                 fout << "\"" << ci->GetCalleeName() << "\"[label=" << ci->GetCallTypeName() << " color=blue];\n";
165             }
166         }
167     }
168 }
169 
AddCallsite(CallInfo * ci,MapleSet<CGNode *,Comparator<CGNode>> * callee)170 void CGNode::AddCallsite(CallInfo *ci, MapleSet<CGNode *, Comparator<CGNode>> *callee)
171 {
172     (void)callees.insert(std::pair<CallInfo *, MapleSet<CGNode *, Comparator<CGNode>> *>(ci, callee));
173 }
174 
AddCallsite(CallInfo & ci,CGNode * node)175 void CGNode::AddCallsite(CallInfo &ci, CGNode *node)
176 {
177     CHECK_FATAL(ci.GetCallType() != kCallTypeInterfaceCall, "must be true");
178     CHECK_FATAL(ci.GetCallType() != kCallTypeVirtualCall, "must be true");
179     auto *cgVector = alloc->GetMemPool()->New<MapleSet<CGNode *, Comparator<CGNode>>>(alloc->Adapter());
180     if (node != nullptr) {
181         node->AddNumRefs();
182         cgVector->insert(node);
183     }
184     (void)callees.emplace(&ci, cgVector);
185 }
186 
RemoveCallsite(const CallInfo * ci,CGNode * node)187 void CGNode::RemoveCallsite(const CallInfo *ci, CGNode *node)
188 {
189     for (Callsite callSite : GetCallee()) {
190         if (callSite.first == ci) {
191             auto cgIt = callSite.second->find(node);
192             if (cgIt != callSite.second->end()) {
193                 callSite.second->erase(cgIt);
194                 return;
195             }
196             CHECK_FATAL(false, "node isn't in ci");
197         }
198     }
199 }
200 
IsCalleeOf(CGNode * func) const201 bool CGNode::IsCalleeOf(CGNode *func) const
202 {
203     return callers.find(func) != callers.end();
204 }
205 
GetCallsiteFrequency(const StmtNode * callstmt) const206 int64_t CGNode::GetCallsiteFrequency(const StmtNode *callstmt) const
207 {
208     GcovFuncInfo *funcInfo = mirFunc->GetFuncProfData();
209     if (funcInfo->stmtFreqs.count(callstmt->GetStmtID()) > 0) {
210         return funcInfo->stmtFreqs[callstmt->GetStmtID()];
211     }
212     DEBUG_ASSERT(0, "should not be here");
213     return -1;
214 }
215 
GetFuncFrequency() const216 int64_t CGNode::GetFuncFrequency() const
217 {
218     GcovFuncInfo *funcInfo = mirFunc->GetFuncProfData();
219     if (funcInfo) {
220         return funcInfo->GetFuncFrequency();
221     }
222     DEBUG_ASSERT(0, "should not be here");
223     return 0;
224 }
225 
ClearFunctionList()226 void CallGraph::ClearFunctionList()
227 {
228     for (auto iter = mirModule->GetFunctionList().begin(); iter != mirModule->GetFunctionList().end();) {
229         if (GlobalTables::GetFunctionTable().GetFunctionFromPuidx((*iter)->GetPuidx()) == nullptr) {
230             (*iter)->ReleaseCodeMemory();
231             (*iter)->ReleaseMemory();
232             iter = mirModule->GetFunctionList().erase(iter);
233         } else {
234             ++iter;
235         }
236     }
237 }
238 
DelNode(CGNode & node)239 void CallGraph::DelNode(CGNode &node)
240 {
241     if (node.GetMIRFunction() == nullptr) {
242         return;
243     }
244     for (auto &callSite : node.GetCallee()) {
245         for (auto &cgIt : *callSite.second) {
246             cgIt->DelCaller(&node);
247             node.DelCallee(callSite.first, cgIt);
248             if (!cgIt->HasCaller() && cgIt->GetMIRFunction()->IsStatic() && !cgIt->IsAddrTaken()) {
249                 DelNode(*cgIt);
250             }
251         }
252     }
253     MIRFunction *func = node.GetMIRFunction();
254     // Delete the method of class info
255     if (func->GetClassTyIdx() != 0u) {
256         MIRType *classType = GlobalTables::GetTypeTable().GetTypeTable().at(func->GetClassTyIdx());
257         auto *mirStructType = static_cast<MIRStructType *>(classType);
258         size_t j = 0;
259         for (; j < mirStructType->GetMethods().size(); ++j) {
260             if (mirStructType->GetMethodsElement(j).first == func->GetStIdx()) {
261                 mirStructType->GetMethods().erase(mirStructType->GetMethods().begin() + static_cast<ssize_t>(j));
262                 break;
263             }
264         }
265     }
266     GlobalTables::GetFunctionTable().SetFunctionItem(func->GetPuidx(), nullptr);
267     // func will be erased, so the coressponding symbol should be set as Deleted
268     func->GetFuncSymbol()->SetIsDeleted();
269     nodesMap.erase(func);
270     // Update Klass info as it has been built
271     if (klassh->GetKlassFromFunc(func) != nullptr) {
272         klassh->GetKlassFromFunc(func)->DelMethod(*func);
273     }
274 }
275 
CallGraph(MIRModule & m,MemPool & memPool,MemPool & templPool,KlassHierarchy & kh,const std::string & fn)276 CallGraph::CallGraph(MIRModule &m, MemPool &memPool, MemPool &templPool, KlassHierarchy &kh, const std::string &fn)
277     : AnalysisResult(&memPool),
278       mirModule(&m),
279       cgAlloc(&memPool),
280       tempAlloc(&templPool),
281       mirBuilder(cgAlloc.GetMemPool()->New<MIRBuilder>(&m)),
282       entryNode(nullptr),
283       rootNodes(cgAlloc.Adapter()),
284       fileName(fn, &memPool),
285       klassh(&kh),
286       nodesMap(cgAlloc.Adapter()),
287       sccTopologicalVec(cgAlloc.Adapter()),
288       localConstValueMap(tempAlloc.Adapter()),
289       icallToFix(tempAlloc.Adapter()),
290       addressTakenPuidxs(tempAlloc.Adapter()),
291       numOfNodes(0)
292 {
293 }
294 
GetCallType(Opcode op) const295 CallType CallGraph::GetCallType(Opcode op) const
296 {
297     CallType typeTemp = kCallTypeInvalid;
298     switch (op) {
299         case OP_call:
300         case OP_callassigned:
301             typeTemp = kCallTypeCall;
302             break;
303         case OP_virtualcall:
304         case OP_virtualcallassigned:
305             typeTemp = kCallTypeVirtualCall;
306             break;
307         case OP_superclasscall:
308         case OP_superclasscallassigned:
309             typeTemp = kCallTypeSuperCall;
310             break;
311         case OP_interfacecall:
312         case OP_interfacecallassigned:
313             typeTemp = kCallTypeInterfaceCall;
314             break;
315         case OP_icall:
316         case OP_icallassigned:
317             typeTemp = kCallTypeIcall;
318             break;
319         case OP_intrinsiccall:
320         case OP_intrinsiccallassigned:
321             typeTemp = kCallTypeIntrinsicCall;
322             break;
323         case OP_xintrinsiccall:
324         case OP_xintrinsiccallassigned:
325             typeTemp = kCallTypeXinitrinsicCall;
326             break;
327         case OP_intrinsiccallwithtype:
328         case OP_intrinsiccallwithtypeassigned:
329             typeTemp = kCallTypeIntrinsicCallWithType;
330             break;
331         case OP_customcall:
332         case OP_customcallassigned:
333             typeTemp = kCallTypeCustomCall;
334             break;
335         case OP_polymorphiccall:
336         case OP_polymorphiccallassigned:
337             typeTemp = kCallTypePolymorphicCall;
338             break;
339         default:
340             break;
341     }
342     return typeTemp;
343 }
344 
GetCGNode(MIRFunction * func) const345 CGNode *CallGraph::GetCGNode(MIRFunction *func) const
346 {
347     if (nodesMap.find(func) != nodesMap.end()) {
348         return nodesMap.at(func);
349     }
350     return nullptr;
351 }
352 
GetCGNode(PUIdx puIdx) const353 CGNode *CallGraph::GetCGNode(PUIdx puIdx) const
354 {
355     return GetCGNode(GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx));
356 }
357 
GetSCCNode(MIRFunction * func) const358 SCCNode<CGNode> *CallGraph::GetSCCNode(MIRFunction *func) const
359 {
360     CGNode *cgNode = GetCGNode(func);
361     return (cgNode != nullptr) ? cgNode->GetSCCNode() : nullptr;
362 }
363 
UpdateCaleeCandidate(PUIdx callerPuIdx,IcallNode * icall,std::set<PUIdx> & candidate)364 void CallGraph::UpdateCaleeCandidate(PUIdx callerPuIdx, IcallNode *icall, std::set<PUIdx> &candidate)
365 {
366     CGNode *caller = GetCGNode(callerPuIdx);
367     for (auto &pair : caller->GetCallee()) {
368         auto *callsite = pair.first;
369         if (callsite->GetCallStmt() == icall) {
370             auto *calleeSet = pair.second;
371             calleeSet->clear();
372             std::for_each(candidate.begin(), candidate.end(), [this, &calleeSet](PUIdx idx) {
373                 CGNode *callee = GetCGNode(idx);
374                 calleeSet->insert(callee);
375             });
376             return;
377         }
378     }
379 }
380 
UpdateCaleeCandidate(PUIdx callerPuIdx,IcallNode * icall,PUIdx calleePuidx,CallNode * call)381 void CallGraph::UpdateCaleeCandidate(PUIdx callerPuIdx, IcallNode *icall, PUIdx calleePuidx, CallNode *call)
382 {
383     CGNode *caller = GetCGNode(callerPuIdx);
384     for (auto &pair : caller->GetCallee()) {
385         auto *callsite = pair.first;
386         if (callsite->GetCallStmt() == icall) {
387             callsite->SetCallStmt(call);
388             auto *calleeSet = pair.second;
389             calleeSet->clear();
390             calleeSet->insert(GetCGNode(calleePuidx));
391         }
392     }
393 }
394 
IsRootNode(MIRFunction * func) const395 bool CallGraph::IsRootNode(MIRFunction *func) const
396 {
397     if (GetCGNode(func) != nullptr) {
398         return (!GetCGNode(func)->HasCaller());
399     } else {
400         return false;
401     }
402 }
403 
GetOrGenCGNode(PUIdx puIdx,bool isVcall,bool isIcall)404 CGNode *CallGraph::GetOrGenCGNode(PUIdx puIdx, bool isVcall, bool isIcall)
405 {
406     CGNode *node = GetCGNode(puIdx);
407     if (node == nullptr) {
408         MIRFunction *mirFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
409         node = cgAlloc.GetMemPool()->New<CGNode>(mirFunc, cgAlloc, numOfNodes++);
410         (void)nodesMap.insert(std::make_pair(mirFunc, node));
411     }
412     if (isVcall && !node->IsVcallCandidatesValid()) {
413         MIRFunction *mirFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
414         Klass *klass = nullptr;
415         CHECK_NULL_FATAL(mirFunc);
416         if (StringUtils::StartsWith(mirFunc->GetBaseClassName(), JARRAY_PREFIX_STR)) {  // Array
417             klass = klassh->GetKlassFromName(namemangler::kJavaLangObjectStr);
418         } else {
419             klass = klassh->GetKlassFromStrIdx(mirFunc->GetBaseClassNameStrIdx());
420         }
421         if (klass == nullptr) {  // Incomplete
422             node->SetVcallCandidatesValid();
423             return node;
424         }
425         // Traverse all subclasses
426         std::vector<Klass *> klassVector;
427         klassVector.push_back(klass);
428         GStrIdx calleeFuncStrIdx = mirFunc->GetBaseFuncNameWithTypeStrIdx();
429         for (Klass *currKlass : klassVector) {
430             const MIRFunction *method = currKlass->GetMethod(calleeFuncStrIdx);
431             if (method != nullptr) {
432                 node->AddVcallCandidate(GetOrGenCGNode(method->GetPuidx()));
433             }
434             // add subclass of currKlass into vector
435             for (Klass *subKlass : currKlass->GetSubKlasses()) {
436                 klassVector.push_back(subKlass);
437             }
438         }
439         if (klass->IsClass() && !klass->GetMIRClassType()->IsAbstract()) {
440             // If klass.foo does not exist, search superclass and find the nearest one
441             // klass.foo does not exist
442             auto &klassMethods = klass->GetMethods();
443             if (std::find(klassMethods.begin(), klassMethods.end(), mirFunc) == klassMethods.end()) {
444                 Klass *superKlass = klass->GetSuperKlass();
445                 while (superKlass != nullptr) {
446                     const MIRFunction *method = superKlass->GetMethod(calleeFuncStrIdx);
447                     if (method != nullptr) {
448                         node->AddVcallCandidate(GetOrGenCGNode(method->GetPuidx()));
449                         break;
450                     }
451                     superKlass = superKlass->GetSuperKlass();
452                 }
453             }
454         }
455         node->SetVcallCandidatesValid();
456     }
457     if (isIcall && !node->IsIcallCandidatesValid()) {
458         Klass *CallerKlass = nullptr;
459         if (StringUtils::StartsWith(CurFunction()->GetBaseClassName(), JARRAY_PREFIX_STR)) {  // Array
460             CallerKlass = klassh->GetKlassFromName(namemangler::kJavaLangObjectStr);
461         } else {
462             CallerKlass = klassh->GetKlassFromStrIdx(CurFunction()->GetBaseClassNameStrIdx());
463         }
464         if (CallerKlass == nullptr) {  // Incomplete
465             CHECK_FATAL(false, "class is incomplete, impossible.");
466             return node;
467         }
468         MIRFunction *mirFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puIdx);
469         Klass *klass = nullptr;
470         if (StringUtils::StartsWith(mirFunc->GetBaseClassName(), JARRAY_PREFIX_STR)) {  // Array
471             klass = klassh->GetKlassFromName(namemangler::kJavaLangObjectStr);
472         } else {
473             klass = klassh->GetKlassFromStrIdx(mirFunc->GetBaseClassNameStrIdx());
474         }
475         if (klass == nullptr) {  // Incomplete
476             node->SetIcallCandidatesValid();
477             return node;
478         }
479         GStrIdx calleeFuncStrIdx = mirFunc->GetBaseFuncNameWithTypeStrIdx();
480         // Traverse all classes which implement the interface
481         for (Klass *implKlass : klass->GetImplKlasses()) {
482             const MIRFunction *method = implKlass->GetMethod(calleeFuncStrIdx);
483             if (method != nullptr) {
484                 node->AddIcallCandidate(GetOrGenCGNode(method->GetPuidx()));
485             } else if (!implKlass->GetMIRClassType()->IsAbstract()) {
486                 // Search in its parent class
487                 Klass *superKlass = implKlass->GetSuperKlass();
488                 while (superKlass != nullptr) {
489                     const MIRFunction *methodT = superKlass->GetMethod(calleeFuncStrIdx);
490                     if (methodT != nullptr) {
491                         node->AddIcallCandidate(GetOrGenCGNode(methodT->GetPuidx()));
492                         break;
493                     }
494                     superKlass = superKlass->GetSuperKlass();
495                 }
496             }
497         }
498         node->SetIcallCandidatesValid();
499     }
500     return node;
501 }
502 
503 // if expr has addroffunc expr as its opnd, store all the addroffunc puidx into result
CollectAddroffuncFromExpr(const BaseNode * expr)504 void CallGraph::CollectAddroffuncFromExpr(const BaseNode *expr)
505 {
506     if (expr->GetOpCode() == OP_addroffunc) {
507         addressTakenPuidxs.insert(static_cast<const AddroffuncNode *>(expr)->GetPUIdx());
508         return;
509     }
510     for (size_t i = 0; i < expr->GetNumOpnds(); ++i) {
511         CollectAddroffuncFromExpr(expr->Opnd(i));
512     }
513 }
514 
CollectAddroffuncFromStmt(const StmtNode * stmt)515 void CallGraph::CollectAddroffuncFromStmt(const StmtNode *stmt)
516 {
517     for (size_t i = 0; i < stmt->NumOpnds(); ++i) {
518         CollectAddroffuncFromExpr(stmt->Opnd(i));
519     }
520 }
521 
CollectAddroffuncFromConst(MIRConst * mirConst)522 void CallGraph::CollectAddroffuncFromConst(MIRConst *mirConst)
523 {
524     if (mirConst->GetKind() == kConstAddrofFunc) {
525         addressTakenPuidxs.insert(static_cast<MIRAddroffuncConst *>(mirConst)->GetValue());
526     } else if (mirConst->GetKind() == kConstAggConst) {
527         auto &constVec = static_cast<MIRAggConst *>(mirConst)->GetConstVec();
528         for (auto &cst : constVec) {
529             CollectAddroffuncFromConst(cst);
530         }
531     }
532 }
533 
RecordLocalConstValue(const StmtNode * stmt)534 void CallGraph::RecordLocalConstValue(const StmtNode *stmt)
535 {
536     if (stmt->GetOpCode() != OP_dassign) {
537         return;
538     }
539     auto *dassign = static_cast<const DassignNode *>(stmt);
540     MIRSymbol *lhs = CurFunction()->GetLocalOrGlobalSymbol(dassign->GetStIdx());
541     if (!lhs->IsLocal() || !lhs->GetAttr(ATTR_const) || dassign->GetFieldID() != 0) {
542         return;
543     }
544     if (localConstValueMap.find(lhs->GetStIdx()) != localConstValueMap.end()) {
545         // Multi def found, put nullptr to indicate that we cannot handle this.
546         localConstValueMap[lhs->GetStIdx()] = nullptr;
547         return;
548     }
549     localConstValueMap[lhs->GetStIdx()] = dassign->GetRHS();
550 }
551 
ReplaceIcallToCall(BlockNode & body,IcallNode * icall,PUIdx newPUIdx)552 CallNode *CallGraph::ReplaceIcallToCall(BlockNode &body, IcallNode *icall, PUIdx newPUIdx)
553 {
554     MapleVector<BaseNode *> opnds(icall->GetNopnd().begin() + 1, icall->GetNopnd().end(),
555                                   CurFunction()->GetCodeMPAllocator().Adapter());
556     CallNode *newCall = nullptr;
557     if (icall->GetOpCode() == OP_icall) {
558         newCall = mirBuilder->CreateStmtCall(newPUIdx, opnds, OP_call);
559     } else if (icall->GetOpCode() == OP_icallassigned) {
560         newCall =
561             mirBuilder->CreateStmtCallAssigned(newPUIdx, opnds, icall->GetCallReturnSymbol(mirBuilder->GetMirModule()),
562                                                OP_callassigned, icall->GetRetTyIdx());
563     } else {
564         CHECK_FATAL(false, "NYI");
565     }
566     body.ReplaceStmt1WithStmt2(icall, newCall);
567     newCall->SetSrcPos(icall->GetSrcPos());
568     if (debugFlag) {
569         icall->Dump(0);
570         newCall->Dump(0);
571         LogInfo::MapleLogger() << "replace icall successfully!\n";
572     }
573     return newCall;
574 }
575 
HandleCall(CGNode & node,StmtNode * stmt,uint32 loopDepth)576 void CallGraph::HandleCall(CGNode &node, StmtNode *stmt, uint32 loopDepth)
577 {
578     PUIdx calleePUIdx = (static_cast<CallNode *>(stmt))->GetPUIdx();
579     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleePUIdx);
580     // Ignore clinit
581     if (!calleeFunc->IsClinit()) {
582         CallInfo *callInfo = GenCallInfo(kCallTypeCall, calleeFunc, stmt, loopDepth, stmt->GetStmtID());
583         CGNode *calleeNode = GetOrGenCGNode(calleeFunc->GetPuidx());
584         DEBUG_ASSERT(calleeNode != nullptr, "calleenode is null");
585         calleeNode->AddCaller(&node, stmt);
586         node.AddCallsite(*callInfo, calleeNode);
587     }
588 }
589 
GetFuncTypeFromFuncAddr(const BaseNode * base)590 MIRType *CallGraph::GetFuncTypeFromFuncAddr(const BaseNode *base)
591 {
592     MIRType *funcType = nullptr;
593     switch (base->GetOpCode()) {
594         case OP_dread: {
595             auto *dread = static_cast<const DreadNode *>(base);
596             const MIRSymbol *st = CurFunction()->GetLocalOrGlobalSymbol(dread->GetStIdx());
597             funcType = st->GetType();
598             if (funcType->IsStructType()) {
599                 funcType = static_cast<MIRStructType *>(funcType)->GetFieldType(dread->GetFieldID());
600             }
601             break;
602         }
603         case OP_iread: {
604             auto *iread = static_cast<const IreadNode *>(base);
605             funcType = iread->GetType();
606             break;
607         }
608         case OP_select: {
609             auto *select = static_cast<const TernaryNode *>(base);
610             funcType = GetFuncTypeFromFuncAddr(select->Opnd(kSecondOpnd));
611             if (funcType == nullptr) {
612                 funcType = GetFuncTypeFromFuncAddr(select->Opnd(kThirdOpnd));
613             }
614             break;
615         }
616         case OP_addroffunc: {
617             auto *funcNode = static_cast<const AddroffuncNode *>(base);
618             auto *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(funcNode->GetPUIdx());
619             funcType = func->GetMIRFuncType();
620             break;
621         }
622         default: {
623             CHECK_FATAL(false, "NYI");
624             break;
625         }
626     }
627     CHECK_FATAL(funcType != nullptr, "Error");
628     return funcType;
629 }
630 
HandleICall(BlockNode & body,CGNode & node,StmtNode * stmt,uint32 loopDepth)631 void CallGraph::HandleICall(BlockNode &body, CGNode &node, StmtNode *stmt, uint32 loopDepth)
632 {
633     IcallNode *icall = static_cast<IcallNode *>(stmt);
634     auto *funcAddr = icall->GetNopndAt(0);
635     MIRType *funcType = nullptr;
636     CHECK_FATAL(IsPrimitivePoint(funcAddr->GetPrimType()), "Error");
637     switch (funcAddr->GetOpCode()) {
638         case OP_dread: {
639             auto *dread = static_cast<DreadNode *>(funcAddr);
640             MIRSymbol *symbol = CurFunction()->GetLocalOrGlobalSymbol(dread->GetStIdx());
641             funcType = symbol->GetType();
642             if (funcType->IsStructType()) {
643                 funcType = static_cast<MIRStructType *>(funcType)->GetFieldType(dread->GetFieldID());
644             }
645             if (symbol->IsGlobal()) {
646                 // Global symbol
647                 if (!symbol->GetAttr(ATTR_const) || !symbol->IsConst()) {
648                     break;
649                 }
650                 if (symbol->GetKonst()->GetKind() == kConstAddrofFunc) {
651                     auto *addrofFuncConst = static_cast<MIRAddroffuncConst *>(symbol->GetKonst());
652                     stmt = ReplaceIcallToCall(body, icall, addrofFuncConst->GetValue());
653                     HandleCall(node, stmt, loopDepth);
654                     return;
655                 }
656                 if (symbol->GetKonst()->GetKind() == kConstAggConst) {
657                     auto *aggConst = static_cast<MIRAggConst *>(symbol->GetKonst());
658                     auto *elem = aggConst->GetAggConstElement(dread->GetFieldID());
659                     if (elem->GetKind() == kConstAddrofFunc) {
660                         auto *addrofFuncConst = static_cast<MIRAddroffuncConst *>(elem);
661                         stmt = ReplaceIcallToCall(body, icall, addrofFuncConst->GetValue());
662                         HandleCall(node, stmt, loopDepth);
663                         return;
664                     }
665                 }
666             } else {
667                 // Local symbol
668                 if (!symbol->GetAttr(ATTR_const)) {
669                     break;
670                 }
671                 if (localConstValueMap.find(symbol->GetStIdx()) != localConstValueMap.end()) {
672                     auto *rhsNode = localConstValueMap[symbol->GetStIdx()];
673                     if (rhsNode != nullptr && rhsNode->GetOpCode() == OP_addroffunc) {
674                         auto *funcNode = static_cast<AddroffuncNode *>(rhsNode);
675                         stmt = ReplaceIcallToCall(body, icall, funcNode->GetPUIdx());
676                         HandleCall(node, stmt, loopDepth);
677                         return;
678                     }
679                 }
680             }
681             break;
682         }
683         case OP_iread: {
684             auto *iread = static_cast<IreadNode *>(funcAddr);
685             funcType = iread->GetType();
686             if (iread->Opnd(0)->GetOpCode() != OP_array) {
687                 break;
688             }
689             auto *arrayNode = static_cast<ArrayNode *>(iread->Opnd(0));
690             if (arrayNode->GetBase()->GetOpCode() != OP_addrof) {
691                 break;
692             }
693             bool hasVarIndex = false;
694             for (size_t i = 1; i < arrayNode->numOpnds; ++i) {
695                 if (!arrayNode->GetNopndAt(i)->IsConstval() ||
696                     static_cast<ConstvalNode *>(arrayNode->GetNopndAt(i))->GetConstVal()->GetKind() != kConstInt) {
697                     hasVarIndex = true;
698                     break;
699                 }
700             }
701             if (hasVarIndex) {
702                 break;
703             }
704             MIRSymbol *symbol =
705                 CurFunction()->GetLocalOrGlobalSymbol(static_cast<AddrofNode *>(arrayNode->GetBase())->GetStIdx());
706             if (symbol->IsGlobal()) {
707                 // Global array.
708                 if (!symbol->GetAttr(ATTR_const) || !symbol->IsConst()) {
709                     break;
710                 }
711                 // Solve multi-dim array.
712                 if (symbol->GetKonst()->GetKind() == kConstAggConst) {
713                     auto *aggConst = static_cast<MIRAggConst *>(symbol->GetKonst());
714                     MIRConst *result = aggConst;
715                     for (size_t i = 1; i < arrayNode->GetNumOpnds(); ++i) {
716                         auto *konst = static_cast<ConstvalNode *>(arrayNode->GetNopndAt(i))->GetConstVal();
717                         auto index = static_cast<MIRIntConst *>(konst)->GetExtValue();
718                         if (result->GetKind() == kConstAggConst) {
719                             result = static_cast<MIRAggConst *>(result)->GetConstVecItem(index);
720                         }
721                     }
722                     CHECK_FATAL(result->GetKind() == kConstAddrofFunc, "Must be");
723                     auto *constValue = static_cast<MIRAddroffuncConst *>(result);
724                     stmt = ReplaceIcallToCall(body, icall, constValue->GetValue());
725                     HandleCall(node, stmt, loopDepth);
726                     return;
727                 }
728             }
729             break;
730         }
731         case OP_select: {
732             auto *select = static_cast<const TernaryNode *>(funcAddr);
733             auto *leftValue = select->Opnd(1);
734             auto *rightValue = select->Opnd(2);
735             if (leftValue->GetOpCode() == OP_addroffunc && rightValue->GetOpCode() == OP_addroffunc) {
736                 auto *funcNode1 = static_cast<AddroffuncNode *>(leftValue);
737                 auto *funcNode2 = static_cast<AddroffuncNode *>(rightValue);
738                 CallInfo *callInfo = GenCallInfo(kCallTypeIcall, nullptr, stmt, loopDepth, stmt->GetStmtID());
739                 CGNode *calleeNode1 = GetOrGenCGNode(funcNode1->GetPUIdx());
740                 CGNode *calleeNode2 = GetOrGenCGNode(funcNode2->GetPUIdx());
741                 CHECK_FATAL(calleeNode1 != nullptr && calleeNode2 != nullptr, "calleenode is null");
742                 auto *cgVector = cgAlloc.GetMemPool()->New<MapleSet<CGNode *, Comparator<CGNode>>>(cgAlloc.Adapter());
743                 cgVector->insert(calleeNode1);
744                 cgVector->insert(calleeNode2);
745                 calleeNode1->AddCaller(&node, stmt);
746                 calleeNode2->AddCaller(&node, stmt);
747                 node.AddCallsite(callInfo, cgVector);
748                 return;
749             } else {
750                 funcType = GetFuncTypeFromFuncAddr(funcAddr);
751             }
752             break;
753         }
754         case OP_addroffunc: {
755             auto *funcNode = static_cast<AddroffuncNode *>(funcAddr);
756             stmt = ReplaceIcallToCall(body, icall, funcNode->GetPUIdx());
757             HandleCall(node, stmt, loopDepth);
758             return;
759         }
760         default: {
761             break;
762         }
763     }
764     if (!Options::wpaa) {
765         // Do not handle icall if Options::wpaa is false.
766         return;
767     }
768     CHECK_FATAL(funcType != nullptr, "Failed to get the function type.");
769     while (funcType != nullptr && funcType->IsMIRPtrType()) {
770         funcType = static_cast<MIRPtrType *>(funcType)->GetPointedType();
771     }
772 
773     // Add a fake callsite here, need to fix it after all the function is visited.
774     CallInfo *callInfo = GenCallInfo(kCallTypeIcall, nullptr, stmt, loopDepth, stmt->GetStmtID());
775     node.AddCallsite(*callInfo, nullptr);
776     if (icallToFix.find(funcType->GetTypeIndex()) == icallToFix.end()) {
777         auto *tempSet = tempAlloc.GetMemPool()->New<MapleSet<Caller2Cands>>(tempAlloc.Adapter());
778         icallToFix.insert({funcType->GetTypeIndex(), tempSet});
779     }
780     CHECK_FATAL(CurFunction()->GetPuidx() == node.GetPuIdx(), "Error");
781     Callsite callSite = {callInfo, node.GetCallee().at(callInfo)};
782     icallToFix.at(funcType->GetTypeIndex())->insert({node.GetPuIdx(), callSite});
783 }
784 
HandleBody(MIRFunction & func,BlockNode & body,CGNode & node,uint32 loopDepth)785 void CallGraph::HandleBody(MIRFunction &func, BlockNode &body, CGNode &node, uint32 loopDepth)
786 {
787     StmtNode *stmtNext = nullptr;
788     for (StmtNode *stmt = body.GetFirst(); stmt != nullptr; stmt = stmtNext) {
789         CollectAddroffuncFromStmt(stmt);
790         RecordLocalConstValue(stmt);
791         stmtNext = static_cast<StmtNode *>(stmt)->GetNext();
792         Opcode op = stmt->GetOpCode();
793         if (op == OP_comment) {
794             continue;
795         } else if (op == OP_doloop) {
796             DoloopNode *doloopNode = static_cast<DoloopNode *>(stmt);
797             HandleBody(func, *doloopNode->GetDoBody(), node, loopDepth + 1);
798         } else if (op == OP_dowhile || op == OP_while) {
799             WhileStmtNode *whileStmt = static_cast<WhileStmtNode *>(stmt);
800             HandleBody(func, *whileStmt->GetBody(), node, loopDepth + 1);
801         } else if (op == OP_if) {
802             IfStmtNode *n = static_cast<IfStmtNode *>(stmt);
803             HandleBody(func, *n->GetThenPart(), node, loopDepth);
804             if (n->GetElsePart() != nullptr) {
805                 HandleBody(func, *n->GetElsePart(), node, loopDepth);
806             }
807         } else {
808             node.IncrStmtCount();
809             CallType ct = GetCallType(op);
810             switch (ct) {
811                 case kCallTypeVirtualCall: {
812                     PUIdx calleePUIdx = (static_cast<CallNode *>(stmt))->GetPUIdx();
813                     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleePUIdx);
814                     CallInfo *callInfo =
815                         GenCallInfo(kCallTypeVirtualCall, calleeFunc, stmt, loopDepth, stmt->GetStmtID());
816                     // Retype makes object type more inaccurate.
817                     StmtNode *stmtPrev = static_cast<StmtNode *>(stmt)->GetPrev();
818                     if (stmtPrev != nullptr && stmtPrev->GetOpCode() == OP_dassign) {
819                         DassignNode *dassignNode = static_cast<DassignNode *>(stmtPrev);
820                         if (dassignNode->GetRHS()->GetOpCode() == OP_retype) {
821                             CallNode *callNode = static_cast<CallNode *>(stmt);
822                             CHECK_FATAL(callNode->Opnd(0)->GetOpCode() == OP_dread, "Must be dread.");
823                             AddrofNode *dread = static_cast<AddrofNode *>(callNode->Opnd(0));
824                             if (dassignNode->GetStIdx() == dread->GetStIdx()) {
825                                 RetypeNode *retypeNode = static_cast<RetypeNode *>(dassignNode->GetRHS());
826                                 CHECK_FATAL(retypeNode->Opnd(0)->GetOpCode() == OP_dread, "Must be dread.");
827                                 AddrofNode *dreadT = static_cast<AddrofNode *>(retypeNode->Opnd(0));
828                                 MIRType *type = func.GetLocalOrGlobalSymbol(dreadT->GetStIdx())->GetType();
829                                 CHECK_FATAL(type->IsMIRPtrType(), "Must be ptr type.");
830                                 MIRPtrType *ptrType = static_cast<MIRPtrType *>(type);
831                                 MIRType *targetType = ptrType->GetPointedType();
832                                 MIRFunction *calleeFuncT =
833                                     GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleePUIdx);
834                                 GStrIdx calleeFuncStrIdx = calleeFuncT->GetBaseFuncNameWithTypeStrIdx();
835                                 Klass *klass = klassh->GetKlassFromTyIdx(targetType->GetTypeIndex());
836                                 if (klass != nullptr) {
837                                     const MIRFunction *method = klass->GetMethod(calleeFuncStrIdx);
838                                     if (method != nullptr) {
839                                         calleePUIdx = method->GetPuidx();
840                                     } else {
841                                         std::string funcName = klass->GetKlassName();
842                                         funcName.append((namemangler::kNameSplitterStr));
843                                         funcName.append(calleeFuncT->GetBaseFuncNameWithType());
844                                         MIRFunction *methodT =
845                                             mirBuilder->GetOrCreateFunction(funcName, static_cast<TyIdx>(PTY_void));
846                                         methodT->SetBaseClassNameStrIdx(klass->GetKlassNameStrIdx());
847                                         methodT->SetBaseFuncNameWithTypeStrIdx(calleeFuncStrIdx);
848                                         calleePUIdx = methodT->GetPuidx();
849                                     }
850                                 }
851                             }
852                         }
853                     }
854                     // Add a call node whether or not the calleeFunc has its body
855                     CGNode *calleeNode = GetOrGenCGNode(calleePUIdx, true);
856                     CHECK_FATAL(calleeNode != nullptr, "calleenode is null");
857                     CHECK_FATAL(calleeNode->IsVcallCandidatesValid(), "vcall candidate must be valid");
858                     node.AddCallsite(callInfo, &calleeNode->GetVcallCandidates());
859                     for (auto &cgIt : calleeNode->GetVcallCandidates()) {
860                         CGNode *calleeNodeT = cgIt;
861                         calleeNodeT->AddCaller(&node, stmt);
862                     }
863                     break;
864                 }
865                 case kCallTypeInterfaceCall: {
866                     PUIdx calleePUIdx = (static_cast<CallNode *>(stmt))->GetPUIdx();
867                     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleePUIdx);
868                     CallInfo *callInfo =
869                         GenCallInfo(kCallTypeInterfaceCall, calleeFunc, stmt, loopDepth, stmt->GetStmtID());
870                     // Add a call node whether or not the calleeFunc has its body
871                     CGNode *calleeNode = GetOrGenCGNode(calleeFunc->GetPuidx(), false, true);
872                     CHECK_FATAL(calleeNode != nullptr, "calleenode is null");
873                     CHECK_FATAL(calleeNode->IsIcallCandidatesValid(), "icall candidate must be valid");
874                     node.AddCallsite(callInfo, &calleeNode->GetIcallCandidates());
875                     for (auto &cgIt : calleeNode->GetIcallCandidates()) {
876                         CGNode *calleeNodeT = cgIt;
877                         calleeNodeT->AddCaller(&node, stmt);
878                     }
879                     break;
880                 }
881                 case kCallTypeIcall: {
882                     if (mirModule->IsCModule()) {
883                         HandleICall(body, node, stmt, loopDepth);
884                     }
885                     break;
886                 }
887                 case kCallTypeCall: {
888                     HandleCall(node, stmt, loopDepth);
889                     break;
890                 }
891                 case kCallTypeSuperCall: {
892                     PUIdx calleePUIdx = (static_cast<CallNode *>(stmt))->GetPUIdx();
893                     MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleePUIdx);
894                     Klass *klass = klassh->GetKlassFromFunc(calleeFunc);
895                     if (klass == nullptr) {  // Fix CI
896                         continue;
897                     }
898                     MapleVector<MIRFunction *> *cands =
899                         klass->GetCandidates(calleeFunc->GetBaseFuncNameWithTypeStrIdx());
900                     // continue to search its implinterfaces
901                     if (cands == nullptr) {
902                         for (Klass *implInterface : klass->GetImplInterfaces()) {
903                             cands = implInterface->GetCandidates(calleeFunc->GetBaseFuncNameWithTypeStrIdx());
904                             if (cands != nullptr && !cands->empty()) {
905                                 break;
906                             }
907                         }
908                     }
909                     if (cands == nullptr || cands->empty()) {
910                         continue;  // Fix CI
911                     }
912                     MIRFunction *actualMirfunc = cands->at(0);
913                     CallInfo *callInfo = GenCallInfo(kCallTypeCall, actualMirfunc, stmt, loopDepth, stmt->GetStmtID());
914                     CGNode *calleeNode = GetOrGenCGNode(actualMirfunc->GetPuidx());
915                     DEBUG_ASSERT(calleeNode != nullptr, "calleenode is null");
916                     calleeNode->AddCaller(&node, stmt);
917                     (static_cast<CallNode *>(stmt))->SetPUIdx(actualMirfunc->GetPuidx());
918                     node.AddCallsite(*callInfo, calleeNode);
919                     break;
920                 }
921                 case kCallTypeIntrinsicCall:
922                 case kCallTypeIntrinsicCallWithType:
923                 case kCallTypeCustomCall:
924                 case kCallTypePolymorphicCall:
925                 case kCallTypeXinitrinsicCall:
926                 case kCallTypeInvalid: {
927                     break;
928                 }
929                 default: {
930                     CHECK_FATAL(false, "NYI::unsupport call type");
931                 }
932             }
933         }
934     }
935 }
936 
UpdateCallGraphNode(CGNode & node)937 void CallGraph::UpdateCallGraphNode(CGNode &node)
938 {
939     node.Reset();
940     MIRFunction *func = node.GetMIRFunction();
941     CHECK_NULL_FATAL(func);
942     BlockNode *body = func->GetBody();
943     HandleBody(*func, *body, node, 0);
944 }
945 
ResetInferredType(std::vector<MIRSymbol * > & inferredSymbols)946 static void ResetInferredType(std::vector<MIRSymbol *> &inferredSymbols)
947 {
948     for (size_t i = 0; i < inferredSymbols.size(); ++i) {
949         inferredSymbols[i]->SetInferredTyIdx(TyIdx());
950     }
951     inferredSymbols.clear();
952 }
953 
ResetInferredType(std::vector<MIRSymbol * > & inferredSymbols,MIRSymbol * symbol)954 static void ResetInferredType(std::vector<MIRSymbol *> &inferredSymbols, MIRSymbol *symbol)
955 {
956     if (symbol == nullptr) {
957         return;
958     }
959     if (symbol->GetInferredTyIdx() == kInitTyIdx || symbol->GetInferredTyIdx() == kNoneTyIdx) {
960         return;
961     }
962     size_t i = 0;
963     for (; i < inferredSymbols.size(); ++i) {
964         if (inferredSymbols[i] == symbol) {
965             symbol->SetInferredTyIdx(TyIdx());
966             inferredSymbols.erase(inferredSymbols.begin() + static_cast<ssize_t>(i));
967             break;
968         }
969     }
970 }
971 
SetInferredType(std::vector<MIRSymbol * > & inferredSymbols,MIRSymbol & symbol,const TyIdx & idx)972 static void SetInferredType(std::vector<MIRSymbol *> &inferredSymbols, MIRSymbol &symbol, const TyIdx &idx)
973 {
974     symbol.SetInferredTyIdx(idx);
975     size_t i = 0;
976     for (; i < inferredSymbols.size(); ++i) {
977         if (inferredSymbols[i] == &symbol) {
978             break;
979         }
980     }
981     if (i == inferredSymbols.size()) {
982         inferredSymbols.push_back(&symbol);
983     }
984 }
985 
SearchDefInClinit(const Klass & klass)986 void IPODevirtulize::SearchDefInClinit(const Klass &klass)
987 {
988     MIRClassType *classType = static_cast<MIRClassType *>(klass.GetMIRStructType());
989     std::vector<MIRSymbol *> staticFinalPrivateSymbols;
990     for (uint32 i = 0; i < classType->GetStaticFields().size(); ++i) {
991         FieldAttrs attribute = classType->GetStaticFieldsPair(i).second.second;
992         if (attribute.GetAttr(FLDATTR_final)) {
993             staticFinalPrivateSymbols.push_back(
994                 GlobalTables::GetGsymTable().GetSymbolFromStrIdx(classType->GetStaticFieldsGStrIdx(i)));
995         }
996     }
997     std::string typeName = klass.GetKlassName();
998     typeName.append(namemangler::kClinitSuffix);
999     GStrIdx clinitFuncGstrIdx =
1000         GlobalTables::GetStrTable().GetStrIdxFromName(namemangler::GetInternalNameLiteral(typeName));
1001     if (clinitFuncGstrIdx == 0u) {
1002         return;
1003     }
1004     MIRFunction *func = GlobalTables::GetGsymTable().GetSymbolFromStrIdx(clinitFuncGstrIdx)->GetFunction();
1005     if (func->GetBody() == nullptr) {
1006         return;
1007     }
1008     StmtNode *stmtNext = nullptr;
1009     std::vector<MIRSymbol *> gcmallocSymbols;
1010     for (StmtNode *stmt = func->GetBody()->GetFirst(); stmt != nullptr; stmt = stmtNext) {
1011         stmtNext = stmt->GetNext();
1012         Opcode op = stmt->GetOpCode();
1013         switch (op) {
1014             case OP_comment:
1015                 break;
1016             case OP_dassign: {
1017                 DassignNode *dassignNode = static_cast<DassignNode *>(stmt);
1018                 MIRSymbol *leftSymbol = func->GetLocalOrGlobalSymbol(dassignNode->GetStIdx());
1019                 size_t i = 0;
1020                 for (; i < staticFinalPrivateSymbols.size(); ++i) {
1021                     if (staticFinalPrivateSymbols[i] == leftSymbol) {
1022                         break;
1023                     }
1024                 }
1025                 if (i < staticFinalPrivateSymbols.size()) {
1026                     if (dassignNode->GetRHS()->GetOpCode() == OP_dread) {
1027                         DreadNode *dreadNode = static_cast<DreadNode *>(dassignNode->GetRHS());
1028                         MIRSymbol *rightSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1029                         if (rightSymbol->GetInferredTyIdx() != kInitTyIdx &&
1030                             rightSymbol->GetInferredTyIdx() != kNoneTyIdx &&
1031                             (staticFinalPrivateSymbols[i]->GetInferredTyIdx() == kInitTyIdx ||
1032                              (staticFinalPrivateSymbols[i]->GetInferredTyIdx() == rightSymbol->GetInferredTyIdx()))) {
1033                             staticFinalPrivateSymbols[i]->SetInferredTyIdx(rightSymbol->GetInferredTyIdx());
1034                         } else {
1035                             staticFinalPrivateSymbols[i]->SetInferredTyIdx(kInitTyIdx);
1036                             staticFinalPrivateSymbols.erase(staticFinalPrivateSymbols.begin() +
1037                                                             static_cast<ssize_t>(i));
1038                         }
1039                     } else {
1040                         staticFinalPrivateSymbols[i]->SetInferredTyIdx(kInitTyIdx);
1041                         staticFinalPrivateSymbols.erase(staticFinalPrivateSymbols.begin() + static_cast<ssize_t>(i));
1042                     }
1043                 } else if (dassignNode->GetRHS()->GetOpCode() == OP_gcmalloc) {
1044                     GCMallocNode *gcmallocNode = static_cast<GCMallocNode *>(dassignNode->GetRHS());
1045                     TyIdx inferredTypeIdx = gcmallocNode->GetTyIdx();
1046                     SetInferredType(gcmallocSymbols, *leftSymbol, inferredTypeIdx);
1047                 } else if (dassignNode->GetRHS()->GetOpCode() == OP_retype) {
1048                     if (dassignNode->GetRHS()->Opnd(0)->GetOpCode() == OP_dread) {
1049                         DreadNode *dreadNode = static_cast<DreadNode *>(dassignNode->GetRHS()->Opnd(0));
1050                         MIRSymbol *rightSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1051                         if (rightSymbol->GetInferredTyIdx() != kInitTyIdx &&
1052                             rightSymbol->GetInferredTyIdx() != kNoneTyIdx) {
1053                             SetInferredType(gcmallocSymbols, *leftSymbol, rightSymbol->GetInferredTyIdx());
1054                         }
1055                     }
1056                 } else {
1057                     ResetInferredType(gcmallocSymbols, leftSymbol);
1058                 }
1059                 break;
1060             }
1061             case OP_call:
1062             case OP_callassigned: {
1063                 CallNode *callNode = static_cast<CallNode *>(stmt);
1064                 MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callNode->GetPUIdx());
1065                 if (calleeFunc->GetName().find(namemangler::kClinitSubStr, 0) != std::string::npos ||
1066                     calleeFunc->GetName().find("MCC_", 0) == 0) {
1067                     // ignore all side effect of initizlizor
1068                     continue;
1069                 }
1070                 for (size_t i = 0; i < callNode->GetReturnVec().size(); ++i) {
1071                     StIdx stIdx = callNode->GetReturnPair(i).first;
1072                     MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(stIdx);
1073                     ResetInferredType(gcmallocSymbols, tmpSymbol);
1074                 }
1075                 for (size_t i = 0; i < callNode->GetNopndSize(); ++i) {
1076                     BaseNode *node = callNode->GetNopndAt(i);
1077                     CHECK_NULL_FATAL(node);
1078                     if (node->GetOpCode() != OP_dread) {
1079                         continue;
1080                     }
1081                     DreadNode *dreadNode = static_cast<DreadNode *>(node);
1082                     MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1083                     ResetInferredType(gcmallocSymbols, tmpSymbol);
1084                 }
1085                 break;
1086             }
1087             case OP_intrinsiccallwithtype: {
1088                 IntrinsiccallNode *callNode = static_cast<IntrinsiccallNode *>(stmt);
1089                 if (callNode->GetIntrinsic() != INTRN_JAVA_CLINIT_CHECK) {
1090                     ResetInferredType(gcmallocSymbols);
1091                 }
1092                 break;
1093             }
1094             default:
1095                 ResetInferredType(gcmallocSymbols);
1096                 break;
1097         }
1098     }
1099 }
1100 
SearchDefInMemberMethods(const Klass & klass)1101 void IPODevirtulize::SearchDefInMemberMethods(const Klass &klass)
1102 {
1103     SearchDefInClinit(klass);
1104     MIRClassType *classType = static_cast<MIRClassType *>(klass.GetMIRStructType());
1105     std::vector<FieldID> finalPrivateFieldID;
1106     for (size_t i = 0; i < classType->GetFieldsSize(); ++i) {
1107         FieldAttrs attribute = classType->GetFieldsElemt(i).second.second;
1108         if (attribute.GetAttr(FLDATTR_final)) {
1109             FieldID id = mirBuilder->GetStructFieldIDFromFieldNameParentFirst(
1110                 classType, GlobalTables::GetStrTable().GetStringFromStrIdx(classType->GetFieldsElemt(i).first));
1111             finalPrivateFieldID.push_back(id);
1112         }
1113     }
1114     std::vector<MIRFunction *> initMethods;
1115     std::string typeName = klass.GetKlassName();
1116     typeName.append(namemangler::kCinitStr);
1117     for (MIRFunction *const &method : klass.GetMethods()) {
1118         if (strncmp(method->GetName().c_str(), typeName.c_str(), typeName.length()) == 0) {
1119             initMethods.push_back(method);
1120         }
1121     }
1122     if (initMethods.empty()) {
1123         return;
1124     }
1125     DEBUG_ASSERT(!initMethods.empty(), "Must have initializor");
1126     StmtNode *stmtNext = nullptr;
1127     for (size_t i = 0; i < initMethods.size(); ++i) {
1128         MIRFunction *func = initMethods[i];
1129         if (func->GetBody() == nullptr) {
1130             continue;
1131         }
1132         std::vector<MIRSymbol *> gcmallocSymbols;
1133         for (StmtNode *stmt = func->GetBody()->GetFirst(); stmt != nullptr; stmt = stmtNext) {
1134             stmtNext = stmt->GetNext();
1135             Opcode op = stmt->GetOpCode();
1136             switch (op) {
1137                 case OP_comment:
1138                     break;
1139                 case OP_dassign: {
1140                     DassignNode *dassignNode = static_cast<DassignNode *>(stmt);
1141                     MIRSymbol *leftSymbol = func->GetLocalOrGlobalSymbol(dassignNode->GetStIdx());
1142                     if (dassignNode->GetRHS()->GetOpCode() == OP_gcmalloc) {
1143                         GCMallocNode *gcmallocNode = static_cast<GCMallocNode *>(dassignNode->GetRHS());
1144                         SetInferredType(gcmallocSymbols, *leftSymbol, gcmallocNode->GetTyIdx());
1145                     } else if (dassignNode->GetRHS()->GetOpCode() == OP_retype) {
1146                         RetypeNode *retyStmt = static_cast<RetypeNode *>(dassignNode->GetRHS());
1147                         BaseNode *fromNode = retyStmt->Opnd(0);
1148                         if (fromNode->GetOpCode() == OP_dread) {
1149                             DreadNode *dreadNode = static_cast<DreadNode *>(fromNode);
1150                             MIRSymbol *fromSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1151                             SetInferredType(gcmallocSymbols, *leftSymbol, fromSymbol->GetInferredTyIdx());
1152                         } else {
1153                             ResetInferredType(gcmallocSymbols, leftSymbol);
1154                         }
1155                     } else {
1156                         ResetInferredType(gcmallocSymbols, leftSymbol);
1157                     }
1158                     break;
1159                 }
1160                 case OP_call:
1161                 case OP_callassigned: {
1162                     CallNode *callNode = static_cast<CallNode *>(stmt);
1163                     MIRFunction *calleeFunc =
1164                         GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callNode->GetPUIdx());
1165                     if (calleeFunc->GetName().find(namemangler::kClinitSubStr, 0) != std::string::npos) {
1166                         // ignore all side effect of initizlizor
1167                         continue;
1168                     }
1169                     for (size_t j = 0; j < callNode->GetReturnVec().size(); ++j) {
1170                         StIdx stIdx = callNode->GetReturnPair(j).first;
1171                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(stIdx);
1172                         ResetInferredType(gcmallocSymbols, tmpSymbol);
1173                     }
1174                     for (size_t j = 0; j < callNode->GetNopndSize(); ++j) {
1175                         BaseNode *node = callNode->GetNopndAt(j);
1176                         if (node->GetOpCode() != OP_dread) {
1177                             continue;
1178                         }
1179                         DreadNode *dreadNode = static_cast<DreadNode *>(node);
1180                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1181                         ResetInferredType(gcmallocSymbols, tmpSymbol);
1182                     }
1183                     break;
1184                 }
1185                 case OP_intrinsiccallwithtype: {
1186                     IntrinsiccallNode *callNode = static_cast<IntrinsiccallNode *>(stmt);
1187                     if (callNode->GetIntrinsic() != INTRN_JAVA_CLINIT_CHECK) {
1188                         ResetInferredType(gcmallocSymbols);
1189                     }
1190                     break;
1191                 }
1192                 case OP_iassign: {
1193                     IassignNode *iassignNode = static_cast<IassignNode *>(stmt);
1194                     MIRType *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(iassignNode->GetTyIdx());
1195                     DEBUG_ASSERT(type->GetKind() == kTypePointer, "Must be pointer type");
1196                     MIRPtrType *pointedType = static_cast<MIRPtrType *>(type);
1197                     if (pointedType->GetPointedTyIdx() == classType->GetTypeIndex()) {
1198                         // set field of current class
1199                         FieldID fieldID = iassignNode->GetFieldID();
1200                         size_t j = 0;
1201                         for (; j < finalPrivateFieldID.size(); ++j) {
1202                             if (finalPrivateFieldID[j] == fieldID) {
1203                                 break;
1204                             }
1205                         }
1206                         if (j < finalPrivateFieldID.size()) {
1207                             if (iassignNode->GetRHS()->GetOpCode() == OP_dread) {
1208                                 DreadNode *dreadNode = static_cast<DreadNode *>(iassignNode->GetRHS());
1209                                 CHECK_FATAL(dreadNode != nullptr, "Impossible");
1210                                 MIRSymbol *rightSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1211                                 if (rightSymbol->GetInferredTyIdx() != kInitTyIdx &&
1212                                     rightSymbol->GetInferredTyIdx() != kNoneTyIdx &&
1213                                     (classType->GetElemInferredTyIdx(fieldID) == kInitTyIdx ||
1214                                      (classType->GetElemInferredTyIdx(fieldID) == rightSymbol->GetInferredTyIdx()))) {
1215                                     classType->SetElemInferredTyIdx(fieldID, rightSymbol->GetInferredTyIdx());
1216                                 } else {
1217                                     classType->SetElemInferredTyIdx(fieldID, kInitTyIdx);
1218                                     finalPrivateFieldID.erase(finalPrivateFieldID.begin() + static_cast<ssize_t>(j));
1219                                 }
1220                             } else {
1221                                 classType->SetElemInferredTyIdx(fieldID, kInitTyIdx);
1222                                 finalPrivateFieldID.erase(finalPrivateFieldID.begin() + static_cast<ssize_t>(j));
1223                             }
1224                         }
1225                     }
1226                     break;
1227                 }
1228                 default:
1229                     ResetInferredType(gcmallocSymbols);
1230                     break;
1231             }
1232         }
1233     }
1234 }
1235 
DoDevirtual(const Klass & klass,const KlassHierarchy & klassh)1236 void DoDevirtual(const Klass &klass, const KlassHierarchy &klassh)
1237 {
1238     MIRClassType *classType = static_cast<MIRClassType *>(klass.GetMIRStructType());
1239     for (auto &func : klass.GetMethods()) {
1240         if (func->GetBody() == nullptr) {
1241             continue;
1242         }
1243         StmtNode *stmtNext = nullptr;
1244         std::vector<MIRSymbol *> inferredSymbols;
1245         for (StmtNode *stmt = func->GetBody()->GetFirst(); stmt != nullptr; stmt = stmtNext) {
1246             stmtNext = stmt->GetNext();
1247             Opcode op = stmt->GetOpCode();
1248             switch (op) {
1249                 case OP_comment:
1250                     CASE_OP_ASSERT_NONNULL
1251                 case OP_brtrue:
1252                 case OP_brfalse:
1253                 case OP_try:
1254                 case OP_endtry:
1255                     break;
1256                 case OP_dassign: {
1257                     DassignNode *dassignNode = static_cast<DassignNode *>(stmt);
1258                     MIRSymbol *leftSymbol = func->GetLocalOrGlobalSymbol(dassignNode->GetStIdx());
1259                     if (dassignNode->GetRHS()->GetOpCode() == OP_dread) {
1260                         DreadNode *dreadNode = static_cast<DreadNode *>(dassignNode->GetRHS());
1261                         if (func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx())->GetInferredTyIdx() != kInitTyIdx) {
1262                             SetInferredType(inferredSymbols, *leftSymbol,
1263                                             func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx())->GetInferredTyIdx());
1264                         }
1265                     } else if (dassignNode->GetRHS()->GetOpCode() == OP_iread) {
1266                         IreadNode *ireadNode = static_cast<IreadNode *>(dassignNode->GetRHS());
1267                         MIRType *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(ireadNode->GetTyIdx());
1268                         DEBUG_ASSERT(type->GetKind() == kTypePointer, "Must be pointer type");
1269                         MIRPtrType *pointedType = static_cast<MIRPtrType *>(type);
1270                         if (pointedType->GetPointedTyIdx() == classType->GetTypeIndex()) {
1271                             FieldID fieldID = ireadNode->GetFieldID();
1272                             FieldID tmpID = fieldID;
1273                             TyIdx tmpTyIdx = classType->GetElemInferredTyIdx(static_cast<size_t>(tmpID));
1274                             if (tmpTyIdx != kInitTyIdx && tmpTyIdx != kNoneTyIdx) {
1275                                 SetInferredType(inferredSymbols, *leftSymbol,
1276                                                 classType->GetElemInferredTyIdx(static_cast<size_t>(fieldID)));
1277                             }
1278                         }
1279                     } else {
1280                         ResetInferredType(inferredSymbols, leftSymbol);
1281                     }
1282                     break;
1283                 }
1284                 case OP_interfacecall:
1285                 case OP_interfacecallassigned:
1286                 case OP_virtualcall:
1287                 case OP_virtualcallassigned: {
1288                     CallNode *calleeNode = static_cast<CallNode *>(stmt);
1289                     MIRFunction *calleeFunc =
1290                         GlobalTables::GetFunctionTable().GetFunctionFromPuidx(calleeNode->GetPUIdx());
1291                     if (calleeNode->GetNopndAt(0)->GetOpCode() == OP_dread) {
1292                         DreadNode *dreadNode = static_cast<DreadNode *>(calleeNode->GetNopndAt(0));
1293                         MIRSymbol *rightSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1294                         if (rightSymbol->GetInferredTyIdx() != kInitTyIdx &&
1295                             rightSymbol->GetInferredTyIdx() != kNoneTyIdx) {
1296                             // Devirtual
1297                             Klass *currKlass = klassh.GetKlassFromTyIdx(rightSymbol->GetInferredTyIdx());
1298                             if (op == OP_interfacecall || op == OP_interfacecallassigned || op == OP_virtualcall ||
1299                                 op == OP_virtualcallassigned) {
1300                                 std::vector<Klass *> klassVector;
1301                                 klassVector.push_back(currKlass);
1302                                 bool hasDevirtualed = false;
1303                                 for (size_t index = 0; index < klassVector.size(); ++index) {
1304                                     Klass *tmpKlass = klassVector[index];
1305                                     for (MIRFunction *const &method : tmpKlass->GetMethods()) {
1306                                         if (calleeFunc->GetBaseFuncNameWithTypeStrIdx() ==
1307                                             method->GetBaseFuncNameWithTypeStrIdx()) {
1308                                             calleeNode->SetPUIdx(method->GetPuidx());
1309                                             if (op == OP_virtualcall || op == OP_interfacecall) {
1310                                                 calleeNode->SetOpCode(OP_call);
1311                                             }
1312                                             if (op == OP_virtualcallassigned || op == OP_interfacecallassigned) {
1313                                                 calleeNode->SetOpCode(OP_callassigned);
1314                                             }
1315                                             hasDevirtualed = true;
1316                                             if (false) {
1317                                                 LogInfo::MapleLogger()
1318                                                     << ("Devirtualize In function:" + func->GetName()) << '\n';
1319                                                 LogInfo::MapleLogger() << calleeNode->GetOpCode() << '\n';
1320                                                 LogInfo::MapleLogger() << "    From:" << calleeFunc->GetName() << '\n';
1321                                                 LogInfo::MapleLogger() << "    To  :"
1322                                                     << GlobalTables::GetFunctionTable().GetFunctionFromPuidx(
1323                                                         calleeNode->GetPUIdx())->GetName()
1324                                                     << '\n';
1325                                             }
1326                                             break;
1327                                         }
1328                                     }
1329                                     if (hasDevirtualed) {
1330                                         break;
1331                                     }
1332                                     // add subclass of currKlass into vecotr
1333                                     for (Klass *superKlass : tmpKlass->GetSuperKlasses()) {
1334                                         klassVector.push_back(superKlass);
1335                                     }
1336                                 }
1337                                 if (hasDevirtualed) {
1338                                     for (size_t i = 0; i < calleeNode->GetNopndSize(); ++i) {
1339                                         BaseNode *node = calleeNode->GetNopndAt(i);
1340                                         CHECK_NULL_FATAL(node);
1341                                         if (node->GetOpCode() != OP_dread) {
1342                                             continue;
1343                                         }
1344                                         dreadNode = static_cast<DreadNode *>(node);
1345                                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1346                                         ResetInferredType(inferredSymbols, tmpSymbol);
1347                                     }
1348                                     if (op == OP_interfacecallassigned || op == OP_virtualcallassigned) {
1349                                         CallNode *callNode = static_cast<CallNode *>(stmt);
1350                                         for (size_t i = 0; i < callNode->GetReturnVec().size(); ++i) {
1351                                             StIdx stIdx = callNode->GetReturnPair(i).first;
1352                                             MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(stIdx);
1353                                             ResetInferredType(inferredSymbols, tmpSymbol);
1354                                         }
1355                                     }
1356                                     break;
1357                                 }
1358                                 // Search default function in interfaces
1359                                 Klass *tmpInterface = nullptr;
1360                                 MIRFunction *tmpMethod = nullptr;
1361                                 for (Klass *iKlass : currKlass->GetImplInterfaces()) {
1362                                     for (MIRFunction *const &method : iKlass->GetMethods()) {
1363                                         if (calleeFunc->GetBaseFuncNameWithTypeStrIdx() ==
1364                                                 method->GetBaseFuncNameWithTypeStrIdx() &&
1365                                             !method->GetFuncAttrs().GetAttr(FUNCATTR_abstract)) {
1366                                             if (tmpInterface == nullptr ||
1367                                                 klassh.IsSuperKlassForInterface(tmpInterface, iKlass)) {
1368                                                 tmpInterface = iKlass;
1369                                                 tmpMethod = method;
1370                                             }
1371                                             break;
1372                                         }
1373                                     }
1374                                 }
1375                                 // Add this check for the thirdparty APP compile
1376                                 if (tmpMethod == nullptr) {
1377                                     if (Options::deferredVisit) {
1378                                         return;
1379                                     }
1380                                     Klass *parentKlass = klassh.GetKlassFromName(calleeFunc->GetBaseClassName());
1381                                     CHECK_FATAL(parentKlass != nullptr, "null ptr check");
1382                                     bool flag = false;
1383                                     if (parentKlass->GetKlassName() == currKlass->GetKlassName()) {
1384                                         flag = true;
1385                                     } else {
1386                                         for (Klass *const &superclass : currKlass->GetSuperKlasses()) {
1387                                             if (parentKlass->GetKlassName() == superclass->GetKlassName()) {
1388                                                 flag = true;
1389                                                 break;
1390                                             }
1391                                         }
1392                                         if (!flag && parentKlass->IsInterface()) {
1393                                             for (Klass *const &implClass : currKlass->GetImplKlasses()) {
1394                                                 if (parentKlass->GetKlassName() == implClass->GetKlassName()) {
1395                                                     flag = true;
1396                                                     break;
1397                                                 }
1398                                             }
1399                                         }
1400                                     }
1401                                     if (!flag) {
1402                                         LogInfo::MapleLogger() << "warning: func " << calleeFunc->GetName()
1403                                                                << " is not found in DeVirtual!" << std::endl;
1404                                         LogInfo::MapleLogger()
1405                                             << "warning: " << calleeFunc->GetBaseClassName() << " is not the parent of "
1406                                             << currKlass->GetKlassName() << std::endl;
1407                                     }
1408                                 }
1409                                 if (tmpMethod == nullptr) {  // SearchWithoutRettype, search only in current class now.
1410                                     MIRType *retType =
1411                                         GlobalTables::GetTypeTable().GetTypeFromTyIdx(calleeFunc->GetReturnTyIdx());
1412                                     Klass *targetKlass = nullptr;
1413                                     bool isCalleeScalar = false;
1414                                     if (retType->GetKind() == kTypePointer && retType->GetPrimType() == PTY_ref) {
1415                                         MIRType *ptrType = (static_cast<MIRPtrType *>(retType))->GetPointedType();
1416                                         targetKlass = klassh.GetKlassFromTyIdx(ptrType->GetTypeIndex());
1417                                     } else if (retType->GetKind() == kTypeScalar) {
1418                                         isCalleeScalar = true;
1419                                     } else {
1420                                         targetKlass = klassh.GetKlassFromTyIdx(retType->GetTypeIndex());
1421                                     }
1422                                     if (targetKlass == nullptr && !isCalleeScalar) {
1423                                         CHECK_FATAL(false, "null ptr check");
1424                                     }
1425                                     Klass *curRetKlass = nullptr;
1426                                     bool isCurrVtabScalar = false;
1427                                     bool isFindMethod = false;
1428                                     for (MIRFunction *const &method : currKlass->GetMethods()) {
1429                                         if (calleeFunc->GetBaseFuncSigStrIdx() == method->GetBaseFuncSigStrIdx()) {
1430                                             Klass *tmpKlass = nullptr;
1431                                             MIRType *tmpType =
1432                                                 GlobalTables::GetTypeTable().GetTypeFromTyIdx(method->GetReturnTyIdx());
1433                                             if (tmpType->GetKind() == kTypePointer &&
1434                                                 tmpType->GetPrimType() == PTY_ref) {
1435                                                 MIRType *ptrType =
1436                                                     (static_cast<MIRPtrType *>(tmpType))->GetPointedType();
1437                                                 tmpKlass = klassh.GetKlassFromTyIdx(ptrType->GetTypeIndex());
1438                                             } else if (tmpType->GetKind() == kTypeScalar) {
1439                                                 isCurrVtabScalar = true;
1440                                             } else {
1441                                                 tmpKlass = klassh.GetKlassFromTyIdx(tmpType->GetTypeIndex());
1442                                             }
1443                                             if (tmpKlass == nullptr && !isCurrVtabScalar) {
1444                                                 CHECK_FATAL(false, "null ptr check");
1445                                             }
1446                                             if (isCalleeScalar || isCurrVtabScalar) {
1447                                                 if (isFindMethod) {
1448                                                     LogInfo::MapleLogger()
1449                                                         << "warning: this " << currKlass->GetKlassName()
1450                                                         << " has mult methods with the same function name but with "
1451                                                            "different return type!"
1452                                                         << std::endl;
1453                                                     break;
1454                                                 }
1455                                                 tmpMethod = method;
1456                                                 isFindMethod = true;
1457                                                 continue;
1458                                             }
1459                                             if (targetKlass->IsClass() && klassh.IsSuperKlass(tmpKlass, targetKlass) &&
1460                                                 (curRetKlass == nullptr ||
1461                                                  klassh.IsSuperKlass(curRetKlass, tmpKlass))) {
1462                                                 curRetKlass = tmpKlass;
1463                                                 tmpMethod = method;
1464                                             }
1465                                             if (targetKlass->IsClass() &&
1466                                                 klassh.IsInterfaceImplemented(tmpKlass, targetKlass)) {
1467                                                 tmpMethod = method;
1468                                                 break;
1469                                             }
1470                                             if (!targetKlass->IsClass()) {
1471                                                 CHECK_FATAL(tmpKlass != nullptr, "Klass null ptr check");
1472                                                 if (tmpKlass->IsClass() &&
1473                                                     klassh.IsInterfaceImplemented(targetKlass, tmpKlass) &&
1474                                                     (curRetKlass == nullptr ||
1475                                                      klassh.IsSuperKlass(curRetKlass, tmpKlass))) {
1476                                                     curRetKlass = tmpKlass;
1477                                                     tmpMethod = method;
1478                                                 }
1479                                                 if (!tmpKlass->IsClass() &&
1480                                                     klassh.IsSuperKlassForInterface(tmpKlass, targetKlass) &&
1481                                                     (curRetKlass == nullptr ||
1482                                                      klassh.IsSuperKlass(curRetKlass, tmpKlass))) {
1483                                                     curRetKlass = tmpKlass;
1484                                                     tmpMethod = method;
1485                                                 }
1486                                             }
1487                                         }
1488                                     }
1489                                 }
1490                                 if (tmpMethod == nullptr && (currKlass->IsClass() || currKlass->IsInterface())) {
1491                                     LogInfo::MapleLogger() << "warning: func " << calleeFunc->GetName()
1492                                                            << " is not found in DeVirtual!" << std::endl;
1493                                     stmt->SetOpCode(OP_callassigned);
1494                                     break;
1495                                 } else if (tmpMethod == nullptr) {
1496                                     LogInfo::MapleLogger()
1497                                         << "Error: func " << calleeFunc->GetName() << " is not found!" << std::endl;
1498                                     DEBUG_ASSERT(tmpMethod != nullptr, "Must not be null");
1499                                 }
1500                                 calleeNode->SetPUIdx(tmpMethod->GetPuidx());
1501                                 if (op == OP_virtualcall || op == OP_interfacecall) {
1502                                     calleeNode->SetOpCode(OP_call);
1503                                 }
1504                                 if (op == OP_virtualcallassigned || op == OP_interfacecallassigned) {
1505                                     calleeNode->SetOpCode(OP_callassigned);
1506                                 }
1507                                 if (false) {
1508                                     LogInfo::MapleLogger() << ("Devirtualize In function:" + func->GetName()) << '\n';
1509                                     LogInfo::MapleLogger() << calleeNode->GetOpCode() << '\n';
1510                                     LogInfo::MapleLogger() << "    From:" << calleeFunc->GetName() << '\n';
1511                                     LogInfo::MapleLogger() << "    To  :" << GlobalTables::GetFunctionTable()
1512                                         .GetFunctionFromPuidx(calleeNode->GetPUIdx())->GetName() << '\n';
1513                                 }
1514                                 for (size_t i = 0; i < calleeNode->GetNopndSize(); ++i) {
1515                                     BaseNode *node = calleeNode->GetNopndAt(i);
1516                                     if (node->GetOpCode() != OP_dread) {
1517                                         continue;
1518                                     }
1519                                     dreadNode = static_cast<DreadNode *>(node);
1520                                     MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1521                                     ResetInferredType(inferredSymbols, tmpSymbol);
1522                                 }
1523                                 if (op == OP_interfacecallassigned || op == OP_virtualcallassigned) {
1524                                     CallNode *callNode = static_cast<CallNode *>(stmt);
1525                                     for (size_t i = 0; i < callNode->GetReturnVec().size(); ++i) {
1526                                         StIdx stIdx = callNode->GetReturnPair(i).first;
1527                                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(stIdx);
1528                                         ResetInferredType(inferredSymbols, tmpSymbol);
1529                                     }
1530                                 }
1531                                 break;
1532                             }
1533                         }
1534                     }
1535                 }
1536                     [[clang::fallthrough]];
1537                 case OP_call:
1538                 case OP_callassigned: {
1539                     CallNode *callNode = static_cast<CallNode *>(stmt);
1540                     for (size_t i = 0; i < callNode->GetReturnVec().size(); ++i) {
1541                         StIdx stIdx = callNode->GetReturnPair(i).first;
1542                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(stIdx);
1543                         ResetInferredType(inferredSymbols, tmpSymbol);
1544                     }
1545                     for (size_t i = 0; i < callNode->GetNopndSize(); ++i) {
1546                         BaseNode *node = callNode->GetNopndAt(i);
1547                         if (node->GetOpCode() != OP_dread) {
1548                             continue;
1549                         }
1550                         DreadNode *dreadNode = static_cast<DreadNode *>(node);
1551                         MIRSymbol *tmpSymbol = func->GetLocalOrGlobalSymbol(dreadNode->GetStIdx());
1552                         ResetInferredType(inferredSymbols, tmpSymbol);
1553                     }
1554                     break;
1555                 }
1556                 default:
1557                     ResetInferredType(inferredSymbols);
1558                     break;
1559             }
1560         }
1561     }
1562 }
1563 
DevirtualFinal()1564 void IPODevirtulize::DevirtualFinal()
1565 {
1566     // Search all klass in order to find final variables
1567     MapleMap<GStrIdx, Klass *>::const_iterator it = klassh->GetKlasses().begin();
1568     for (; it != klassh->GetKlasses().end(); ++it) {
1569         Klass *klass = it->second;
1570         if (klass->IsClass()) {
1571             MIRClassType *classType = static_cast<MIRClassType *>(klass->GetMIRStructType());
1572             // Initialize inferred type of member fileds as kInitTyidx
1573             for (size_t i = 0; i < classType->GetFieldsSize(); ++i) {  // Don't include parent's field
1574                 classType->SetElemInferredTyIdx(i, kInitTyIdx);
1575             }
1576             SearchDefInMemberMethods(*klass);
1577             for (size_t i = 0; i < classType->GetFieldInferredTyIdx().size(); ++i) {
1578                 if (classType->GetElemInferredTyIdx(i) != kInitTyIdx &&
1579                     classType->GetElemInferredTyIdx(i) != kNoneTyIdx && debugFlag) {
1580                     FieldID tmpID = static_cast<FieldID>(i);
1581                     FieldPair pair = classType->TraverseToFieldRef(tmpID);
1582                     LogInfo::MapleLogger() << ("Inferred Final Private None-Static Variable:" + klass->GetKlassName() +
1583                                                ":" + GlobalTables::GetStrTable().GetStringFromStrIdx(pair.first))
1584                                            << '\n';
1585                 }
1586             }
1587             for (size_t i = 0; i < classType->GetStaticFields().size(); ++i) {
1588                 FieldAttrs attribute = classType->GetStaticFieldsPair(i).second.second;
1589                 if (GlobalTables::GetGsymTable().GetSymbolFromStrIdx(classType->GetStaticFieldsGStrIdx(i)) == nullptr) {
1590                     continue;
1591                 }
1592                 TyIdx tyIdx = GlobalTables::GetGsymTable().GetSymbolFromStrIdx(
1593                     classType->GetStaticFieldsPair(i).first)->GetInferredTyIdx();
1594                 if (tyIdx != kInitTyIdx && tyIdx != kNoneTyIdx) {
1595                     CHECK_FATAL(attribute.GetAttr(FLDATTR_final), "Must be final private");
1596                     if (debugFlag) {
1597                         LogInfo::MapleLogger() << ("Final Private Static Variable:" +
1598                             GlobalTables::GetStrTable().GetStringFromStrIdx(classType->GetStaticFieldsPair(i).first))
1599                             << '\n';
1600                     }
1601                 }
1602             }
1603             DoDevirtual(*klass, *GetKlassh());
1604         }
1605     }
1606 }
1607 
ReadCallGraphFromMplt()1608 void CallGraph::ReadCallGraphFromMplt()
1609 {
1610     // Read existing call graph from mplt, std::map<PUIdx, std::vector<CallInfo*> >
1611     // caller_PUIdx and all call site info are needed. Rebuild all other info of CGNode using CHA
1612     for (auto const &it : mirModule->GetMethod2TargetMap()) {
1613         CGNode *node = GetOrGenCGNode(it.first);
1614         CHECK_FATAL(node != nullptr, "node is null");
1615         std::vector<CallInfo *> callees = it.second;
1616         for (auto itInner = callees.begin(); itInner != callees.end(); ++itInner) {
1617             CallInfo *info = *itInner;
1618             CGNode *calleeNode =
1619                 GetOrGenCGNode(info->GetFunc()->GetPuidx(), info->GetCallType() == kCallTypeVirtualCall,
1620                                info->GetCallType() == kCallTypeInterfaceCall);
1621             CHECK_FATAL(calleeNode != nullptr, "calleeNode is null");
1622             if (info->GetCallType() == kCallTypeVirtualCall) {
1623                 node->AddCallsite(*itInner, &calleeNode->GetVcallCandidates());
1624             } else if (info->GetCallType() == kCallTypeInterfaceCall) {
1625                 node->AddCallsite(*itInner, &calleeNode->GetIcallCandidates());
1626             } else if (info->GetCallType() == kCallTypeCall) {
1627                 node->AddCallsite(**itInner, calleeNode);
1628             } else if (info->GetCallType() == kCallTypeSuperCall) {
1629                 const MIRFunction *calleeFunc = info->GetFunc();
1630                 Klass *klass = klassh->GetKlassFromFunc(calleeFunc);
1631                 if (klass == nullptr) {  // Fix CI
1632                     continue;
1633                 }
1634                 MapleVector<MIRFunction *> *cands = klass->GetCandidates(calleeFunc->GetBaseFuncNameWithTypeStrIdx());
1635                 // continue to search its implinterfaces
1636                 if (cands == nullptr) {
1637                     for (Klass *implInterface : klass->GetImplInterfaces()) {
1638                         cands = implInterface->GetCandidates(calleeFunc->GetBaseFuncNameWithTypeStrIdx());
1639                         if (cands != nullptr && !cands->empty()) {
1640                             break;
1641                         }
1642                     }
1643                 }
1644                 if (cands == nullptr || cands->empty()) {
1645                     continue;  // Fix CI
1646                 }
1647                 MIRFunction *actualMirFunc = cands->at(0);
1648                 CGNode *tempNode = GetOrGenCGNode(actualMirFunc->GetPuidx());
1649                 DEBUG_ASSERT(tempNode != nullptr, "calleenode is null in CallGraph::HandleBody");
1650                 node->AddCallsite(*info, tempNode);
1651             }
1652             for (auto &callSite : node->GetCallee()) {
1653                 if (callSite.first == info) {
1654                     for (auto &cgIt : *callSite.second) {
1655                         CGNode *tempNode = cgIt;
1656                         tempNode->AddCaller(node, info->GetCallStmt());
1657                     }
1658                     break;
1659                 }
1660             }
1661         }
1662     }
1663 }
1664 
GetMatchedCGNode(TyIdx idx,std::vector<CGNode * > & result)1665 void CallGraph::GetMatchedCGNode(TyIdx idx, std::vector<CGNode *> &result)
1666 {
1667     auto *funcType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(idx);
1668     for (auto puidx : addressTakenPuidxs) {
1669         auto *func = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(puidx);
1670         if (func == nullptr) {
1671             continue;
1672         }
1673         if (func->GetMIRFuncType()->CompatibleWith(*funcType)) {
1674             result.push_back(GetOrGenCGNode(puidx));
1675         }
1676     }
1677 }
1678 
FixIcallCallee()1679 void CallGraph::FixIcallCallee()
1680 {
1681     for (auto &pair : icallToFix) {
1682         std::vector<CGNode *> funcs;
1683         GetMatchedCGNode(pair.first, funcs);
1684         for (auto &callerCallee : *pair.second) {
1685             auto puidx = callerCallee.first;
1686             auto candidate = callerCallee.second;
1687             CHECK_FATAL(candidate.second->empty(), "Error");
1688             candidate.second->insert(funcs.begin(), funcs.end());
1689             StmtNode *stmt = candidate.first->GetCallStmt();
1690             std::for_each(funcs.begin(), funcs.end(), [puidx, stmt, this](CGNode *elem) {
1691                 CGNode *callerNode = GetOrGenCGNode(puidx);
1692                 elem->AddCaller(callerNode, stmt);
1693             });
1694         }
1695     }
1696 }
1697 
FindRootNodes()1698 void CallGraph::FindRootNodes()
1699 {
1700     if (!rootNodes.empty()) {
1701         CHECK_FATAL(false, "rootNodes has already been set");
1702     }
1703     for (auto const &it : nodesMap) {
1704         CGNode *node = it.second;
1705         if (!node->HasCaller()) {
1706             rootNodes.push_back(node);
1707         }
1708     }
1709 }
1710 
RemoveFileStaticRootNodes()1711 void CallGraph::RemoveFileStaticRootNodes()
1712 {
1713     std::vector<CGNode *> staticRoots;
1714     std::copy_if(
1715         rootNodes.begin(), rootNodes.end(), std::inserter(staticRoots, staticRoots.begin()), [](const CGNode *root) {
1716             // root means no caller, we should also make sure that root is not be used in addroffunc
1717             auto mirFunc = root->GetMIRFunction();
1718             return root != nullptr &&
1719                    mirFunc != nullptr &&  // remove before
1720                                           // if static functions or inline but not extern modified functions are not
1721                                           // used anymore, they can be removed safely.
1722                    !root->IsAddrTaken() && (mirFunc->IsStatic() || (mirFunc->IsInline() && !mirFunc->IsExtern()));
1723         });
1724     for (auto *root : staticRoots) {
1725         // DFS delete root and its callee that is static and have no caller after root is deleted
1726         DelNode(*root);
1727     }
1728     ClearFunctionList();
1729     // rebuild rootNodes
1730     rootNodes.clear();
1731     FindRootNodes();
1732 }
1733 
RemoveFileStaticSCC()1734 void CallGraph::RemoveFileStaticSCC()
1735 {
1736     for (size_t idx = 0; idx < sccTopologicalVec.size();) {
1737         SCCNode<CGNode> *sccNode = sccTopologicalVec[idx];
1738         if (sccNode->HasInScc() || sccNode == nullptr) {
1739             ++idx;
1740             continue;
1741         }
1742         bool canBeDel = true;
1743         for (auto *node : sccNode->GetNodes()) {
1744             // If the function is not static, it may be referred in other module;
1745             // If the function is taken address, we should deal with this situation conservatively,
1746             // because we are not sure whether the func pointer may escape from this SCC
1747             if (!node->GetMIRFunction()->IsStatic() || node->IsAddrTaken()) {
1748                 canBeDel = false;
1749                 break;
1750             }
1751         }
1752         if (canBeDel) {
1753             sccTopologicalVec.erase(sccTopologicalVec.begin() + static_cast<ssize_t>(idx));
1754             for (auto *calleeSCC : sccNode->GetOutScc()) {
1755                 calleeSCC->RemoveInScc(sccNode);
1756             }
1757             for (auto *cgnode : sccNode->GetNodes()) {
1758                 DelNode(*cgnode);
1759             }
1760             // this sccnode is deleted from sccTopologicalVec, so we don't inc idx here
1761             continue;
1762         }
1763         ++idx;
1764     }
1765     ClearFunctionList();
1766 }
1767 
Dump() const1768 void CallGraph::Dump() const
1769 {
1770     for (auto const &it : nodesMap) {
1771         CGNode *node = it.second;
1772         node->DumpDetail();
1773     }
1774 }
1775 
1776 // Sort CGNode within an SCC. Try best to arrange callee appears before
1777 // its (direct) caller, so that caller can benefit from summary info.
1778 // If we have iterative inter-procedure analysis, then would not bother
1779 // do this.
CGNodeCompare(CGNode * left,CGNode * right)1780 static bool CGNodeCompare(CGNode *left, CGNode *right)
1781 {
1782     // special case: left calls right and right calls left, then compare by id
1783     if (left->IsCalleeOf(right) && right->IsCalleeOf(left)) {
1784         return left->GetID() < right->GetID();
1785     }
1786     // left is right's direct callee, then make left appears first
1787     if (left->IsCalleeOf(right)) {
1788         return true;
1789     } else if (right->IsCalleeOf(left)) {
1790         return false;
1791     }
1792     return left->GetID() < right->GetID();
1793 }
1794 
1795 // Set compilation order as the bottom-up order of callgraph. So callee
1796 // is always compiled before caller. This benifits thoses optimizations
1797 // need interprocedure information like escape analysis.
SetCompilationFunclist() const1798 void CallGraph::SetCompilationFunclist() const
1799 {
1800     mirModule->GetFunctionList().clear();
1801     const MapleVector<SCCNode<CGNode> *> &sccTopVec = GetSCCTopVec();
1802     for (MapleVector<SCCNode<CGNode> *>::const_reverse_iterator it = sccTopVec.rbegin(); it != sccTopVec.rend(); ++it) {
1803         SCCNode<CGNode> *sccNode = *it;
1804         std::sort(sccNode->GetNodes().begin(), sccNode->GetNodes().end(), CGNodeCompare);
1805         for (auto const kIt : sccNode->GetNodes()) {
1806             CGNode *node = kIt;
1807             MIRFunction *func = node->GetMIRFunction();
1808             if ((func != nullptr && func->GetBody() != nullptr && !IsInIPA()) ||
1809                 (func != nullptr && !func->IsNative())) {
1810                 mirModule->GetFunctionList().push_back(func);
1811             }
1812         }
1813     }
1814 }
1815 
AddCandsForCallNode(const KlassHierarchy & kh)1816 void CGNode::AddCandsForCallNode(const KlassHierarchy &kh)
1817 {
1818     // already set vcall candidates information
1819     if (HasSetVCallCandidates()) {
1820         return;
1821     }
1822     CHECK_NULL_FATAL(mirFunc);
1823     Klass *klass = kh.GetKlassFromFunc(mirFunc);
1824     if (klass != nullptr) {
1825         MapleVector<MIRFunction *> *vec = klass->GetCandidates(mirFunc->GetBaseFuncNameWithTypeStrIdx());
1826         if (vec != nullptr) {
1827             vcallCands = *vec;  // Vector copy
1828         }
1829     }
1830 }
1831 
HasOneCandidate() const1832 MIRFunction *CGNode::HasOneCandidate() const
1833 {
1834     int count = 0;
1835     MIRFunction *cand = nullptr;
1836     if (!mirFunc->IsEmpty()) {
1837         ++count;
1838         cand = mirFunc;
1839     }
1840     // scan candidates
1841     for (size_t i = 0; i < vcallCands.size(); ++i) {
1842         if (vcallCands[i] == nullptr) {
1843             CHECK_FATAL(false, "must not be nullptr");
1844         }
1845         if (!vcallCands[i]->IsEmpty()) {
1846             ++count;
1847             if (cand == nullptr) {
1848                 cand = vcallCands[i];
1849             }
1850         }
1851     }
1852     return (count == 1) ? cand : nullptr;
1853 }
1854 
1855 }  // namespace maple
1856