1 /*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "transforms/passes/prune_deopt.h"
17 #include "transforms/transform_utils.h"
18
19 #include <llvm/IR/IRBuilder.h>
20 #include <llvm/IR/InlineAsm.h>
21 #include <llvm/IR/MDBuilder.h>
22 #include <llvm/IR/Verifier.h>
23 #include <llvm/Pass.h>
24 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
25
26 #define DEBUG_TYPE "prune-deopt"
27
28 // Basic classes
29 using llvm::ArrayRef;
30 using llvm::BasicBlock;
31 using llvm::Function;
32 using llvm::FunctionAnalysisManager;
33 using llvm::OperandBundleDef;
34 using llvm::OperandBundleUse;
35 using llvm::Use;
36 // Instructions
37 using llvm::CallInst;
38 using llvm::ConstantInt;
39 using llvm::Instruction;
40
41 namespace ark::llvmbackend::passes {
42
run(Function & function,FunctionAnalysisManager &)43 llvm::PreservedAnalyses PruneDeopt::run(Function &function, FunctionAnalysisManager & /*analysisManager*/)
44 {
45 LLVM_DEBUG(llvm::dbgs() << "Pruning Deopts for: " << function.getName() << "\n");
46 bool changed = false;
47 for (auto &block : function) {
48 for (auto iter = block.begin(); iter != block.end();) {
49 auto &inst = *iter;
50 iter++;
51 auto call = llvm::dyn_cast<CallInst>(&inst);
52 if (call == nullptr) {
53 continue;
54 }
55 auto bundle = call->getOperandBundle(llvm::LLVMContext::OB_deopt);
56 if (bundle == llvm::None) {
57 continue;
58 }
59 auto noReturn = IsNoReturn(bundle->Inputs);
60 auto updated = GetUpdatedCallInst(call, bundle.getValue());
61 changed = true;
62 if (noReturn) {
63 MakeUnreachableAfter(&block, updated);
64 break;
65 }
66 }
67 }
68 return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
69 }
70
GetUpdatedCallInst(CallInst * call,const OperandBundleUse & bundle)71 CallInst *PruneDeopt::GetUpdatedCallInst(CallInst *call, const OperandBundleUse &bundle)
72 {
73 CallInst *updated;
74 if (!IsCaughtDeoptimization(bundle.Inputs) && !call->hasFnAttr("may-deoptimize")) {
75 LLVM_DEBUG(llvm::dbgs() << "Pruning deopt for: " << *call << "\n");
76 ASSERT(call->getNumOperandBundles() == 1);
77 OperandBundleDef emptyBundle {"deopt", llvm::None};
78 updated = CallInst::Create(call, emptyBundle);
79 auto iinfo = GetInlineInfo(bundle.Inputs);
80 if (!iinfo.empty()) {
81 updated->addFnAttr(llvm::Attribute::get(updated->getContext(), "inline-info", iinfo));
82 }
83 } else {
84 LLVM_DEBUG(llvm::dbgs() << "Encoding deopt bundle: " << *call << "\n");
85 ASSERT(call->getNumOperandBundles() == 1);
86 OperandBundleDef encodedBundle {"deopt", EncodeDeoptBundle(call, bundle)};
87 updated = CallInst::Create(call, {encodedBundle});
88 }
89 ReplaceInstWithInst(call, updated);
90 LLVM_DEBUG(llvm::dbgs() << "Replaced with: " << *updated << "\n");
91 return updated;
92 }
93
IsCaughtDeoptimization(ArrayRef<Use> inputs) const94 bool PruneDeopt::IsCaughtDeoptimization(ArrayRef<Use> inputs) const
95 {
96 constexpr auto CAUGHT_FLAG_IDX = 3;
97 for (uint32_t i = 0; i < inputs.size(); ++i) {
98 if (llvm::isa<Function>(inputs[i])) {
99 ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
100 uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
101 if ((tryFlag & 1U) > 0) {
102 return true;
103 }
104 }
105 }
106 return false;
107 }
108
IsNoReturn(ArrayRef<Use> inputs) const109 bool PruneDeopt::IsNoReturn(ArrayRef<Use> inputs) const
110 {
111 constexpr auto CAUGHT_FLAG_IDX = 3;
112 for (uint32_t i = 0; i < inputs.size(); ++i) {
113 if (llvm::isa<Function>(inputs[i])) {
114 ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
115 uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
116 if ((tryFlag & 2U) > 0) {
117 return true;
118 }
119 }
120 }
121 return false;
122 }
123
EncodeDeoptBundle(CallInst * call,const OperandBundleUse & bundle) const124 PruneDeopt::EncodedDeoptBundle PruneDeopt::EncodeDeoptBundle(CallInst *call, const OperandBundleUse &bundle) const
125 {
126 EncodedDeoptBundle encoded;
127 // Reserve place for function counter
128 encoded.push_back(nullptr);
129 // Reserve space for function indexes
130 for (const auto &ops : bundle.Inputs) {
131 if (llvm::isa<Function>(ops)) {
132 encoded.push_back(nullptr);
133 }
134 }
135 bool mayBeDeoptIf = call->hasFnAttr("may-deoptimize");
136 llvm::IRBuilder<> builder(call);
137 // Set amount of functions
138 encoded[0] = builder.getInt32(encoded.size() - 1);
139 auto functionIndex = 1;
140 constexpr auto REGMAP_FLAG = 1U;
141 constexpr auto REGMAP_FLAG_IDX = 3;
142 size_t offs = REGMAP_FLAG_IDX;
143 for (const auto &ops : bundle.Inputs) {
144 if (llvm::isa<Function>(ops)) {
145 // Record position of the next function
146 encoded[functionIndex++] = builder.getInt32(encoded.size());
147 offs = 0;
148 } else {
149 offs++;
150 if (offs == REGMAP_FLAG_IDX && mayBeDeoptIf) {
151 encoded.push_back(builder.getInt32(llvm::cast<ConstantInt>(ops)->getZExtValue() | REGMAP_FLAG));
152 } else {
153 encoded.push_back(ops);
154 }
155 }
156 }
157 return encoded;
158 }
159
GetInlineInfo(ArrayRef<Use> inputs) const160 std::string PruneDeopt::GetInlineInfo(ArrayRef<Use> inputs) const
161 {
162 constexpr auto METHOD_ID_IDX = 1;
163 std::string inlineInfo;
164 for (uint32_t i = 0; i < inputs.size(); i++) {
165 if (llvm::isa<Function>(inputs[i])) {
166 ASSERT((i + METHOD_ID_IDX) < inputs.size());
167 if (!inlineInfo.empty()) {
168 inlineInfo.append(",");
169 }
170
171 auto methodId = llvm::cast<ConstantInt>(inputs[i + METHOD_ID_IDX])->getZExtValue();
172 inlineInfo.append(std::to_string(methodId));
173 }
174 }
175 return inlineInfo;
176 }
177
MakeUnreachableAfter(BasicBlock * block,Instruction * after) const178 void PruneDeopt::MakeUnreachableAfter(BasicBlock *block, Instruction *after) const
179 {
180 // Remove the BLOCK from phi instructions
181 for (auto succ : successors(block)) {
182 for (auto phii = succ->phis().begin(), end = succ->phis().end(); phii != end;) {
183 auto &phi = *phii++;
184 auto idx = phi.getBasicBlockIndex(block);
185 if (idx != -1) {
186 phi.removeIncomingValue(idx);
187 }
188 }
189 }
190 auto maybeCall = llvm::dyn_cast<llvm::CallInst>(after);
191 auto deoptimize = maybeCall != nullptr && maybeCall->getIntrinsicID() == llvm::Intrinsic::experimental_deoptimize;
192
193 // Remove all instructions after AFTER
194 for (auto iter = block->rbegin(); (&(*iter) != after);) {
195 auto &toRemove = *iter++;
196 Instruction *inst = &toRemove;
197 // Do not remove ret instruction after deoptimize call
198 if (deoptimize && toRemove.isTerminator()) {
199 continue;
200 }
201 inst->replaceAllUsesWith(llvm::UndefValue::get(inst->getType()));
202 inst->eraseFromParent();
203 }
204 if (!deoptimize) {
205 // Create unreachable after AFTER
206 llvm::IRBuilder<> builder(block);
207 builder.CreateUnreachable();
208 }
209 }
210 } // namespace ark::llvmbackend::passes
211