• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===---------------- BPFAdjustOpt.cpp - Adjust Optimization --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Adjust optimization to make the code more kernel verifier friendly.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "BPF.h"
14 #include "BPFCORE.h"
15 #include "BPFTargetMachine.h"
16 #include "llvm/IR/Instruction.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Type.h"
20 #include "llvm/IR/User.h"
21 #include "llvm/IR/Value.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 
25 #define DEBUG_TYPE "bpf-adjust-opt"
26 
27 using namespace llvm;
28 
29 static cl::opt<bool>
30     DisableBPFserializeICMP("bpf-disable-serialize-icmp", cl::Hidden,
31                             cl::desc("BPF: Disable Serializing ICMP insns."),
32                             cl::init(false));
33 
34 static cl::opt<bool> DisableBPFavoidSpeculation(
35     "bpf-disable-avoid-speculation", cl::Hidden,
36     cl::desc("BPF: Disable Avoiding Speculative Code Motion."),
37     cl::init(false));
38 
39 namespace {
40 
41 class BPFAdjustOpt final : public ModulePass {
42 public:
43   static char ID;
44 
BPFAdjustOpt()45   BPFAdjustOpt() : ModulePass(ID) {}
46   bool runOnModule(Module &M) override;
47 };
48 
49 class BPFAdjustOptImpl {
50   struct PassThroughInfo {
51     Instruction *Input;
52     Instruction *UsedInst;
53     uint32_t OpIdx;
PassThroughInfo__anon0b8fc79a0111::BPFAdjustOptImpl::PassThroughInfo54     PassThroughInfo(Instruction *I, Instruction *U, uint32_t Idx)
55         : Input(I), UsedInst(U), OpIdx(Idx) {}
56   };
57 
58 public:
BPFAdjustOptImpl(Module * M)59   BPFAdjustOptImpl(Module *M) : M(M) {}
60 
61   bool run();
62 
63 private:
64   Module *M;
65   SmallVector<PassThroughInfo, 16> PassThroughs;
66 
67   void adjustBasicBlock(BasicBlock &BB);
68   bool serializeICMPCrossBB(BasicBlock &BB);
69   void adjustInst(Instruction &I);
70   bool serializeICMPInBB(Instruction &I);
71   bool avoidSpeculation(Instruction &I);
72   bool insertPassThrough();
73 };
74 
75 } // End anonymous namespace
76 
77 char BPFAdjustOpt::ID = 0;
78 INITIALIZE_PASS(BPFAdjustOpt, "bpf-adjust-opt", "BPF Adjust Optimization",
79                 false, false)
80 
createBPFAdjustOpt()81 ModulePass *llvm::createBPFAdjustOpt() { return new BPFAdjustOpt(); }
82 
runOnModule(Module & M)83 bool BPFAdjustOpt::runOnModule(Module &M) { return BPFAdjustOptImpl(&M).run(); }
84 
run()85 bool BPFAdjustOptImpl::run() {
86   for (Function &F : *M)
87     for (auto &BB : F) {
88       adjustBasicBlock(BB);
89       for (auto &I : BB)
90         adjustInst(I);
91     }
92 
93   return insertPassThrough();
94 }
95 
insertPassThrough()96 bool BPFAdjustOptImpl::insertPassThrough() {
97   for (auto &Info : PassThroughs) {
98     auto *CI = BPFCoreSharedInfo::insertPassThrough(
99         M, Info.UsedInst->getParent(), Info.Input, Info.UsedInst);
100     Info.UsedInst->setOperand(Info.OpIdx, CI);
101   }
102 
103   return !PassThroughs.empty();
104 }
105 
106 // To avoid combining conditionals in the same basic block by
107 // instrcombine optimization.
serializeICMPInBB(Instruction & I)108 bool BPFAdjustOptImpl::serializeICMPInBB(Instruction &I) {
109   // For:
110   //   comp1 = icmp <opcode> ...;
111   //   comp2 = icmp <opcode> ...;
112   //   ... or comp1 comp2 ...
113   // changed to:
114   //   comp1 = icmp <opcode> ...;
115   //   comp2 = icmp <opcode> ...;
116   //   new_comp1 = __builtin_bpf_passthrough(seq_num, comp1)
117   //   ... or new_comp1 comp2 ...
118   if (I.getOpcode() != Instruction::Or)
119     return false;
120   auto *Icmp1 = dyn_cast<ICmpInst>(I.getOperand(0));
121   if (!Icmp1)
122     return false;
123   auto *Icmp2 = dyn_cast<ICmpInst>(I.getOperand(1));
124   if (!Icmp2)
125     return false;
126 
127   Value *Icmp1Op0 = Icmp1->getOperand(0);
128   Value *Icmp2Op0 = Icmp2->getOperand(0);
129   if (Icmp1Op0 != Icmp2Op0)
130     return false;
131 
132   // Now we got two icmp instructions which feed into
133   // an "or" instruction.
134   PassThroughInfo Info(Icmp1, &I, 0);
135   PassThroughs.push_back(Info);
136   return true;
137 }
138 
139 // To avoid combining conditionals in the same basic block by
140 // instrcombine optimization.
serializeICMPCrossBB(BasicBlock & BB)141 bool BPFAdjustOptImpl::serializeICMPCrossBB(BasicBlock &BB) {
142   // For:
143   //   B1:
144   //     comp1 = icmp <opcode> ...;
145   //     if (comp1) goto B2 else B3;
146   //   B2:
147   //     comp2 = icmp <opcode> ...;
148   //     if (comp2) goto B4 else B5;
149   //   B4:
150   //     ...
151   // changed to:
152   //   B1:
153   //     comp1 = icmp <opcode> ...;
154   //     comp1 = __builtin_bpf_passthrough(seq_num, comp1);
155   //     if (comp1) goto B2 else B3;
156   //   B2:
157   //     comp2 = icmp <opcode> ...;
158   //     if (comp2) goto B4 else B5;
159   //   B4:
160   //     ...
161 
162   // Check basic predecessors, if two of them (say B1, B2) are using
163   // icmp instructions to generate conditions and one is the predesessor
164   // of another (e.g., B1 is the predecessor of B2). Add a passthrough
165   // barrier after icmp inst of block B1.
166   BasicBlock *B2 = BB.getSinglePredecessor();
167   if (!B2)
168     return false;
169 
170   BasicBlock *B1 = B2->getSinglePredecessor();
171   if (!B1)
172     return false;
173 
174   Instruction *TI = B2->getTerminator();
175   auto *BI = dyn_cast<BranchInst>(TI);
176   if (!BI || !BI->isConditional())
177     return false;
178   auto *Cond = dyn_cast<ICmpInst>(BI->getCondition());
179   if (!Cond || B2->getFirstNonPHI() != Cond)
180     return false;
181   Value *B2Op0 = Cond->getOperand(0);
182   auto Cond2Op = Cond->getPredicate();
183 
184   TI = B1->getTerminator();
185   BI = dyn_cast<BranchInst>(TI);
186   if (!BI || !BI->isConditional())
187     return false;
188   Cond = dyn_cast<ICmpInst>(BI->getCondition());
189   if (!Cond)
190     return false;
191   Value *B1Op0 = Cond->getOperand(0);
192   auto Cond1Op = Cond->getPredicate();
193 
194   if (B1Op0 != B2Op0)
195     return false;
196 
197   if (Cond1Op == ICmpInst::ICMP_SGT || Cond1Op == ICmpInst::ICMP_SGE) {
198     if (Cond2Op != ICmpInst::ICMP_SLT && Cond1Op != ICmpInst::ICMP_SLE)
199       return false;
200   } else if (Cond1Op == ICmpInst::ICMP_SLT || Cond1Op == ICmpInst::ICMP_SLE) {
201     if (Cond2Op != ICmpInst::ICMP_SGT && Cond1Op != ICmpInst::ICMP_SGE)
202       return false;
203   } else {
204     return false;
205   }
206 
207   PassThroughInfo Info(Cond, BI, 0);
208   PassThroughs.push_back(Info);
209 
210   return true;
211 }
212 
213 // To avoid speculative hoisting certain computations out of
214 // a basic block.
avoidSpeculation(Instruction & I)215 bool BPFAdjustOptImpl::avoidSpeculation(Instruction &I) {
216   if (auto *LdInst = dyn_cast<LoadInst>(&I)) {
217     if (auto *GV = dyn_cast<GlobalVariable>(LdInst->getOperand(0))) {
218       if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
219           GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
220         return false;
221     }
222   }
223 
224   if (!dyn_cast<LoadInst>(&I) && !dyn_cast<CallInst>(&I))
225     return false;
226 
227   // For:
228   //   B1:
229   //     var = ...
230   //     ...
231   //     /* icmp may not be in the same block as var = ... */
232   //     comp1 = icmp <opcode> var, <const>;
233   //     if (comp1) goto B2 else B3;
234   //   B2:
235   //     ... var ...
236   // change to:
237   //   B1:
238   //     var = ...
239   //     ...
240   //     /* icmp may not be in the same block as var = ... */
241   //     comp1 = icmp <opcode> var, <const>;
242   //     if (comp1) goto B2 else B3;
243   //   B2:
244   //     var = __builtin_bpf_passthrough(seq_num, var);
245   //     ... var ...
246   bool isCandidate = false;
247   SmallVector<PassThroughInfo, 4> Candidates;
248   for (User *U : I.users()) {
249     Instruction *Inst = dyn_cast<Instruction>(U);
250     if (!Inst)
251       continue;
252 
253     // May cover a little bit more than the
254     // above pattern.
255     if (auto *Icmp1 = dyn_cast<ICmpInst>(Inst)) {
256       Value *Icmp1Op1 = Icmp1->getOperand(1);
257       if (!isa<Constant>(Icmp1Op1))
258         return false;
259       isCandidate = true;
260       continue;
261     }
262 
263     // Ignore the use in the same basic block as the definition.
264     if (Inst->getParent() == I.getParent())
265       continue;
266 
267     // use in a different basic block, If there is a call or
268     // load/store insn before this instruction in this basic
269     // block. Most likely it cannot be hoisted out. Skip it.
270     for (auto &I2 : *Inst->getParent()) {
271       if (dyn_cast<CallInst>(&I2))
272         return false;
273       if (dyn_cast<LoadInst>(&I2) || dyn_cast<StoreInst>(&I2))
274         return false;
275       if (&I2 == Inst)
276         break;
277     }
278 
279     // It should be used in a GEP or a simple arithmetic like
280     // ZEXT/SEXT which is used for GEP.
281     if (Inst->getOpcode() == Instruction::ZExt ||
282         Inst->getOpcode() == Instruction::SExt) {
283       PassThroughInfo Info(&I, Inst, 0);
284       Candidates.push_back(Info);
285     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
286       // traverse GEP inst to find Use operand index
287       unsigned i, e;
288       for (i = 1, e = GI->getNumOperands(); i != e; ++i) {
289         Value *V = GI->getOperand(i);
290         if (V == &I)
291           break;
292       }
293       if (i == e)
294         continue;
295 
296       PassThroughInfo Info(&I, GI, i);
297       Candidates.push_back(Info);
298     }
299   }
300 
301   if (!isCandidate || Candidates.empty())
302     return false;
303 
304   PassThroughs.insert(PassThroughs.end(), Candidates.begin(), Candidates.end());
305   return true;
306 }
307 
adjustBasicBlock(BasicBlock & BB)308 void BPFAdjustOptImpl::adjustBasicBlock(BasicBlock &BB) {
309   if (!DisableBPFserializeICMP && serializeICMPCrossBB(BB))
310     return;
311 }
312 
adjustInst(Instruction & I)313 void BPFAdjustOptImpl::adjustInst(Instruction &I) {
314   if (!DisableBPFserializeICMP && serializeICMPInBB(I))
315     return;
316   if (!DisableBPFavoidSpeculation && avoidSpeculation(I))
317     return;
318 }
319 
run(Module & M,ModuleAnalysisManager & AM)320 PreservedAnalyses BPFAdjustOptPass::run(Module &M, ModuleAnalysisManager &AM) {
321   return BPFAdjustOptImpl(&M).run() ? PreservedAnalyses::none()
322                                     : PreservedAnalyses::all();
323 }
324