1 /**
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "transforms/builtins.h"
17 #include "transforms/passes/devirt.h"
18 #include "transforms/transform_utils.h"
19 #include "utils.h"
20 #include "llvm_ark_interface.h"
21 #include "llvm_compiler_options.h"
22
23 namespace ark::llvmbackend::passes {
24
Create(LLVMArkInterface * arkInterface,const ark::llvmbackend::LLVMCompilerOptions * options)25 Devirt Devirt::Create(LLVMArkInterface *arkInterface,
26 [[maybe_unused]] const ark::llvmbackend::LLVMCompilerOptions *options)
27 {
28 return Devirt(arkInterface);
29 }
30
Devirt(LLVMArkInterface * arkInterface)31 Devirt::Devirt(LLVMArkInterface *arkInterface) : arkInterface_ {arkInterface} {}
32
ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions * options)33 bool Devirt::ShouldInsert(const ark::llvmbackend::LLVMCompilerOptions *options)
34 {
35 return options->doVirtualInline;
36 }
37
GetObjectClassId(llvm::CallInst * call)38 llvm::ConstantInt *GetObjectClassId(llvm::CallInst *call)
39 {
40 auto thisArg = llvm::dyn_cast<llvm::Instruction>(call->getArgOperand(1));
41 if (thisArg == nullptr) {
42 return nullptr;
43 }
44 auto allocate = llvm::dyn_cast<llvm::CallInst>(thisArg);
45 // NOTE: handle entrypoints
46 if (allocate == nullptr || allocate->arg_size() == 0) {
47 return nullptr;
48 }
49 auto loadAndInit = llvm::dyn_cast<llvm::CallInst>(allocate->getArgOperand(0));
50 // NOTE: handle entrypoints
51 if (loadAndInit == nullptr) {
52 return nullptr;
53 }
54 // NOTE: support more sophisticated cases (maybe with type propagation)
55 auto loadAndInitFunc = loadAndInit->getCalledFunction();
56 auto module = call->getModule();
57 if (loadAndInitFunc != ark::llvmbackend::builtins::LoadInitClass(module)) {
58 return nullptr;
59 }
60 auto *objectKlassId = llvm::dyn_cast<llvm::ConstantInt>(loadAndInit->getArgOperand(0));
61 if (objectKlassId == nullptr) {
62 return nullptr;
63 }
64 return objectKlassId;
65 }
66
run(llvm::Function & function,llvm::FunctionAnalysisManager &)67 llvm::PreservedAnalyses Devirt::run(llvm::Function &function, llvm::FunctionAnalysisManager & /*analysis_manager*/)
68 {
69 ASSERT(arkInterface_ != nullptr);
70 bool changed = false;
71 for (auto &block : function) {
72 for (auto &instruction : block) {
73 auto *call = llvm::dyn_cast<llvm::CallInst>(&instruction);
74 if (call == nullptr || call->getCalledFunction() == nullptr || call->arg_size() < 2U ||
75 call->getCalledFunction()->isIntrinsic()) {
76 continue;
77 }
78
79 if (arkInterface_->IsRememberedCall(call->getFunction(), call->getCalledFunction())) {
80 continue;
81 }
82
83 auto *objectKlassId = GetObjectClassId(call);
84 if (objectKlassId == nullptr) {
85 continue;
86 }
87
88 auto methodPtr = arkInterface_->ResolveVirtual(objectKlassId->getZExtValue(), call);
89 if (methodPtr == nullptr) {
90 continue;
91 }
92
93 auto methodName = arkInterface_->GetUniqMethodName(methodPtr);
94 auto func = function.getParent()->getFunction(methodName);
95 if (func == nullptr || func == call->getCalledFunction() || func->isDeclaration()) {
96 continue;
97 }
98
99 auto ftype = call->getFunctionType();
100 call->setCalledFunction(ftype, func);
101 if (arkInterface_->IsExternal(call) && ark::llvmbackend::utils::HasCallsWithDeopt(*func)) {
102 call->addAttributeAtIndex(llvm::AttributeList::FunctionIndex, llvm::Attribute::NoInline);
103 }
104
105 arkInterface_->PutVirtualFunction(methodPtr, func);
106 changed = true;
107 }
108 }
109 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
110 }
111
112 } // namespace ark::llvmbackend::passes
113