1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "transforms/passes/prune_deopt.h"
17 #include "transforms/transform_utils.h"
18
19 #include <llvm/IR/IRBuilder.h>
20 #include <llvm/IR/InlineAsm.h>
21 #include <llvm/IR/MDBuilder.h>
22 #include <llvm/IR/Verifier.h>
23 #include <llvm/Pass.h>
24 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
25
26 #define DEBUG_TYPE "prune-deopt"
27
28 // Basic classes
29 using llvm::ArrayRef;
30 using llvm::BasicBlock;
31 using llvm::Function;
32 using llvm::FunctionAnalysisManager;
33 using llvm::OperandBundleDef;
34 using llvm::OperandBundleUse;
35 using llvm::Use;
36 // Instructions
37 using llvm::CallInst;
38 using llvm::ConstantInt;
39 using llvm::Instruction;
40
41 namespace ark::llvmbackend::passes {
42
run(Function & function,FunctionAnalysisManager &)43 llvm::PreservedAnalyses PruneDeopt::run(Function &function, FunctionAnalysisManager & /*analysisManager*/)
44 {
45 LLVM_DEBUG(llvm::dbgs() << "Pruning Deopts for: " << function.getName() << "\n");
46 bool changed = false;
47 for (auto &block : function) {
48 for (auto iter = block.begin(); iter != block.end();) {
49 auto &inst = *iter;
50 iter++;
51 auto call = llvm::dyn_cast<CallInst>(&inst);
52 if (call == nullptr) {
53 continue;
54 }
55 auto bundle = call->getOperandBundle(llvm::LLVMContext::OB_deopt);
56 if (bundle == llvm::None) {
57 continue;
58 }
59 auto noReturn = IsNoReturn(bundle->Inputs);
60 auto updated = GetUpdatedCallInst(call, bundle.getValue());
61 changed = true;
62 if (noReturn) {
63 MakeUnreachableAfter(&block, updated);
64 break;
65 }
66 }
67 }
68 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
69 }
70
GetUpdatedCallInst(CallInst * call,const OperandBundleUse & bundle)71 CallInst *PruneDeopt::GetUpdatedCallInst(CallInst *call, const OperandBundleUse &bundle)
72 {
73 CallInst *updated;
74 if (!IsCaughtDeoptimization(bundle.Inputs) && !call->hasFnAttr("may-deoptimize")) {
75 LLVM_DEBUG(llvm::dbgs() << "Pruning deopt for: " << *call << "\n");
76 ASSERT(call->getNumOperandBundles() == 1);
77 updated = CallInst::Create(call, llvm::None);
78 auto iinfo = GetInlineInfo(bundle.Inputs);
79 if (!iinfo.empty()) {
80 updated->addFnAttr(llvm::Attribute::get(updated->getContext(), "inline-info", iinfo));
81 }
82 } else {
83 LLVM_DEBUG(llvm::dbgs() << "Encoding deopt bundle: " << *call << "\n");
84 ASSERT(call->getNumOperandBundles() == 1);
85 OperandBundleDef encodedBundle {"deopt", EncodeDeoptBundle(call, bundle)};
86 updated = CallInst::Create(call, {encodedBundle});
87 }
88 ReplaceInstWithInst(call, updated);
89 LLVM_DEBUG(llvm::dbgs() << "Replaced with: " << *updated << "\n");
90 return updated;
91 }
92
IsCaughtDeoptimization(ArrayRef<Use> inputs) const93 bool PruneDeopt::IsCaughtDeoptimization(ArrayRef<Use> inputs) const
94 {
95 constexpr auto CAUGHT_FLAG_IDX = 3;
96 for (uint32_t i = 0; i < inputs.size(); ++i) {
97 if (llvm::isa<Function>(inputs[i])) {
98 ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
99 uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
100 if ((tryFlag & 1U) > 0) {
101 return true;
102 }
103 }
104 }
105 return false;
106 }
107
IsNoReturn(ArrayRef<Use> inputs) const108 bool PruneDeopt::IsNoReturn(ArrayRef<Use> inputs) const
109 {
110 constexpr auto CAUGHT_FLAG_IDX = 3;
111 for (uint32_t i = 0; i < inputs.size(); ++i) {
112 if (llvm::isa<Function>(inputs[i])) {
113 ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
114 uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
115 if ((tryFlag & 2U) > 0) {
116 return true;
117 }
118 }
119 }
120 return false;
121 }
122
EncodeDeoptBundle(CallInst * call,const OperandBundleUse & bundle) const123 PruneDeopt::EncodedDeoptBundle PruneDeopt::EncodeDeoptBundle(CallInst *call, const OperandBundleUse &bundle) const
124 {
125 EncodedDeoptBundle encoded;
126 // Reserve place for function counter
127 encoded.push_back(nullptr);
128 // Reserve space for function indexes
129 for (const auto &ops : bundle.Inputs) {
130 if (llvm::isa<Function>(ops)) {
131 encoded.push_back(nullptr);
132 }
133 }
134 bool mayBeDeoptIf = call->hasFnAttr("may-deoptimize");
135 llvm::IRBuilder<> builder(call);
136 // Set amount of functions
137 encoded[0] = builder.getInt32(encoded.size() - 1);
138 auto functionIndex = 1;
139 constexpr auto REGMAP_FLAG = 1U;
140 constexpr auto REGMAP_FLAG_IDX = 3;
141 size_t offs = REGMAP_FLAG_IDX;
142 for (const auto &ops : bundle.Inputs) {
143 if (llvm::isa<Function>(ops)) {
144 // Record position of the next function
145 encoded[functionIndex++] = builder.getInt32(encoded.size());
146 offs = 0;
147 } else {
148 offs++;
149 if (offs == REGMAP_FLAG_IDX && mayBeDeoptIf) {
150 encoded.push_back(builder.getInt32(llvm::cast<ConstantInt>(ops)->getZExtValue() | REGMAP_FLAG));
151 } else {
152 encoded.push_back(ops);
153 }
154 }
155 }
156 return encoded;
157 }
158
GetInlineInfo(ArrayRef<Use> inputs) const159 std::string PruneDeopt::GetInlineInfo(ArrayRef<Use> inputs) const
160 {
161 constexpr auto METHOD_ID_IDX = 1;
162 std::string inlineInfo;
163 for (uint32_t i = 0; i < inputs.size(); i++) {
164 if (llvm::isa<Function>(inputs[i])) {
165 ASSERT((i + METHOD_ID_IDX) < inputs.size());
166 if (!inlineInfo.empty()) {
167 inlineInfo.append(",");
168 }
169
170 auto methodId = llvm::cast<ConstantInt>(inputs[i + METHOD_ID_IDX])->getZExtValue();
171 inlineInfo.append(std::to_string(methodId));
172 }
173 }
174 return inlineInfo;
175 }
176
MakeUnreachableAfter(BasicBlock * block,Instruction * after) const177 void PruneDeopt::MakeUnreachableAfter(BasicBlock *block, Instruction *after) const
178 {
179 // Remove the BLOCK from phi instructions
180 for (auto succ : successors(block)) {
181 for (auto phii = succ->phis().begin(), end = succ->phis().end(); phii != end;) {
182 auto &phi = *phii++;
183 auto idx = phi.getBasicBlockIndex(block);
184 if (idx != -1) {
185 phi.removeIncomingValue(idx);
186 }
187 }
188 }
189 // Remove all instructions after AFTER
190 for (auto iter = block->rbegin(); (&(*iter) != after);) {
191 auto &toRemove = *iter++;
192 Instruction *inst = &toRemove;
193 inst->replaceAllUsesWith(llvm::UndefValue::get(inst->getType()));
194 inst->eraseFromParent();
195 }
196 // Create unreachable after AFTER
197 llvm::IRBuilder<> builder(block);
198 builder.CreateUnreachable();
199 }
200 } // namespace ark::llvmbackend::passes
201