• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "class_hierarchy.h"
17 #include <iostream>
18 #include <fstream>
19 #include "option.h"
20 
21 // Class Hierarchy Analysis
22 // This phase is a foundation phase of compilation. This phase build
23 // the class hierarchy for both this module and all modules it depends
24 // on. So many phases rely on this phase's analysis result, such as
25 // call graph, ssa and so on.
26 // The main procedure shows as following.
27 // A. Based on the information read from mplts, it collects all class that
28 //    declared in modules. And creates a Klass for each class.
29 // B. Fill class method info. Connect superclass<->subclass and
30 //    interface->implementation edges.
31 // C. Tag All Throwable class and its child class.
32 // D. In the case of "class C implements B; interface B extends A;",
33 //    we need to add a link between C and A. So we recursively traverse
34 //    Klass and collect all interfaces it implements.
35 // E. Topological Sort
36 // F. Based on Topological Sort Order, for each virtual method in a class,
37 //    we collect all its potential implementation. If the number of
38 //    potential implementations is 1, it means all virtual calls to this
39 //    method can be easily devirtualized.
40 namespace maple {
41 bool KlassHierarchy::traceFlag = false;
42 
IsSystemPreloadedClass(const std::string & className)43 bool IsSystemPreloadedClass(const std::string &className)
44 {
45     return false;
46 }
47 
Klass(MIRStructType * type,MapleAllocator * alc)48 Klass::Klass(MIRStructType *type, MapleAllocator *alc)
49     : structType(type),
50       alloc(alc),
51       superKlasses(alloc->Adapter()),
52       subKlasses(alloc->Adapter()),
53       implKlasses(alloc->Adapter()),
54       implInterfaces(alloc->Adapter()),
55       methods(alloc->Adapter()),
56       strIdx2Method(alloc->Adapter()),
57       strIdx2CandidateMap(alloc->Adapter())
58 {
59     DEBUG_ASSERT(type != nullptr, "type is nullptr in Klass::Klass!");
60     DEBUG_ASSERT(type->GetKind() == kTypeClass || type->GetKind() == kTypeInterface || type->IsIncomplete(),
61                  "runtime check error");
62 }
63 
DumpKlassMethods() const64 void Klass::DumpKlassMethods() const
65 {
66     if (methods.empty()) {
67         return;
68     }
69     LogInfo::MapleLogger() << "   class member methods:\n";
70     for (MIRFunction *method : methods) {
71         LogInfo::MapleLogger() << "   \t" << method->GetName() << " , ";
72         method->GetFuncSymbol()->GetAttrs().DumpAttributes();
73         LogInfo::MapleLogger() << "\n";
74     }
75     for (const auto &m2cPair : strIdx2CandidateMap) {
76         LogInfo::MapleLogger() << "   \t" << GlobalTables::GetStrTable().GetStringFromStrIdx(m2cPair.first)
77                                << "   , # of target:" << m2cPair.second->size() << "\n";
78     }
79 }
80 
DumpKlassImplKlasses() const81 void Klass::DumpKlassImplKlasses() const
82 {
83     if (implKlasses.empty()) {
84         return;
85     }
86     LogInfo::MapleLogger() << "  implemented by:\n";
87     for (Klass *implKlass : implKlasses) {
88         LogInfo::MapleLogger() << "   \t@implbyclass_idx " << implKlass->structType->GetTypeIndex() << "\n";
89     }
90 }
91 
DumpKlassImplInterfaces() const92 void Klass::DumpKlassImplInterfaces() const
93 {
94     if (implInterfaces.empty()) {
95         return;
96     }
97     LogInfo::MapleLogger() << "  implements:\n";
98     for (Klass *interface : implInterfaces) {
99         LogInfo::MapleLogger() << "   \t@implinterface_idx " << interface->structType->GetTypeIndex() << "\n";
100     }
101 }
102 
DumpKlassSuperKlasses() const103 void Klass::DumpKlassSuperKlasses() const
104 {
105     if (superKlasses.empty()) {
106         return;
107     }
108     LogInfo::MapleLogger() << "   superclasses:\n";
109     for (Klass *superKlass : superKlasses) {
110         LogInfo::MapleLogger() << "   \t@superclass_idx " << superKlass->structType->GetTypeIndex() << "\n";
111     }
112 }
113 
DumpKlassSubKlasses() const114 void Klass::DumpKlassSubKlasses() const
115 {
116     if (subKlasses.empty()) {
117         return;
118     }
119     LogInfo::MapleLogger() << "   subclasses:\n";
120     for (Klass *subKlass : subKlasses) {
121         LogInfo::MapleLogger() << "   \t@subclass_idx " << subKlass->structType->GetTypeIndex() << "\n";
122     }
123 }
124 
Dump() const125 void Klass::Dump() const
126 {
127     // Dump detailed class info
128     LogInfo::MapleLogger() << "class \" " << GetKlassName() << " \" @class_id " << structType->GetTypeIndex() << "\n";
129     DumpKlassSuperKlasses();
130     DumpKlassSubKlasses();
131     DumpKlassImplInterfaces();
132     DumpKlassImplKlasses();
133     DumpKlassMethods();
134 }
135 
GetClosestMethod(GStrIdx funcName) const136 MIRFunction *Klass::GetClosestMethod(GStrIdx funcName) const
137 {
138     MapleVector<MIRFunction *> *cands = GetCandidates(funcName);
139     if (cands != nullptr && !cands->empty()) {
140         return cands->at(0);
141     }
142     return nullptr;
143 }
144 
DelMethod(const MIRFunction & func)145 void Klass::DelMethod(const MIRFunction &func)
146 {
147     if (GetMethod(func.GetBaseFuncNameWithTypeStrIdx()) == &func) {
148         strIdx2Method.erase(func.GetBaseFuncNameWithTypeStrIdx());
149     }
150     for (auto it = methods.begin(); it != methods.end(); ++it) {
151         if (*it == &func) {
152             methods.erase(it--);
153             return;
154         }
155     }
156 }
157 
158 // This for class only, which only has 0 or 1 super class
GetSuperKlass() const159 Klass *Klass::GetSuperKlass() const
160 {
161     switch (superKlasses.size()) {
162         case 0:
163             return nullptr;
164         case 1:
165             return *superKlasses.begin();
166         default:
167             LogInfo::MapleLogger() << GetKlassName() << "\n";
168             CHECK_FATAL(false, "GetSuperKlass expects less than one super class");
169     }
170 }
171 
IsKlassMethod(const MIRFunction * func) const172 bool Klass::IsKlassMethod(const MIRFunction *func) const
173 {
174     if (func == nullptr) {
175         return false;
176     }
177     for (MIRFunction *method : methods) {
178         if (method == func) {
179             return true;
180         }
181     }
182     return false;
183 }
184 
ImplementsKlass() const185 bool Klass::ImplementsKlass() const
186 {
187     if (IsInterface() || IsInterfaceIncomplete()) {
188         return false;
189     }
190     MIRClassType *classType = GetMIRClassType();
191     DEBUG_ASSERT(classType != nullptr, "null ptr check");
192     return (!classType->GetInterfaceImplemented().empty());
193 }
194 
GetCandidates(GStrIdx mnameNoklassStrIdx) const195 MapleVector<MIRFunction *> *Klass::GetCandidates(GStrIdx mnameNoklassStrIdx) const
196 {
197     auto it = strIdx2CandidateMap.find(mnameNoklassStrIdx);
198     return ((it != strIdx2CandidateMap.end()) ? (it->second) : nullptr);
199 }
200 
GetUniqueMethod(GStrIdx mnameNoklassStrIdx) const201 MIRFunction *Klass::GetUniqueMethod(GStrIdx mnameNoklassStrIdx) const
202 {
203     if (structType->IsIncomplete()) {
204         return nullptr;
205     }
206     auto it = strIdx2CandidateMap.find(mnameNoklassStrIdx);
207     if (it != strIdx2CandidateMap.end()) {
208         MapleVector<MIRFunction *> *fv = it->second;
209         if (fv != nullptr && fv->size() == 1) {
210             return fv->at(0);
211         }
212     }
213     return nullptr;
214 }
215 
IsVirtualMethod(const MIRFunction & func) const216 bool Klass::IsVirtualMethod(const MIRFunction &func) const
217 {
218     // May add other checking conditions in future
219     return (func.GetAttr(FUNCATTR_virtual) && !func.GetAttr(FUNCATTR_private) && !func.GetAttr(FUNCATTR_abstract));
220 }
221 
CountVirtMethBottomUp()222 void Klass::CountVirtMethBottomUp()
223 {
224     MapleVector<MIRFunction *> *vec;
225     GStrIdx strIdx;
226     for (Klass *subKlass : subKlasses) {
227         CHECK_FATAL(subKlass != nullptr, "nullptr check failed");
228         for (const auto &pair : subKlass->strIdx2CandidateMap) {
229             strIdx = pair.first;
230             if (strIdx2CandidateMap.find(strIdx) == strIdx2CandidateMap.end()) {
231                 continue;
232             }
233             vec = strIdx2CandidateMap[strIdx];
234             MapleVector<MIRFunction *> *subv = pair.second;
235             if (!vec->empty() && !subv->empty() && vec->at(0) == subv->at(0)) {
236                 // If this class and subclass share the same default implementation,
237                 // then we have to avoid duplicated counting.
238                 vec->insert(vec->end(), subv->begin() + 1, subv->end());
239             } else {
240                 vec->insert(vec->end(), subv->begin(), subv->end());
241             }
242         }
243     }
244 }
245 
HasMethod(const std::string & funcname) const246 const MIRFunction *Klass::HasMethod(const std::string &funcname) const
247 {
248     for (auto *method : methods) {
249         if (method->GetBaseFuncNameWithType().find(funcname) != std::string::npos) {
250             return method;
251         }
252     }
253     return nullptr;
254 }
255 
GetKlassFromStrIdx(GStrIdx strIdx) const256 Klass *KlassHierarchy::GetKlassFromStrIdx(GStrIdx strIdx) const
257 {
258     auto it = strIdx2KlassMap.find(strIdx);
259     return ((it != strIdx2KlassMap.end()) ? (it->second) : nullptr);
260 }
261 
GetKlassFromTyIdx(TyIdx tyIdx) const262 Klass *KlassHierarchy::GetKlassFromTyIdx(TyIdx tyIdx) const
263 {
264     MIRType *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(tyIdx);
265     return (type != nullptr ? GetKlassFromStrIdx(type->GetNameStrIdx()) : nullptr);
266 }
267 
GetKlassFromFunc(const MIRFunction * func) const268 Klass *KlassHierarchy::GetKlassFromFunc(const MIRFunction *func) const
269 {
270     return (func != nullptr ? GetKlassFromStrIdx(func->GetBaseClassNameStrIdx()) : nullptr);
271 }
272 
GetKlassFromName(const std::string & name) const273 Klass *KlassHierarchy::GetKlassFromName(const std::string &name) const
274 {
275     return GetKlassFromStrIdx(GlobalTables::GetStrTable().GetStrIdxFromName(name));
276 }
277 
GetKlassFromLiteral(const std::string & name) const278 Klass *KlassHierarchy::GetKlassFromLiteral(const std::string &name) const
279 {
280     return GetKlassFromStrIdx(GlobalTables::GetStrTable().GetStrIdxFromName(name));
281 }
282 
283 // check if super is a superclass of base
284 // 1/0/-1: true/false/unknown
IsSuperKlass(TyIdx superTyIdx,TyIdx baseTyIdx) const285 int KlassHierarchy::IsSuperKlass(TyIdx superTyIdx, TyIdx baseTyIdx) const
286 {
287     if (superTyIdx == 0u || baseTyIdx == 0u) {
288         return -1;
289     }
290     if (superTyIdx == baseTyIdx) {
291         return 1;
292     }
293     Klass *super = GetKlassFromTyIdx(superTyIdx);
294     Klass *base = GetKlassFromTyIdx(baseTyIdx);
295     if (super == nullptr || base == nullptr) {
296         return -1;
297     }
298     while (base != nullptr) {
299         if (base == super) {
300             return 1;
301         }
302         base = base->GetSuperKlass();
303     }
304     return 0;
305 }
306 
IsSuperKlass(const Klass * super,const Klass * base) const307 bool KlassHierarchy::IsSuperKlass(const Klass *super, const Klass *base) const
308 {
309     if (super == nullptr || base == nullptr || base->IsInterface()) {
310         return false;
311     }
312     while (base != nullptr) {
313         if (base == super) {
314             return true;
315         }
316         base = base->GetSuperKlass();
317     }
318     return false;
319 }
320 
321 // Interface
IsSuperKlassForInterface(const Klass * super,const Klass * base) const322 bool KlassHierarchy::IsSuperKlassForInterface(const Klass *super, const Klass *base) const
323 {
324     if (super == nullptr || base == nullptr) {
325         return false;
326     }
327     if (!super->IsInterface() || !base->IsInterface()) {
328         return false;
329     }
330     std::vector<const Klass *> tmpVector;
331     tmpVector.push_back(base);
332     for (size_t idx = 0; idx < tmpVector.size(); ++idx) {
333         if (tmpVector[idx] == super) {
334             return true;
335         }
336         for (const Klass *superKlass : tmpVector[idx]->GetSuperKlasses()) {
337             tmpVector.push_back(superKlass);
338         }
339     }
340     return false;
341 }
342 
IsInterfaceImplemented(Klass * interface,const Klass * base) const343 bool KlassHierarchy::IsInterfaceImplemented(Klass *interface, const Klass *base) const
344 {
345     if (interface == nullptr || base == nullptr) {
346         return false;
347     }
348     if (!interface->IsInterface() || !base->IsClass()) {
349         return false;
350     }
351     // All the implemented interfaces and their parent interfaces
352     // are directly stored in this set, so no need to look up super
353     return (base->GetImplInterfaces().find(interface) != base->GetImplInterfaces().end());
354 }
355 
GetFieldIDOffsetBetweenClasses(const Klass & super,const Klass & base) const356 int KlassHierarchy::GetFieldIDOffsetBetweenClasses(const Klass &super, const Klass &base) const
357 {
358     int offset = 0;
359     const Klass *superPtr = &super;
360     const Klass *basePtr = &base;
361     while (basePtr != superPtr) {
362         basePtr = basePtr->GetSuperKlass();
363         CHECK_FATAL(basePtr != nullptr, "null ptr check");
364         offset++;
365     }
366     return offset;
367 }
368 
UpdateFieldID(TyIdx baseTypeIdx,TyIdx targetTypeIdx,FieldID & fldID) const369 bool KlassHierarchy::UpdateFieldID(TyIdx baseTypeIdx, TyIdx targetTypeIdx, FieldID &fldID) const
370 {
371     MIRType *baseType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(baseTypeIdx);
372     MIRType *targetType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(targetTypeIdx);
373     if (baseType->GetKind() == kTypePointer && targetType->GetKind() == kTypePointer) {
374         baseType = static_cast<const MIRPtrType *>(baseType)->GetPointedType();
375         targetType = static_cast<const MIRPtrType *>(targetType)->GetPointedType();
376     }
377     if (baseType->GetKind() != kTypeClass || targetType->GetKind() != kTypeClass) {
378         return false;
379     }
380     Klass *baseKlass = GetKlassFromTyIdx(baseType->GetTypeIndex());
381     DEBUG_ASSERT(baseKlass != nullptr, "null ptr check");
382     Klass *targetKlass = GetKlassFromTyIdx(targetType->GetTypeIndex());
383     DEBUG_ASSERT(targetKlass != nullptr, "null ptr check");
384     if (IsSuperKlass(baseKlass, targetKlass)) {
385         fldID += GetFieldIDOffsetBetweenClasses(*baseKlass, *targetKlass);
386         return true;
387     } else if (IsSuperKlass(targetKlass, baseKlass)) {
388         fldID -= GetFieldIDOffsetBetweenClasses(*targetKlass, *baseKlass);
389         return true;
390     }
391     return false;
392 }
393 
NeedClinitCheckRecursively(const Klass & kl) const394 bool KlassHierarchy::NeedClinitCheckRecursively(const Klass &kl) const
395 {
396     if (kl.HasFlag(kClassRuntimeVerify)) {
397         return true;
398     }
399     const Klass *klass = &kl;
400     if (klass->IsClass()) {
401         while (klass != nullptr) {
402             if (klass->GetClinit()) {
403                 return true;
404             }
405             klass = klass->GetSuperKlass();
406         }
407         for (Klass *implInterface : kl.GetImplInterfaces()) {
408             if (implInterface->GetClinit()) {
409                 for (auto &func : implInterface->GetMethods()) {
410                     if (!func->GetAttr(FUNCATTR_abstract) && !func->GetAttr(FUNCATTR_static)) {
411                         return true;
412                     }
413                 }
414             }
415         }
416         return false;
417     }
418     if (klass->IsInterface()) {
419         return klass->GetClinit();
420     }
421     return true;
422 }
423 
424 // Get lowest common ancestor for two classes
GetLCA(Klass * klass1,Klass * klass2) const425 Klass *KlassHierarchy::GetLCA(Klass *klass1, Klass *klass2) const
426 {
427     std::vector<Klass *> v1, v2;
428     while (klass1 != nullptr) {
429         v1.push_back(klass1);
430         klass1 = klass1->GetSuperKlass();
431     }
432     while (klass2 != nullptr) {
433         v2.push_back(klass2);
434         klass2 = klass2->GetSuperKlass();
435     }
436     Klass *result = nullptr;
437     size_t size1 = v1.size();
438     size_t size2 = v2.size();
439     size_t min = (size1 < size2) ? size1 : size2;
440     for (size_t i = 1; i <= min; ++i) {
441         CHECK_FATAL(size1 > 0, "must not be zero");
442         if (v1[size1 - i] != v2[size2 - i]) {
443             break;
444         }
445         result = v1[size1 - i];
446     }
447     return result;
448 }
449 
GetLCA(TyIdx ty1,TyIdx ty2) const450 TyIdx KlassHierarchy::GetLCA(TyIdx ty1, TyIdx ty2) const
451 {
452     Klass *result = GetLCA(GetKlassFromTyIdx(ty1), GetKlassFromTyIdx(ty2));
453     return (result != nullptr ? result->GetTypeIdx() : TyIdx(0));
454 }
455 
GetLCA(GStrIdx str1,GStrIdx str2) const456 GStrIdx KlassHierarchy::GetLCA(GStrIdx str1, GStrIdx str2) const
457 {
458     Klass *result = GetLCA(GetKlassFromStrIdx(str1), GetKlassFromStrIdx(str2));
459     return (result != nullptr ? result->GetKlassNameStrIdx() : GStrIdx(0));
460 }
461 
GetLCA(const std::string & name1,const std::string & name2) const462 const std::string &KlassHierarchy::GetLCA(const std::string &name1, const std::string &name2) const
463 {
464     Klass *result = GetLCA(GetKlassFromName(name1), GetKlassFromName(name2));
465     return (result != nullptr ? result->GetKlassName() : GlobalTables::GetStrTable().GetStringFromStrIdx(GStrIdx(0)));
466 }
467 
AddKlasses()468 void KlassHierarchy::AddKlasses()
469 {
470     for (MIRType *type : GlobalTables::GetTypeTable().GetTypeTable()) {
471 #if DEBUG
472         if (type != nullptr) {
473             MIRTypeKind kd = type->GetKind();
474             if (kd == kTypeStructIncomplete || kd == kTypeClassIncomplete || kd == kTypeInterfaceIncomplete)
475                 LogInfo::MapleLogger() << "Warning: KlassHierarchy::AddKlasses "
476                                        << GlobalTables::GetStrTable().GetStringFromStrIdx(type->GetNameStrIdx())
477                                        << " INCOMPLETE \n";
478         }
479 #endif
480         if (Options::deferredVisit2 && type && (type->IsIncomplete())) {
481             GStrIdx stridx = type->GetNameStrIdx();
482             std::string strName = GlobalTables::GetStrTable().GetStringFromStrIdx(stridx);
483 #if DEBUG
484             LogInfo::MapleLogger() << "Waring: " << strName << " INCOMPLETE \n";
485 #endif
486             if (strName == namemangler::kClassMetadataTypeName) {
487                 continue;
488             }
489         } else if (type == nullptr || (type->GetKind() != kTypeClass && type->GetKind() != kTypeInterface)) {
490             continue;
491         }
492         GStrIdx strIdx = type->GetNameStrIdx();
493         Klass *klass = GetKlassFromStrIdx(strIdx);
494         if (klass != nullptr) {
495             continue;
496         }
497         auto *stype = static_cast<MIRStructType *>(type);
498         klass = GetMempool()->New<Klass>(stype, &alloc);
499         strIdx2KlassMap[strIdx] = klass;
500     }
501 }
502 
ExceptionFlagProp(Klass & klass)503 void KlassHierarchy::ExceptionFlagProp(Klass &klass)
504 {
505     klass.SetExceptionKlass();
506     for (Klass *subClass : klass.GetSubKlasses()) {
507         DEBUG_ASSERT(subClass != nullptr, "null ptr check!");
508         subClass->SetExceptionKlass();
509         ExceptionFlagProp(*subClass);
510     }
511 }
512 
CollectImplInterfaces(const Klass & klass,std::set<Klass * > & implInterfaceSet)513 static void CollectImplInterfaces(const Klass &klass, std::set<Klass *> &implInterfaceSet)
514 {
515     for (Klass *superKlass : klass.GetSuperKlasses()) {
516         if (implInterfaceSet.find(superKlass) == implInterfaceSet.end()) {
517             DEBUG_ASSERT(superKlass != nullptr, "null ptr check!");
518             if (superKlass->IsInterface()) {
519                 implInterfaceSet.insert(superKlass);
520             }
521             CollectImplInterfaces(*superKlass, implInterfaceSet);
522         }
523     }
524     for (Klass *interfaceKlass : klass.GetImplInterfaces()) {
525         if (implInterfaceSet.find(interfaceKlass) == implInterfaceSet.end()) {
526             implInterfaceSet.insert(interfaceKlass);
527             DEBUG_ASSERT(interfaceKlass != nullptr, "null ptr check!");
528             CollectImplInterfaces(*interfaceKlass, implInterfaceSet);
529         }
530     }
531 }
532 
UpdateImplementedInterfaces()533 void KlassHierarchy::UpdateImplementedInterfaces()
534 {
535     for (const auto &pair : strIdx2KlassMap) {
536         Klass *klass = pair.second;
537         DEBUG_ASSERT(klass != nullptr, "null ptr check");
538         if (!klass->IsInterface()) {
539             std::set<Klass *> implInterfaceSet;
540             CollectImplInterfaces(*klass, implInterfaceSet);
541             for (auto interface : implInterfaceSet) {
542                 // Add missing parent interface to class link
543                 interface->AddImplKlass(klass);
544                 klass->AddImplInterface(interface);
545             }
546         }
547     }
548 }
549 
GetParentKlasses(const Klass & klass,std::vector<Klass * > & parentKlasses) const550 void KlassHierarchy::GetParentKlasses(const Klass &klass, std::vector<Klass *> &parentKlasses) const
551 {
552     for (Klass *superKlass : klass.GetSuperKlasses()) {
553         parentKlasses.push_back(superKlass);
554     }
555     if (!klass.IsInterface()) {
556         for (Klass *iklass : klass.GetImplInterfaces()) {
557             parentKlasses.push_back(iklass);
558         }
559     }
560 }
561 
GetChildKlasses(const Klass & klass,std::vector<Klass * > & childKlasses) const562 void KlassHierarchy::GetChildKlasses(const Klass &klass, std::vector<Klass *> &childKlasses) const
563 {
564     for (Klass *subKlass : klass.GetSubKlasses()) {
565         childKlasses.push_back(subKlass);
566     }
567     if (klass.IsInterface()) {
568         for (Klass *implKlass : klass.GetImplKlasses()) {
569             childKlasses.push_back(implKlass);
570         }
571     }
572 }
573 
TopologicalSortKlasses()574 void KlassHierarchy::TopologicalSortKlasses()
575 {
576     std::set<Klass *> inQueue;  // Local variable, no need to use MapleSet
577     for (const auto &pair : strIdx2KlassMap) {
578         Klass *klass = pair.second;
579         DEBUG_ASSERT(klass != nullptr, "klass can not be nullptr");
580         if (!klass->HasSuperKlass() && !klass->ImplementsKlass()) {
581             topoWorkList.push_back(klass);
582             inQueue.insert(klass);
583         }
584     }
585     // Top-down iterates all nodes
586     for (size_t i = 0; i < topoWorkList.size(); ++i) {
587         Klass *klass = topoWorkList[i];
588         std::vector<Klass *> childklasses;
589         DEBUG_ASSERT(klass != nullptr, "null ptr check!");
590         GetChildKlasses(*klass, childklasses);
591         for (Klass *childklass : childklasses) {
592             if (inQueue.find(childklass) == inQueue.end()) {
593                 // callee has not been visited
594                 bool parentKlassAllVisited = true;
595                 std::vector<Klass *> parentKlasses;
596                 DEBUG_ASSERT(childklass != nullptr, "null ptr check!");
597                 GetParentKlasses(*childklass, parentKlasses);
598                 // Check whether all callers of the current callee have been visited
599                 for (Klass *parentKlass : parentKlasses) {
600                     if (inQueue.find(parentKlass) == inQueue.end()) {
601                         parentKlassAllVisited = false;
602                         break;
603                     }
604                 }
605                 if (parentKlassAllVisited) {
606                     topoWorkList.push_back(childklass);
607                     inQueue.insert(childklass);
608                 }
609             }
610         }
611     }
612 }
613 
CountVirtualMethods() const614 void KlassHierarchy::CountVirtualMethods() const
615 {
616     // Bottom-up iterates all klass nodes
617     for (size_t i = topoWorkList.size(); i != 0; --i) {
618         topoWorkList[i - 1]->CountVirtMethBottomUp();
619     }
620 }
621 
AddClassFlag(const std::string & name,uint32 flag)622 Klass *KlassHierarchy::AddClassFlag(const std::string &name, uint32 flag)
623 {
624     Klass *refKlass = GetKlassFromLiteral(name);
625     if (refKlass != nullptr) {
626         refKlass->SetFlag(flag);
627     }
628     return refKlass;
629 }
630 
Dump() const631 void KlassHierarchy::Dump() const
632 {
633     for (Klass *klass : topoWorkList) {
634         klass->Dump();
635     }
636 }
637 
GetUniqueMethod(GStrIdx vfuncNameStridx) const638 GStrIdx KlassHierarchy::GetUniqueMethod(GStrIdx vfuncNameStridx) const
639 {
640     auto it = vfunc2RfuncMap.find(vfuncNameStridx);
641     return (it != vfunc2RfuncMap.end() ? it->second : GStrIdx(0));
642 }
643 
IsDevirtualListEmpty() const644 bool KlassHierarchy::IsDevirtualListEmpty() const
645 {
646     return vfunc2RfuncMap.empty();
647 }
648 
DumpDevirtualList(const std::string & outputFileName) const649 void KlassHierarchy::DumpDevirtualList(const std::string &outputFileName) const
650 {
651     std::unordered_map<std::string, std::string> funcMap;
652     for (Klass *klass : topoWorkList) {
653         for (MIRFunction *func : klass->GetMethods()) {
654             MIRFunction *uniqCand = klass->GetUniqueMethod(func->GetBaseFuncNameWithTypeStrIdx());
655             if (uniqCand != nullptr) {
656                 funcMap[func->GetName()] = uniqCand->GetName();
657             }
658         }
659     }
660     std::ofstream outputFile;
661     outputFile.open(outputFileName);
662     for (auto it : funcMap) {
663         outputFile << it.first << "\t" << it.second << "\n";
664     }
665     outputFile.close();
666 }
667 
ReadDevirtualList(const std::string & inputFileName)668 void KlassHierarchy::ReadDevirtualList(const std::string &inputFileName)
669 {
670     std::ifstream inputFile;
671     inputFile.open(inputFileName);
672     std::string vfuncName;
673     std::string rfuncName;
674     while (inputFile >> vfuncName >> rfuncName) {
675         vfunc2RfuncMap[GlobalTables::GetStrTable().GetOrCreateStrIdxFromName(vfuncName)] =
676             GlobalTables::GetStrTable().GetOrCreateStrIdxFromName(rfuncName);
677     }
678     inputFile.close();
679 }
680 
BuildHierarchy()681 void KlassHierarchy::BuildHierarchy()
682 {
683     // Scan class list and generate Klass without method information
684     AddKlasses();
685     // In the case of "class C implements B; interface B extends A;",
686     // we need to add a link between C and A.
687     UpdateImplementedInterfaces();
688     TopologicalSortKlasses();
689     // Use --dump-devirtual-list=... to dump a closed-wolrd analysis result into a file
690     if (!Options::dumpDevirtualList.empty()) {
691         DumpDevirtualList(Options::dumpDevirtualList);
692     }
693     // Use --read-devirtual-list=... to read in a closed-world analysis result
694     if (!Options::readDevirtualList.empty()) {
695         ReadDevirtualList(Options::readDevirtualList);
696     }
697 }
698 
KlassHierarchy(MIRModule * mirmodule,MemPool * memPool)699 KlassHierarchy::KlassHierarchy(MIRModule *mirmodule, MemPool *memPool)
700     : AnalysisResult(memPool),
701       alloc(memPool),
702       mirModule(mirmodule),
703       strIdx2KlassMap(std::less<GStrIdx>(), alloc.Adapter()),
704       vfunc2RfuncMap(std::less<GStrIdx>(), alloc.Adapter()),
705       topoWorkList(alloc.Adapter())
706 {
707 }
708 }  // namespace maple
709