• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "transforms/passes/prune_deopt.h"
17 #include "transforms/transform_utils.h"
18 
19 #include <llvm/IR/IRBuilder.h>
20 #include <llvm/IR/InlineAsm.h>
21 #include <llvm/IR/MDBuilder.h>
22 #include <llvm/IR/Verifier.h>
23 #include <llvm/Pass.h>
24 #include <llvm/Transforms/Utils/BasicBlockUtils.h>
25 
26 #define DEBUG_TYPE "prune-deopt"
27 
28 // Basic classes
29 using llvm::ArrayRef;
30 using llvm::BasicBlock;
31 using llvm::Function;
32 using llvm::FunctionAnalysisManager;
33 using llvm::OperandBundleDef;
34 using llvm::OperandBundleUse;
35 using llvm::Use;
36 // Instructions
37 using llvm::CallInst;
38 using llvm::ConstantInt;
39 using llvm::Instruction;
40 
41 namespace ark::llvmbackend::passes {
42 
run(Function & function,FunctionAnalysisManager &)43 llvm::PreservedAnalyses PruneDeopt::run(Function &function, FunctionAnalysisManager & /*analysisManager*/)
44 {
45     LLVM_DEBUG(llvm::dbgs() << "Pruning Deopts for: " << function.getName() << "\n");
46     bool changed = false;
47     for (auto &block : function) {
48         for (auto iter = block.begin(); iter != block.end();) {
49             auto &inst = *iter;
50             iter++;
51             auto call = llvm::dyn_cast<CallInst>(&inst);
52             if (call == nullptr) {
53                 continue;
54             }
55             auto bundle = call->getOperandBundle(llvm::LLVMContext::OB_deopt);
56             if (bundle == llvm::None) {
57                 continue;
58             }
59             auto noReturn = IsNoReturn(bundle->Inputs);
60             auto updated = GetUpdatedCallInst(call, bundle.getValue());
61             changed = true;
62             if (noReturn) {
63                 MakeUnreachableAfter(&block, updated);
64                 break;
65             }
66         }
67     }
68     return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all();
69 }
70 
GetUpdatedCallInst(CallInst * call,const OperandBundleUse & bundle)71 CallInst *PruneDeopt::GetUpdatedCallInst(CallInst *call, const OperandBundleUse &bundle)
72 {
73     CallInst *updated;
74     if (!IsCaughtDeoptimization(bundle.Inputs) && !call->hasFnAttr("may-deoptimize")) {
75         LLVM_DEBUG(llvm::dbgs() << "Pruning deopt for: " << *call << "\n");
76         ASSERT(call->getNumOperandBundles() == 1);
77         updated = CallInst::Create(call, llvm::None);
78         auto iinfo = GetInlineInfo(bundle.Inputs);
79         if (!iinfo.empty()) {
80             updated->addFnAttr(llvm::Attribute::get(updated->getContext(), "inline-info", iinfo));
81         }
82     } else {
83         LLVM_DEBUG(llvm::dbgs() << "Encoding deopt bundle: " << *call << "\n");
84         ASSERT(call->getNumOperandBundles() == 1);
85         OperandBundleDef encodedBundle {"deopt", EncodeDeoptBundle(call, bundle)};
86         updated = CallInst::Create(call, {encodedBundle});
87     }
88     ReplaceInstWithInst(call, updated);
89     LLVM_DEBUG(llvm::dbgs() << "Replaced with: " << *updated << "\n");
90     return updated;
91 }
92 
IsCaughtDeoptimization(ArrayRef<Use> inputs) const93 bool PruneDeopt::IsCaughtDeoptimization(ArrayRef<Use> inputs) const
94 {
95     constexpr auto CAUGHT_FLAG_IDX = 3;
96     for (uint32_t i = 0; i < inputs.size(); ++i) {
97         if (llvm::isa<Function>(inputs[i])) {
98             ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
99             uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
100             if ((tryFlag & 1U) > 0) {
101                 return true;
102             }
103         }
104     }
105     return false;
106 }
107 
IsNoReturn(ArrayRef<Use> inputs) const108 bool PruneDeopt::IsNoReturn(ArrayRef<Use> inputs) const
109 {
110     constexpr auto CAUGHT_FLAG_IDX = 3;
111     for (uint32_t i = 0; i < inputs.size(); ++i) {
112         if (llvm::isa<Function>(inputs[i])) {
113             ASSERT((i + CAUGHT_FLAG_IDX) < inputs.size());
114             uint32_t tryFlag = llvm::cast<ConstantInt>(inputs[i + CAUGHT_FLAG_IDX])->getZExtValue();
115             if ((tryFlag & 2U) > 0) {
116                 return true;
117             }
118         }
119     }
120     return false;
121 }
122 
EncodeDeoptBundle(CallInst * call,const OperandBundleUse & bundle) const123 PruneDeopt::EncodedDeoptBundle PruneDeopt::EncodeDeoptBundle(CallInst *call, const OperandBundleUse &bundle) const
124 {
125     EncodedDeoptBundle encoded;
126     // Reserve place for function counter
127     encoded.push_back(nullptr);
128     // Reserve space for function indexes
129     for (const auto &ops : bundle.Inputs) {
130         if (llvm::isa<Function>(ops)) {
131             encoded.push_back(nullptr);
132         }
133     }
134     bool mayBeDeoptIf = call->hasFnAttr("may-deoptimize");
135     llvm::IRBuilder<> builder(call);
136     // Set amount of functions
137     encoded[0] = builder.getInt32(encoded.size() - 1);
138     auto functionIndex = 1;
139     constexpr auto REGMAP_FLAG = 1U;
140     constexpr auto REGMAP_FLAG_IDX = 3;
141     size_t offs = REGMAP_FLAG_IDX;
142     for (const auto &ops : bundle.Inputs) {
143         if (llvm::isa<Function>(ops)) {
144             // Record position of the next function
145             encoded[functionIndex++] = builder.getInt32(encoded.size());
146             offs = 0;
147         } else {
148             offs++;
149             if (offs == REGMAP_FLAG_IDX && mayBeDeoptIf) {
150                 encoded.push_back(builder.getInt32(llvm::cast<ConstantInt>(ops)->getZExtValue() | REGMAP_FLAG));
151             } else {
152                 encoded.push_back(ops);
153             }
154         }
155     }
156     return encoded;
157 }
158 
GetInlineInfo(ArrayRef<Use> inputs) const159 std::string PruneDeopt::GetInlineInfo(ArrayRef<Use> inputs) const
160 {
161     constexpr auto METHOD_ID_IDX = 1;
162     std::string inlineInfo;
163     for (uint32_t i = 0; i < inputs.size(); i++) {
164         if (llvm::isa<Function>(inputs[i])) {
165             ASSERT((i + METHOD_ID_IDX) < inputs.size());
166             if (!inlineInfo.empty()) {
167                 inlineInfo.append(",");
168             }
169 
170             auto methodId = llvm::cast<ConstantInt>(inputs[i + METHOD_ID_IDX])->getZExtValue();
171             inlineInfo.append(std::to_string(methodId));
172         }
173     }
174     return inlineInfo;
175 }
176 
MakeUnreachableAfter(BasicBlock * block,Instruction * after) const177 void PruneDeopt::MakeUnreachableAfter(BasicBlock *block, Instruction *after) const
178 {
179     // Remove the BLOCK from phi instructions
180     for (auto succ : successors(block)) {
181         for (auto phii = succ->phis().begin(), end = succ->phis().end(); phii != end;) {
182             auto &phi = *phii++;
183             auto idx = phi.getBasicBlockIndex(block);
184             if (idx != -1) {
185                 phi.removeIncomingValue(idx);
186             }
187         }
188     }
189     // Remove all instructions after AFTER
190     for (auto iter = block->rbegin(); (&(*iter) != after);) {
191         auto &toRemove = *iter++;
192         Instruction *inst = &toRemove;
193         inst->replaceAllUsesWith(llvm::UndefValue::get(inst->getType()));
194         inst->eraseFromParent();
195     }
196     // Create unreachable after AFTER
197     llvm::IRBuilder<> builder(block);
198     builder.CreateUnreachable();
199 }
200 }  // namespace ark::llvmbackend::passes
201