• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "transforms/passes/prune_deopt.h"
17 #include "transforms/transform_utils.h"
18 
19 #include <llvm/IR/IRBuilder.h>
20 #include <llvm/IR/InlineAsm.h>
21 #include <llvm/IR/MDBuilder.h>
22 #include <llvm/IR/Verifier.h>
23 #include <llvm/Pass.h>
24 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
25 
26 #define DEBUG_TYPE "prune-deopt"
27 
28 // Basic classes
29 using llvm::ArrayRef;
30 using llvm::BasicBlock;
31 using llvm::Function;
32 using llvm::FunctionAnalysisManager;
33 using llvm::OperandBundleDef;
34 using llvm::OperandBundleUse;
35 using llvm::Use;
36 // Instructions
37 using llvm::CallInst;
38 using llvm::ConstantInt;
39 using llvm::Instruction;
40 
41 namespace ark::llvmbackend::passes {
42 
run(Function & function,FunctionAnalysisManager &)43 llvm::PreservedAnalyses PruneDeopt::run(Function &function, FunctionAnalysisManager & /*analysisManager*/)
44 {
45     LLVM_DEBUG(llvm::dbgs() << "Pruning Deopts for: " << function.getName() << "\n");
46     bool changed = false;
47     for (auto &block : function) {
48         for (auto iter = block.begin(); iter != block.end();) {
49             auto &inst = *iter;
50             iter++;
51             auto call = llvm::dyn_cast<CallInst>(&inst);
52             if (call == nullptr) {
53                 continue;
54             }
55             auto bundle = call->getOperandBundle(llvm::LLVMContext::OB_deopt);
56             if (bundle == llvm::None) {
57                 continue;
58             }
59             auto noReturn = IsNoReturn(bundle->Inputs);
60             auto updated = GetUpdatedCallInst(call, bundle.getValue());
61             changed = true;
62             if (noReturn) {
63                 MakeUnreachableAfter(&block, updated);
64                 break;
65             }
66         }
67     }
68     return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
69 }
70 
GetUpdatedCallInst(CallInst * call,const OperandBundleUse & bundle)71 CallInst *PruneDeopt::GetUpdatedCallInst(CallInst *call, const OperandBundleUse &bundle)
72 {
73     CallInst *updated;
74     if (!IsCaughtDeoptimization(bundle.Inputs) && !call->hasFnAttr("may-deoptimize")) {
75         LLVM_DEBUG(llvm::dbgs() << "Pruning deopt for: " << *call << "\n");
76         ASSERT(call->getNumOperandBundles() == 1);
77         OperandBundleDef emptyBundle {"deopt", llvm::None};
78         updated = CallInst::Create(call, emptyBundle);
79         auto iinfo = GetInlineInfo(bundle.Inputs);
80         if (!iinfo.empty()) {
81             updated->addFnAttr(llvm::Attribute::get(updated->getContext(), "inline-info", iinfo));
82         }
83     } else {
84         LLVM_DEBUG(llvm::dbgs() << "Encoding deopt bundle: " << *call << "\n");
85         ASSERT(call->getNumOperandBundles() == 1);
86         OperandBundleDef encodedBundle {"deopt", EncodeDeoptBundle(call, bundle)};
87         updated = CallInst::Create(call, {encodedBundle});
88     }
89     ReplaceInstWithInst(call, updated);
90     LLVM_DEBUG(llvm::dbgs() << "Replaced with: " << *updated << "\n");
91     return updated;
92 }
93 
IsCaughtDeoptimization(ArrayRef<Use> inputs) const94 bool PruneDeopt::IsCaughtDeoptimization(ArrayRef<Use> inputs) const
95 {
96     constexpr auto CAUGHT_FLAG_IDX = 3;
97     for (uint32_t i = 0; i < inputs.size(); ++i) {
98         if (llvm::isa<Function>(inputs[i])) {
99             ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
100             uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
101             if ((tryFlag & 1U) > 0) {
102                 return true;
103             }
104         }
105     }
106     return false;
107 }
108 
IsNoReturn(ArrayRef<Use> inputs) const109 bool PruneDeopt::IsNoReturn(ArrayRef<Use> inputs) const
110 {
111     constexpr auto CAUGHT_FLAG_IDX = 3;
112     for (uint32_t i = 0; i < inputs.size(); ++i) {
113         if (llvm::isa<Function>(inputs[i])) {
114             ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
115             uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
116             if ((tryFlag & 2U) > 0) {
117                 return true;
118             }
119         }
120     }
121     return false;
122 }
123 
EncodeDeoptBundle(CallInst * call,const OperandBundleUse & bundle) const124 PruneDeopt::EncodedDeoptBundle PruneDeopt::EncodeDeoptBundle(CallInst *call, const OperandBundleUse &bundle) const
125 {
126     EncodedDeoptBundle encoded;
127     // Reserve place for function counter
128     encoded.push_back(nullptr);
129     // Reserve space for function indexes
130     for (const auto &ops : bundle.Inputs) {
131         if (llvm::isa<Function>(ops)) {
132             encoded.push_back(nullptr);
133         }
134     }
135     bool mayBeDeoptIf = call->hasFnAttr("may-deoptimize");
136     llvm::IRBuilder<> builder(call);
137     // Set amount of functions
138     encoded[0] = builder.getInt32(encoded.size() - 1);
139     auto functionIndex = 1;
140     constexpr auto REGMAP_FLAG = 1U;
141     constexpr auto REGMAP_FLAG_IDX = 3;
142     size_t offs = REGMAP_FLAG_IDX;
143     for (const auto &ops : bundle.Inputs) {
144         if (llvm::isa<Function>(ops)) {
145             // Record position of the next function
146             encoded[functionIndex++] = builder.getInt32(encoded.size());
147             offs = 0;
148         } else {
149             offs++;
150             if (offs == REGMAP_FLAG_IDX && mayBeDeoptIf) {
151                 encoded.push_back(builder.getInt32(llvm::cast<ConstantInt>(ops)->getZExtValue() | REGMAP_FLAG));
152             } else {
153                 encoded.push_back(ops);
154             }
155         }
156     }
157     return encoded;
158 }
159 
GetInlineInfo(ArrayRef<Use> inputs) const160 std::string PruneDeopt::GetInlineInfo(ArrayRef<Use> inputs) const
161 {
162     constexpr auto METHOD_ID_IDX = 1;
163     std::string inlineInfo;
164     for (uint32_t i = 0; i < inputs.size(); i++) {
165         if (llvm::isa<Function>(inputs[i])) {
166             ASSERT((i + METHOD_ID_IDX) < inputs.size());
167             if (!inlineInfo.empty()) {
168                 inlineInfo.append(",");
169             }
170 
171             auto methodId = llvm::cast<ConstantInt>(inputs[i + METHOD_ID_IDX])->getZExtValue();
172             inlineInfo.append(std::to_string(methodId));
173         }
174     }
175     return inlineInfo;
176 }
177 
MakeUnreachableAfter(BasicBlock * block,Instruction * after) const178 void PruneDeopt::MakeUnreachableAfter(BasicBlock *block, Instruction *after) const
179 {
180     // Remove the BLOCK from phi instructions
181     for (auto succ : successors(block)) {
182         for (auto phii = succ->phis().begin(), end = succ->phis().end(); phii != end;) {
183             auto &phi = *phii++;
184             auto idx = phi.getBasicBlockIndex(block);
185             if (idx != -1) {
186                 phi.removeIncomingValue(idx);
187             }
188         }
189     }
190     auto maybeCall = llvm::dyn_cast<llvm::CallInst>(after);
191     auto deoptimize = maybeCall != nullptr && maybeCall->getIntrinsicID() == llvm::Intrinsic::experimental_deoptimize;
192 
193     // Remove all instructions after AFTER
194     for (auto iter = block->rbegin(); (&(*iter) != after);) {
195         auto &toRemove = *iter++;
196         Instruction *inst = &toRemove;
197         // Do not remove ret instruction after deoptimize call
198         if (deoptimize && toRemove.isTerminator()) {
199             continue;
200         }
201         inst->replaceAllUsesWith(llvm::UndefValue::get(inst->getType()));
202         inst->eraseFromParent();
203     }
204     if (!deoptimize) {
205         // Create unreachable after AFTER
206         llvm::IRBuilder<> builder(block);
207         builder.CreateUnreachable();
208     }
209 }
210 }  // namespace ark::llvmbackend::passes
211