1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "class_hierarchy.h"
17 #include <iostream>
18 #include <fstream>
19 #include "option.h"
20
21 // Class Hierarchy Analysis
22 // This phase is a foundation phase of compilation. This phase build
23 // the class hierarchy for both this module and all modules it depends
24 // on. So many phases rely on this phase's analysis result, such as
25 // call graph, ssa and so on.
26 // The main procedure shows as following.
27 // A. Based on the information read from mplts, it collects all class that
28 // declared in modules. And creates a Klass for each class.
29 // B. Fill class method info. Connect superclass<->subclass and
30 // interface->implementation edges.
31 // C. Tag All Throwable class and its child class.
32 // D. In the case of "class C implements B; interface B extends A;",
33 // we need to add a link between C and A. So we recursively traverse
34 // Klass and collect all interfaces it implements.
35 // E. Topological Sort
36 // F. Based on Topological Sort Order, for each virtual method in a class,
37 // we collect all its potential implementation. If the number of
38 // potential implementations is 1, it means all virtual calls to this
39 // method can be easily devirtualized.
40 namespace maple {
41 bool KlassHierarchy::traceFlag = false;
42
IsSystemPreloadedClass(const std::string & className)43 bool IsSystemPreloadedClass(const std::string &className)
44 {
45 return false;
46 }
47
Klass(MIRStructType * type,MapleAllocator * alc)48 Klass::Klass(MIRStructType *type, MapleAllocator *alc)
49 : structType(type),
50 alloc(alc),
51 superKlasses(alloc->Adapter()),
52 subKlasses(alloc->Adapter()),
53 implKlasses(alloc->Adapter()),
54 implInterfaces(alloc->Adapter()),
55 methods(alloc->Adapter()),
56 strIdx2Method(alloc->Adapter()),
57 strIdx2CandidateMap(alloc->Adapter())
58 {
59 DEBUG_ASSERT(type != nullptr, "type is nullptr in Klass::Klass!");
60 DEBUG_ASSERT(type->GetKind() == kTypeClass || type->GetKind() == kTypeInterface || type->IsIncomplete(),
61 "runtime check error");
62 }
63
DumpKlassMethods() const64 void Klass::DumpKlassMethods() const
65 {
66 if (methods.empty()) {
67 return;
68 }
69 LogInfo::MapleLogger() << " class member methods:\n";
70 for (MIRFunction *method : methods) {
71 LogInfo::MapleLogger() << " \t" << method->GetName() << " , ";
72 method->GetFuncSymbol()->GetAttrs().DumpAttributes();
73 LogInfo::MapleLogger() << "\n";
74 }
75 for (const auto &m2cPair : strIdx2CandidateMap) {
76 LogInfo::MapleLogger() << " \t" << GlobalTables::GetStrTable().GetStringFromStrIdx(m2cPair.first)
77 << " , # of target:" << m2cPair.second->size() << "\n";
78 }
79 }
80
DumpKlassImplKlasses() const81 void Klass::DumpKlassImplKlasses() const
82 {
83 if (implKlasses.empty()) {
84 return;
85 }
86 LogInfo::MapleLogger() << " implemented by:\n";
87 for (Klass *implKlass : implKlasses) {
88 LogInfo::MapleLogger() << " \t@implbyclass_idx " << implKlass->structType->GetTypeIndex() << "\n";
89 }
90 }
91
DumpKlassImplInterfaces() const92 void Klass::DumpKlassImplInterfaces() const
93 {
94 if (implInterfaces.empty()) {
95 return;
96 }
97 LogInfo::MapleLogger() << " implements:\n";
98 for (Klass *interface : implInterfaces) {
99 LogInfo::MapleLogger() << " \t@implinterface_idx " << interface->structType->GetTypeIndex() << "\n";
100 }
101 }
102
DumpKlassSuperKlasses() const103 void Klass::DumpKlassSuperKlasses() const
104 {
105 if (superKlasses.empty()) {
106 return;
107 }
108 LogInfo::MapleLogger() << " superclasses:\n";
109 for (Klass *superKlass : superKlasses) {
110 LogInfo::MapleLogger() << " \t@superclass_idx " << superKlass->structType->GetTypeIndex() << "\n";
111 }
112 }
113
DumpKlassSubKlasses() const114 void Klass::DumpKlassSubKlasses() const
115 {
116 if (subKlasses.empty()) {
117 return;
118 }
119 LogInfo::MapleLogger() << " subclasses:\n";
120 for (Klass *subKlass : subKlasses) {
121 LogInfo::MapleLogger() << " \t@subclass_idx " << subKlass->structType->GetTypeIndex() << "\n";
122 }
123 }
124
Dump() const125 void Klass::Dump() const
126 {
127 // Dump detailed class info
128 LogInfo::MapleLogger() << "class \" " << GetKlassName() << " \" @class_id " << structType->GetTypeIndex() << "\n";
129 DumpKlassSuperKlasses();
130 DumpKlassSubKlasses();
131 DumpKlassImplInterfaces();
132 DumpKlassImplKlasses();
133 DumpKlassMethods();
134 }
135
GetClosestMethod(GStrIdx funcName) const136 MIRFunction *Klass::GetClosestMethod(GStrIdx funcName) const
137 {
138 MapleVector<MIRFunction *> *cands = GetCandidates(funcName);
139 if (cands != nullptr && !cands->empty()) {
140 return cands->at(0);
141 }
142 return nullptr;
143 }
144
DelMethod(const MIRFunction & func)145 void Klass::DelMethod(const MIRFunction &func)
146 {
147 if (GetMethod(func.GetBaseFuncNameWithTypeStrIdx()) == &func) {
148 strIdx2Method.erase(func.GetBaseFuncNameWithTypeStrIdx());
149 }
150 for (auto it = methods.begin(); it != methods.end(); ++it) {
151 if (*it == &func) {
152 methods.erase(it--);
153 return;
154 }
155 }
156 }
157
158 // This for class only, which only has 0 or 1 super class
GetSuperKlass() const159 Klass *Klass::GetSuperKlass() const
160 {
161 switch (superKlasses.size()) {
162 case 0:
163 return nullptr;
164 case 1:
165 return *superKlasses.begin();
166 default:
167 LogInfo::MapleLogger() << GetKlassName() << "\n";
168 CHECK_FATAL(false, "GetSuperKlass expects less than one super class");
169 }
170 }
171
IsKlassMethod(const MIRFunction * func) const172 bool Klass::IsKlassMethod(const MIRFunction *func) const
173 {
174 if (func == nullptr) {
175 return false;
176 }
177 for (MIRFunction *method : methods) {
178 if (method == func) {
179 return true;
180 }
181 }
182 return false;
183 }
184
ImplementsKlass() const185 bool Klass::ImplementsKlass() const
186 {
187 if (IsInterface() || IsInterfaceIncomplete()) {
188 return false;
189 }
190 MIRClassType *classType = GetMIRClassType();
191 DEBUG_ASSERT(classType != nullptr, "null ptr check");
192 return (!classType->GetInterfaceImplemented().empty());
193 }
194
GetCandidates(GStrIdx mnameNoklassStrIdx) const195 MapleVector<MIRFunction *> *Klass::GetCandidates(GStrIdx mnameNoklassStrIdx) const
196 {
197 auto it = strIdx2CandidateMap.find(mnameNoklassStrIdx);
198 return ((it != strIdx2CandidateMap.end()) ? (it->second) : nullptr);
199 }
200
GetUniqueMethod(GStrIdx mnameNoklassStrIdx) const201 MIRFunction *Klass::GetUniqueMethod(GStrIdx mnameNoklassStrIdx) const
202 {
203 if (structType->IsIncomplete()) {
204 return nullptr;
205 }
206 auto it = strIdx2CandidateMap.find(mnameNoklassStrIdx);
207 if (it != strIdx2CandidateMap.end()) {
208 MapleVector<MIRFunction *> *fv = it->second;
209 if (fv != nullptr && fv->size() == 1) {
210 return fv->at(0);
211 }
212 }
213 return nullptr;
214 }
215
IsVirtualMethod(const MIRFunction & func) const216 bool Klass::IsVirtualMethod(const MIRFunction &func) const
217 {
218 // May add other checking conditions in future
219 return (func.GetAttr(FUNCATTR_virtual) && !func.GetAttr(FUNCATTR_private) && !func.GetAttr(FUNCATTR_abstract));
220 }
221
CountVirtMethBottomUp()222 void Klass::CountVirtMethBottomUp()
223 {
224 MapleVector<MIRFunction *> *vec;
225 GStrIdx strIdx;
226 for (Klass *subKlass : subKlasses) {
227 CHECK_FATAL(subKlass != nullptr, "nullptr check failed");
228 for (const auto &pair : subKlass->strIdx2CandidateMap) {
229 strIdx = pair.first;
230 if (strIdx2CandidateMap.find(strIdx) == strIdx2CandidateMap.end()) {
231 continue;
232 }
233 vec = strIdx2CandidateMap[strIdx];
234 MapleVector<MIRFunction *> *subv = pair.second;
235 if (!vec->empty() && !subv->empty() && vec->at(0) == subv->at(0)) {
236 // If this class and subclass share the same default implementation,
237 // then we have to avoid duplicated counting.
238 vec->insert(vec->end(), subv->begin() + 1, subv->end());
239 } else {
240 vec->insert(vec->end(), subv->begin(), subv->end());
241 }
242 }
243 }
244 }
245
HasMethod(const std::string & funcname) const246 const MIRFunction *Klass::HasMethod(const std::string &funcname) const
247 {
248 for (auto *method : methods) {
249 if (method->GetBaseFuncNameWithType().find(funcname) != std::string::npos) {
250 return method;
251 }
252 }
253 return nullptr;
254 }
255
GetKlassFromStrIdx(GStrIdx strIdx) const256 Klass *KlassHierarchy::GetKlassFromStrIdx(GStrIdx strIdx) const
257 {
258 auto it = strIdx2KlassMap.find(strIdx);
259 return ((it != strIdx2KlassMap.end()) ? (it->second) : nullptr);
260 }
261
GetKlassFromTyIdx(TyIdx tyIdx) const262 Klass *KlassHierarchy::GetKlassFromTyIdx(TyIdx tyIdx) const
263 {
264 MIRType *type = GlobalTables::GetTypeTable().GetTypeFromTyIdx(tyIdx);
265 return (type != nullptr ? GetKlassFromStrIdx(type->GetNameStrIdx()) : nullptr);
266 }
267
GetKlassFromFunc(const MIRFunction * func) const268 Klass *KlassHierarchy::GetKlassFromFunc(const MIRFunction *func) const
269 {
270 return (func != nullptr ? GetKlassFromStrIdx(func->GetBaseClassNameStrIdx()) : nullptr);
271 }
272
GetKlassFromName(const std::string & name) const273 Klass *KlassHierarchy::GetKlassFromName(const std::string &name) const
274 {
275 return GetKlassFromStrIdx(GlobalTables::GetStrTable().GetStrIdxFromName(name));
276 }
277
GetKlassFromLiteral(const std::string & name) const278 Klass *KlassHierarchy::GetKlassFromLiteral(const std::string &name) const
279 {
280 return GetKlassFromStrIdx(GlobalTables::GetStrTable().GetStrIdxFromName(name));
281 }
282
283 // check if super is a superclass of base
284 // 1/0/-1: true/false/unknown
IsSuperKlass(TyIdx superTyIdx,TyIdx baseTyIdx) const285 int KlassHierarchy::IsSuperKlass(TyIdx superTyIdx, TyIdx baseTyIdx) const
286 {
287 if (superTyIdx == 0u || baseTyIdx == 0u) {
288 return -1;
289 }
290 if (superTyIdx == baseTyIdx) {
291 return 1;
292 }
293 Klass *super = GetKlassFromTyIdx(superTyIdx);
294 Klass *base = GetKlassFromTyIdx(baseTyIdx);
295 if (super == nullptr || base == nullptr) {
296 return -1;
297 }
298 while (base != nullptr) {
299 if (base == super) {
300 return 1;
301 }
302 base = base->GetSuperKlass();
303 }
304 return 0;
305 }
306
IsSuperKlass(const Klass * super,const Klass * base) const307 bool KlassHierarchy::IsSuperKlass(const Klass *super, const Klass *base) const
308 {
309 if (super == nullptr || base == nullptr || base->IsInterface()) {
310 return false;
311 }
312 while (base != nullptr) {
313 if (base == super) {
314 return true;
315 }
316 base = base->GetSuperKlass();
317 }
318 return false;
319 }
320
321 // Interface
IsSuperKlassForInterface(const Klass * super,const Klass * base) const322 bool KlassHierarchy::IsSuperKlassForInterface(const Klass *super, const Klass *base) const
323 {
324 if (super == nullptr || base == nullptr) {
325 return false;
326 }
327 if (!super->IsInterface() || !base->IsInterface()) {
328 return false;
329 }
330 std::vector<const Klass *> tmpVector;
331 tmpVector.push_back(base);
332 for (size_t idx = 0; idx < tmpVector.size(); ++idx) {
333 if (tmpVector[idx] == super) {
334 return true;
335 }
336 for (const Klass *superKlass : tmpVector[idx]->GetSuperKlasses()) {
337 tmpVector.push_back(superKlass);
338 }
339 }
340 return false;
341 }
342
IsInterfaceImplemented(Klass * interface,const Klass * base) const343 bool KlassHierarchy::IsInterfaceImplemented(Klass *interface, const Klass *base) const
344 {
345 if (interface == nullptr || base == nullptr) {
346 return false;
347 }
348 if (!interface->IsInterface() || !base->IsClass()) {
349 return false;
350 }
351 // All the implemented interfaces and their parent interfaces
352 // are directly stored in this set, so no need to look up super
353 return (base->GetImplInterfaces().find(interface) != base->GetImplInterfaces().end());
354 }
355
GetFieldIDOffsetBetweenClasses(const Klass & super,const Klass & base) const356 int KlassHierarchy::GetFieldIDOffsetBetweenClasses(const Klass &super, const Klass &base) const
357 {
358 int offset = 0;
359 const Klass *superPtr = &super;
360 const Klass *basePtr = &base;
361 while (basePtr != superPtr) {
362 basePtr = basePtr->GetSuperKlass();
363 CHECK_FATAL(basePtr != nullptr, "null ptr check");
364 offset++;
365 }
366 return offset;
367 }
368
UpdateFieldID(TyIdx baseTypeIdx,TyIdx targetTypeIdx,FieldID & fldID) const369 bool KlassHierarchy::UpdateFieldID(TyIdx baseTypeIdx, TyIdx targetTypeIdx, FieldID &fldID) const
370 {
371 MIRType *baseType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(baseTypeIdx);
372 MIRType *targetType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(targetTypeIdx);
373 if (baseType->GetKind() == kTypePointer && targetType->GetKind() == kTypePointer) {
374 baseType = static_cast<const MIRPtrType *>(baseType)->GetPointedType();
375 targetType = static_cast<const MIRPtrType *>(targetType)->GetPointedType();
376 }
377 if (baseType->GetKind() != kTypeClass || targetType->GetKind() != kTypeClass) {
378 return false;
379 }
380 Klass *baseKlass = GetKlassFromTyIdx(baseType->GetTypeIndex());
381 DEBUG_ASSERT(baseKlass != nullptr, "null ptr check");
382 Klass *targetKlass = GetKlassFromTyIdx(targetType->GetTypeIndex());
383 DEBUG_ASSERT(targetKlass != nullptr, "null ptr check");
384 if (IsSuperKlass(baseKlass, targetKlass)) {
385 fldID += GetFieldIDOffsetBetweenClasses(*baseKlass, *targetKlass);
386 return true;
387 } else if (IsSuperKlass(targetKlass, baseKlass)) {
388 fldID -= GetFieldIDOffsetBetweenClasses(*targetKlass, *baseKlass);
389 return true;
390 }
391 return false;
392 }
393
NeedClinitCheckRecursively(const Klass & kl) const394 bool KlassHierarchy::NeedClinitCheckRecursively(const Klass &kl) const
395 {
396 if (kl.HasFlag(kClassRuntimeVerify)) {
397 return true;
398 }
399 const Klass *klass = &kl;
400 if (klass->IsClass()) {
401 while (klass != nullptr) {
402 if (klass->GetClinit()) {
403 return true;
404 }
405 klass = klass->GetSuperKlass();
406 }
407 for (Klass *implInterface : kl.GetImplInterfaces()) {
408 if (implInterface->GetClinit()) {
409 for (auto &func : implInterface->GetMethods()) {
410 if (!func->GetAttr(FUNCATTR_abstract) && !func->GetAttr(FUNCATTR_static)) {
411 return true;
412 }
413 }
414 }
415 }
416 return false;
417 }
418 if (klass->IsInterface()) {
419 return klass->GetClinit();
420 }
421 return true;
422 }
423
424 // Get lowest common ancestor for two classes
GetLCA(Klass * klass1,Klass * klass2) const425 Klass *KlassHierarchy::GetLCA(Klass *klass1, Klass *klass2) const
426 {
427 std::vector<Klass *> v1, v2;
428 while (klass1 != nullptr) {
429 v1.push_back(klass1);
430 klass1 = klass1->GetSuperKlass();
431 }
432 while (klass2 != nullptr) {
433 v2.push_back(klass2);
434 klass2 = klass2->GetSuperKlass();
435 }
436 Klass *result = nullptr;
437 size_t size1 = v1.size();
438 size_t size2 = v2.size();
439 size_t min = (size1 < size2) ? size1 : size2;
440 for (size_t i = 1; i <= min; ++i) {
441 CHECK_FATAL(size1 > 0, "must not be zero");
442 if (v1[size1 - i] != v2[size2 - i]) {
443 break;
444 }
445 result = v1[size1 - i];
446 }
447 return result;
448 }
449
GetLCA(TyIdx ty1,TyIdx ty2) const450 TyIdx KlassHierarchy::GetLCA(TyIdx ty1, TyIdx ty2) const
451 {
452 Klass *result = GetLCA(GetKlassFromTyIdx(ty1), GetKlassFromTyIdx(ty2));
453 return (result != nullptr ? result->GetTypeIdx() : TyIdx(0));
454 }
455
GetLCA(GStrIdx str1,GStrIdx str2) const456 GStrIdx KlassHierarchy::GetLCA(GStrIdx str1, GStrIdx str2) const
457 {
458 Klass *result = GetLCA(GetKlassFromStrIdx(str1), GetKlassFromStrIdx(str2));
459 return (result != nullptr ? result->GetKlassNameStrIdx() : GStrIdx(0));
460 }
461
GetLCA(const std::string & name1,const std::string & name2) const462 const std::string &KlassHierarchy::GetLCA(const std::string &name1, const std::string &name2) const
463 {
464 Klass *result = GetLCA(GetKlassFromName(name1), GetKlassFromName(name2));
465 return (result != nullptr ? result->GetKlassName() : GlobalTables::GetStrTable().GetStringFromStrIdx(GStrIdx(0)));
466 }
467
AddKlasses()468 void KlassHierarchy::AddKlasses()
469 {
470 for (MIRType *type : GlobalTables::GetTypeTable().GetTypeTable()) {
471 #if DEBUG
472 if (type != nullptr) {
473 MIRTypeKind kd = type->GetKind();
474 if (kd == kTypeStructIncomplete || kd == kTypeClassIncomplete || kd == kTypeInterfaceIncomplete)
475 LogInfo::MapleLogger() << "Warning: KlassHierarchy::AddKlasses "
476 << GlobalTables::GetStrTable().GetStringFromStrIdx(type->GetNameStrIdx())
477 << " INCOMPLETE \n";
478 }
479 #endif
480 if (Options::deferredVisit2 && type && (type->IsIncomplete())) {
481 GStrIdx stridx = type->GetNameStrIdx();
482 std::string strName = GlobalTables::GetStrTable().GetStringFromStrIdx(stridx);
483 #if DEBUG
484 LogInfo::MapleLogger() << "Waring: " << strName << " INCOMPLETE \n";
485 #endif
486 if (strName == namemangler::kClassMetadataTypeName) {
487 continue;
488 }
489 } else if (type == nullptr || (type->GetKind() != kTypeClass && type->GetKind() != kTypeInterface)) {
490 continue;
491 }
492 GStrIdx strIdx = type->GetNameStrIdx();
493 Klass *klass = GetKlassFromStrIdx(strIdx);
494 if (klass != nullptr) {
495 continue;
496 }
497 auto *stype = static_cast<MIRStructType *>(type);
498 klass = GetMempool()->New<Klass>(stype, &alloc);
499 strIdx2KlassMap[strIdx] = klass;
500 }
501 }
502
ExceptionFlagProp(Klass & klass)503 void KlassHierarchy::ExceptionFlagProp(Klass &klass)
504 {
505 klass.SetExceptionKlass();
506 for (Klass *subClass : klass.GetSubKlasses()) {
507 DEBUG_ASSERT(subClass != nullptr, "null ptr check!");
508 subClass->SetExceptionKlass();
509 ExceptionFlagProp(*subClass);
510 }
511 }
512
CollectImplInterfaces(const Klass & klass,std::set<Klass * > & implInterfaceSet)513 static void CollectImplInterfaces(const Klass &klass, std::set<Klass *> &implInterfaceSet)
514 {
515 for (Klass *superKlass : klass.GetSuperKlasses()) {
516 if (implInterfaceSet.find(superKlass) == implInterfaceSet.end()) {
517 DEBUG_ASSERT(superKlass != nullptr, "null ptr check!");
518 if (superKlass->IsInterface()) {
519 implInterfaceSet.insert(superKlass);
520 }
521 CollectImplInterfaces(*superKlass, implInterfaceSet);
522 }
523 }
524 for (Klass *interfaceKlass : klass.GetImplInterfaces()) {
525 if (implInterfaceSet.find(interfaceKlass) == implInterfaceSet.end()) {
526 implInterfaceSet.insert(interfaceKlass);
527 DEBUG_ASSERT(interfaceKlass != nullptr, "null ptr check!");
528 CollectImplInterfaces(*interfaceKlass, implInterfaceSet);
529 }
530 }
531 }
532
UpdateImplementedInterfaces()533 void KlassHierarchy::UpdateImplementedInterfaces()
534 {
535 for (const auto &pair : strIdx2KlassMap) {
536 Klass *klass = pair.second;
537 DEBUG_ASSERT(klass != nullptr, "null ptr check");
538 if (!klass->IsInterface()) {
539 std::set<Klass *> implInterfaceSet;
540 CollectImplInterfaces(*klass, implInterfaceSet);
541 for (auto interface : implInterfaceSet) {
542 // Add missing parent interface to class link
543 interface->AddImplKlass(klass);
544 klass->AddImplInterface(interface);
545 }
546 }
547 }
548 }
549
GetParentKlasses(const Klass & klass,std::vector<Klass * > & parentKlasses) const550 void KlassHierarchy::GetParentKlasses(const Klass &klass, std::vector<Klass *> &parentKlasses) const
551 {
552 for (Klass *superKlass : klass.GetSuperKlasses()) {
553 parentKlasses.push_back(superKlass);
554 }
555 if (!klass.IsInterface()) {
556 for (Klass *iklass : klass.GetImplInterfaces()) {
557 parentKlasses.push_back(iklass);
558 }
559 }
560 }
561
GetChildKlasses(const Klass & klass,std::vector<Klass * > & childKlasses) const562 void KlassHierarchy::GetChildKlasses(const Klass &klass, std::vector<Klass *> &childKlasses) const
563 {
564 for (Klass *subKlass : klass.GetSubKlasses()) {
565 childKlasses.push_back(subKlass);
566 }
567 if (klass.IsInterface()) {
568 for (Klass *implKlass : klass.GetImplKlasses()) {
569 childKlasses.push_back(implKlass);
570 }
571 }
572 }
573
TopologicalSortKlasses()574 void KlassHierarchy::TopologicalSortKlasses()
575 {
576 std::set<Klass *> inQueue; // Local variable, no need to use MapleSet
577 for (const auto &pair : strIdx2KlassMap) {
578 Klass *klass = pair.second;
579 DEBUG_ASSERT(klass != nullptr, "klass can not be nullptr");
580 if (!klass->HasSuperKlass() && !klass->ImplementsKlass()) {
581 topoWorkList.push_back(klass);
582 inQueue.insert(klass);
583 }
584 }
585 // Top-down iterates all nodes
586 for (size_t i = 0; i < topoWorkList.size(); ++i) {
587 Klass *klass = topoWorkList[i];
588 std::vector<Klass *> childklasses;
589 DEBUG_ASSERT(klass != nullptr, "null ptr check!");
590 GetChildKlasses(*klass, childklasses);
591 for (Klass *childklass : childklasses) {
592 if (inQueue.find(childklass) == inQueue.end()) {
593 // callee has not been visited
594 bool parentKlassAllVisited = true;
595 std::vector<Klass *> parentKlasses;
596 DEBUG_ASSERT(childklass != nullptr, "null ptr check!");
597 GetParentKlasses(*childklass, parentKlasses);
598 // Check whether all callers of the current callee have been visited
599 for (Klass *parentKlass : parentKlasses) {
600 if (inQueue.find(parentKlass) == inQueue.end()) {
601 parentKlassAllVisited = false;
602 break;
603 }
604 }
605 if (parentKlassAllVisited) {
606 topoWorkList.push_back(childklass);
607 inQueue.insert(childklass);
608 }
609 }
610 }
611 }
612 }
613
CountVirtualMethods() const614 void KlassHierarchy::CountVirtualMethods() const
615 {
616 // Bottom-up iterates all klass nodes
617 for (size_t i = topoWorkList.size(); i != 0; --i) {
618 topoWorkList[i - 1]->CountVirtMethBottomUp();
619 }
620 }
621
AddClassFlag(const std::string & name,uint32 flag)622 Klass *KlassHierarchy::AddClassFlag(const std::string &name, uint32 flag)
623 {
624 Klass *refKlass = GetKlassFromLiteral(name);
625 if (refKlass != nullptr) {
626 refKlass->SetFlag(flag);
627 }
628 return refKlass;
629 }
630
Dump() const631 void KlassHierarchy::Dump() const
632 {
633 for (Klass *klass : topoWorkList) {
634 klass->Dump();
635 }
636 }
637
GetUniqueMethod(GStrIdx vfuncNameStridx) const638 GStrIdx KlassHierarchy::GetUniqueMethod(GStrIdx vfuncNameStridx) const
639 {
640 auto it = vfunc2RfuncMap.find(vfuncNameStridx);
641 return (it != vfunc2RfuncMap.end() ? it->second : GStrIdx(0));
642 }
643
IsDevirtualListEmpty() const644 bool KlassHierarchy::IsDevirtualListEmpty() const
645 {
646 return vfunc2RfuncMap.empty();
647 }
648
DumpDevirtualList(const std::string & outputFileName) const649 void KlassHierarchy::DumpDevirtualList(const std::string &outputFileName) const
650 {
651 std::unordered_map<std::string, std::string> funcMap;
652 for (Klass *klass : topoWorkList) {
653 for (MIRFunction *func : klass->GetMethods()) {
654 MIRFunction *uniqCand = klass->GetUniqueMethod(func->GetBaseFuncNameWithTypeStrIdx());
655 if (uniqCand != nullptr) {
656 funcMap[func->GetName()] = uniqCand->GetName();
657 }
658 }
659 }
660 std::ofstream outputFile;
661 outputFile.open(outputFileName);
662 for (auto it : funcMap) {
663 outputFile << it.first << "\t" << it.second << "\n";
664 }
665 outputFile.close();
666 }
667
ReadDevirtualList(const std::string & inputFileName)668 void KlassHierarchy::ReadDevirtualList(const std::string &inputFileName)
669 {
670 std::ifstream inputFile;
671 inputFile.open(inputFileName);
672 std::string vfuncName;
673 std::string rfuncName;
674 while (inputFile >> vfuncName >> rfuncName) {
675 vfunc2RfuncMap[GlobalTables::GetStrTable().GetOrCreateStrIdxFromName(vfuncName)] =
676 GlobalTables::GetStrTable().GetOrCreateStrIdxFromName(rfuncName);
677 }
678 inputFile.close();
679 }
680
BuildHierarchy()681 void KlassHierarchy::BuildHierarchy()
682 {
683 // Scan class list and generate Klass without method information
684 AddKlasses();
685 // In the case of "class C implements B; interface B extends A;",
686 // we need to add a link between C and A.
687 UpdateImplementedInterfaces();
688 TopologicalSortKlasses();
689 // Use --dump-devirtual-list=... to dump a closed-wolrd analysis result into a file
690 if (!Options::dumpDevirtualList.empty()) {
691 DumpDevirtualList(Options::dumpDevirtualList);
692 }
693 // Use --read-devirtual-list=... to read in a closed-world analysis result
694 if (!Options::readDevirtualList.empty()) {
695 ReadDevirtualList(Options::readDevirtualList);
696 }
697 }
698
KlassHierarchy(MIRModule * mirmodule,MemPool * memPool)699 KlassHierarchy::KlassHierarchy(MIRModule *mirmodule, MemPool *memPool)
700 : AnalysisResult(memPool),
701 alloc(memPool),
702 mirModule(mirmodule),
703 strIdx2KlassMap(std::less<GStrIdx>(), alloc.Adapter()),
704 vfunc2RfuncMap(std::less<GStrIdx>(), alloc.Adapter()),
705 topoWorkList(alloc.Adapter())
706 {
707 }
708 } // namespace maple
709