• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a general divergence analysis for loop vectorization
10 // and GPU programs. It determines which branches and values in a loop or GPU
11 // program are divergent. It can help branch optimizations such as jump
12 // threading and loop unswitching to make better decisions.
13 //
14 // GPU programs typically use the SIMD execution model, where multiple threads
15 // in the same execution group have to execute in lock-step. Therefore, if the
16 // code contains divergent branches (i.e., threads in a group do not agree on
17 // which path of the branch to take), the group of threads has to execute all
18 // the paths from that branch with different subsets of threads enabled until
19 // they re-converge.
20 //
21 // Due to this execution model, some optimizations such as jump
22 // threading and loop unswitching can interfere with thread re-convergence.
23 // Therefore, an analysis that computes which branches in a GPU program are
24 // divergent can help the compiler to selectively run these optimizations.
25 //
26 // This implementation is derived from the Vectorization Analysis of the
27 // Region Vectorizer (RV). That implementation in turn is based on the approach
28 // described in
29 //
30 //   Improving Performance of OpenCL on CPUs
31 //   Ralf Karrenberg and Sebastian Hack
32 //   CC '12
33 //
34 // This DivergenceAnalysis implementation is generic in the sense that it does
35 // not itself identify original sources of divergence.
36 // Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37 // (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence
38 // (e.g., special variables that hold the thread ID or the iteration variable).
39 //
40 // The generic implementation propagates divergence to variables that are data
41 // or sync dependent on a source of divergence.
42 //
43 // While data dependency is a well-known concept, the notion of sync dependency
44 // is worth more explanation. Sync dependence characterizes the control flow
45 // aspect of the propagation of branch divergence. For example,
46 //
47 //   %cond = icmp slt i32 %tid, 10
48 //   br i1 %cond, label %then, label %else
49 // then:
50 //   br label %merge
51 // else:
52 //   br label %merge
53 // merge:
54 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
55 //
56 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57 // because %tid is not on its use-def chains, %a is sync dependent on %tid
58 // because the branch "br i1 %cond" depends on %tid and affects which value %a
59 // is assigned to.
60 //
61 // The sync dependence detection (which branch induces divergence in which join
62 // points) is implemented in the SyncDependenceAnalysis.
63 //
64 // The current DivergenceAnalysis implementation has the following limitations:
65 // 1. intra-procedural. It conservatively considers the arguments of a
66 //    non-kernel-entry function and the return value of a function call as
67 //    divergent.
68 // 2. memory as black box. It conservatively considers values loaded from
69 //    generic or local address as divergent. This can be improved by leveraging
70 //    pointer analysis and/or by modelling non-escaping memory objects in SSA
71 //    as done in RV.
72 //
73 //===----------------------------------------------------------------------===//
74 
75 #include "llvm/Analysis/DivergenceAnalysis.h"
76 #include "llvm/Analysis/LoopInfo.h"
77 #include "llvm/Analysis/Passes.h"
78 #include "llvm/Analysis/PostDominators.h"
79 #include "llvm/Analysis/TargetTransformInfo.h"
80 #include "llvm/IR/Dominators.h"
81 #include "llvm/IR/InstIterator.h"
82 #include "llvm/IR/Instructions.h"
83 #include "llvm/IR/IntrinsicInst.h"
84 #include "llvm/IR/Value.h"
85 #include "llvm/Support/Debug.h"
86 #include "llvm/Support/raw_ostream.h"
87 #include <vector>
88 
89 using namespace llvm;
90 
91 #define DEBUG_TYPE "divergence-analysis"
92 
93 // class DivergenceAnalysis
DivergenceAnalysis(const Function & F,const Loop * RegionLoop,const DominatorTree & DT,const LoopInfo & LI,SyncDependenceAnalysis & SDA,bool IsLCSSAForm)94 DivergenceAnalysis::DivergenceAnalysis(
95     const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
96     const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
97     : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
98       IsLCSSAForm(IsLCSSAForm) {}
99 
markDivergent(const Value & DivVal)100 void DivergenceAnalysis::markDivergent(const Value &DivVal) {
101   assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
102   assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
103   DivergentValues.insert(&DivVal);
104 }
105 
addUniformOverride(const Value & UniVal)106 void DivergenceAnalysis::addUniformOverride(const Value &UniVal) {
107   UniformOverrides.insert(&UniVal);
108 }
109 
updateTerminator(const Instruction & Term) const110 bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const {
111   if (Term.getNumSuccessors() <= 1)
112     return false;
113   if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) {
114     assert(BranchTerm->isConditional());
115     return isDivergent(*BranchTerm->getCondition());
116   }
117   if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) {
118     return isDivergent(*SwitchTerm->getCondition());
119   }
120   if (isa<InvokeInst>(Term)) {
121     return false; // ignore abnormal executions through landingpad
122   }
123 
124   llvm_unreachable("unexpected terminator");
125 }
126 
updateNormalInstruction(const Instruction & I) const127 bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const {
128   // TODO function calls with side effects, etc
129   for (const auto &Op : I.operands()) {
130     if (isDivergent(*Op))
131       return true;
132   }
133   return false;
134 }
135 
isTemporalDivergent(const BasicBlock & ObservingBlock,const Value & Val) const136 bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock,
137                                              const Value &Val) const {
138   const auto *Inst = dyn_cast<const Instruction>(&Val);
139   if (!Inst)
140     return false;
141   // check whether any divergent loop carrying Val terminates before control
142   // proceeds to ObservingBlock
143   for (const auto *Loop = LI.getLoopFor(Inst->getParent());
144        Loop != RegionLoop && !Loop->contains(&ObservingBlock);
145        Loop = Loop->getParentLoop()) {
146     if (DivergentLoops.find(Loop) != DivergentLoops.end())
147       return true;
148   }
149 
150   return false;
151 }
152 
updatePHINode(const PHINode & Phi) const153 bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const {
154   // joining divergent disjoint path in Phi parent block
155   if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) {
156     return true;
157   }
158 
159   // An incoming value could be divergent by itself.
160   // Otherwise, an incoming value could be uniform within the loop
161   // that carries its definition but it may appear divergent
162   // from outside the loop. This happens when divergent loop exits
163   // drop definitions of that uniform value in different iterations.
164   //
165   // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop
166   //   if (i % thread_id == 0) break;    // divergent loop exit
167   // }
168   // int divI = i;                 // divI is divergent
169   for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) {
170     const auto *InVal = Phi.getIncomingValue(i);
171     if (isDivergent(*Phi.getIncomingValue(i)) ||
172         isTemporalDivergent(*Phi.getParent(), *InVal)) {
173       return true;
174     }
175   }
176   return false;
177 }
178 
inRegion(const Instruction & I) const179 bool DivergenceAnalysis::inRegion(const Instruction &I) const {
180   return I.getParent() && inRegion(*I.getParent());
181 }
182 
inRegion(const BasicBlock & BB) const183 bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
184   return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
185 }
186 
187 // marks all users of loop-carried values of the loop headed by LoopHeader as
188 // divergent
taintLoopLiveOuts(const BasicBlock & LoopHeader)189 void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
190   auto *DivLoop = LI.getLoopFor(&LoopHeader);
191   assert(DivLoop && "loopHeader is not actually part of a loop");
192 
193   SmallVector<BasicBlock *, 8> TaintStack;
194   DivLoop->getExitBlocks(TaintStack);
195 
196   // Otherwise potential users of loop-carried values could be anywhere in the
197   // dominance region of DivLoop (including its fringes for phi nodes)
198   DenseSet<const BasicBlock *> Visited;
199   for (auto *Block : TaintStack) {
200     Visited.insert(Block);
201   }
202   Visited.insert(&LoopHeader);
203 
204   while (!TaintStack.empty()) {
205     auto *UserBlock = TaintStack.back();
206     TaintStack.pop_back();
207 
208     // don't spread divergence beyond the region
209     if (!inRegion(*UserBlock))
210       continue;
211 
212     assert(!DivLoop->contains(UserBlock) &&
213            "irreducible control flow detected");
214 
215     // phi nodes at the fringes of the dominance region
216     if (!DT.dominates(&LoopHeader, UserBlock)) {
217       // all PHI nodes of UserBlock become divergent
218       for (auto &Phi : UserBlock->phis()) {
219         Worklist.push_back(&Phi);
220       }
221       continue;
222     }
223 
224     // taint outside users of values carried by DivLoop
225     for (auto &I : *UserBlock) {
226       if (isAlwaysUniform(I))
227         continue;
228       if (isDivergent(I))
229         continue;
230 
231       for (auto &Op : I.operands()) {
232         auto *OpInst = dyn_cast<Instruction>(&Op);
233         if (!OpInst)
234           continue;
235         if (DivLoop->contains(OpInst->getParent())) {
236           markDivergent(I);
237           pushUsers(I);
238           break;
239         }
240       }
241     }
242 
243     // visit all blocks in the dominance region
244     for (auto *SuccBlock : successors(UserBlock)) {
245       if (!Visited.insert(SuccBlock).second) {
246         continue;
247       }
248       TaintStack.push_back(SuccBlock);
249     }
250   }
251 }
252 
pushPHINodes(const BasicBlock & Block)253 void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) {
254   for (const auto &Phi : Block.phis()) {
255     if (isDivergent(Phi))
256       continue;
257     Worklist.push_back(&Phi);
258   }
259 }
260 
pushUsers(const Value & V)261 void DivergenceAnalysis::pushUsers(const Value &V) {
262   for (const auto *User : V.users()) {
263     const auto *UserInst = dyn_cast<const Instruction>(User);
264     if (!UserInst)
265       continue;
266 
267     if (isDivergent(*UserInst))
268       continue;
269 
270     // only compute divergent inside loop
271     if (!inRegion(*UserInst))
272       continue;
273     Worklist.push_back(UserInst);
274   }
275 }
276 
propagateJoinDivergence(const BasicBlock & JoinBlock,const Loop * BranchLoop)277 bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock,
278                                                  const Loop *BranchLoop) {
279   LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n");
280 
281   // ignore divergence outside the region
282   if (!inRegion(JoinBlock)) {
283     return false;
284   }
285 
286   // push non-divergent phi nodes in JoinBlock to the worklist
287   pushPHINodes(JoinBlock);
288 
289   // JoinBlock is a divergent loop exit
290   if (BranchLoop && !BranchLoop->contains(&JoinBlock)) {
291     return true;
292   }
293 
294   // disjoint-paths divergent at JoinBlock
295   markBlockJoinDivergent(JoinBlock);
296   return false;
297 }
298 
propagateBranchDivergence(const Instruction & Term)299 void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
300   LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n");
301 
302   markDivergent(Term);
303 
304   const auto *BranchLoop = LI.getLoopFor(Term.getParent());
305 
306   // whether there is a divergent loop exit from BranchLoop (if any)
307   bool IsBranchLoopDivergent = false;
308 
309   // iterate over all blocks reachable by disjoint from Term within the loop
310   // also iterates over loop exits that become divergent due to Term.
311   for (const auto *JoinBlock : SDA.join_blocks(Term)) {
312     IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
313   }
314 
315   // Branch loop is a divergent loop due to the divergent branch in Term
316   if (IsBranchLoopDivergent) {
317     assert(BranchLoop);
318     if (!DivergentLoops.insert(BranchLoop).second) {
319       return;
320     }
321     propagateLoopDivergence(*BranchLoop);
322   }
323 }
324 
propagateLoopDivergence(const Loop & ExitingLoop)325 void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) {
326   LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n");
327 
328   // don't propagate beyond region
329   if (!inRegion(*ExitingLoop.getHeader()))
330     return;
331 
332   const auto *BranchLoop = ExitingLoop.getParentLoop();
333 
334   // Uses of loop-carried values could occur anywhere
335   // within the dominance region of the definition. All loop-carried
336   // definitions are dominated by the loop header (reducible control).
337   // Thus all users have to be in the dominance region of the loop header,
338   // except PHI nodes that can also live at the fringes of the dom region
339   // (incoming defining value).
340   if (!IsLCSSAForm)
341     taintLoopLiveOuts(*ExitingLoop.getHeader());
342 
343   // whether there is a divergent loop exit from BranchLoop (if any)
344   bool IsBranchLoopDivergent = false;
345 
346   // iterate over all blocks reachable by disjoint paths from exits of
347   // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn
348   // become divergent.
349   for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) {
350     IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop);
351   }
352 
353   // Branch loop is a divergent due to divergent loop exit in ExitingLoop
354   if (IsBranchLoopDivergent) {
355     assert(BranchLoop);
356     if (!DivergentLoops.insert(BranchLoop).second) {
357       return;
358     }
359     propagateLoopDivergence(*BranchLoop);
360   }
361 }
362 
compute()363 void DivergenceAnalysis::compute() {
364   for (auto *DivVal : DivergentValues) {
365     pushUsers(*DivVal);
366   }
367 
368   // propagate divergence
369   while (!Worklist.empty()) {
370     const Instruction &I = *Worklist.back();
371     Worklist.pop_back();
372 
373     // maintain uniformity of overrides
374     if (isAlwaysUniform(I))
375       continue;
376 
377     bool WasDivergent = isDivergent(I);
378     if (WasDivergent)
379       continue;
380 
381     // propagate divergence caused by terminator
382     if (I.isTerminator()) {
383       if (updateTerminator(I)) {
384         // propagate control divergence to affected instructions
385         propagateBranchDivergence(I);
386         continue;
387       }
388     }
389 
390     // update divergence of I due to divergent operands
391     bool DivergentUpd = false;
392     const auto *Phi = dyn_cast<const PHINode>(&I);
393     if (Phi) {
394       DivergentUpd = updatePHINode(*Phi);
395     } else {
396       DivergentUpd = updateNormalInstruction(I);
397     }
398 
399     // propagate value divergence to users
400     if (DivergentUpd) {
401       markDivergent(I);
402       pushUsers(I);
403     }
404   }
405 }
406 
isAlwaysUniform(const Value & V) const407 bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const {
408   return UniformOverrides.find(&V) != UniformOverrides.end();
409 }
410 
isDivergent(const Value & V) const411 bool DivergenceAnalysis::isDivergent(const Value &V) const {
412   return DivergentValues.find(&V) != DivergentValues.end();
413 }
414 
isDivergentUse(const Use & U) const415 bool DivergenceAnalysis::isDivergentUse(const Use &U) const {
416   Value &V = *U.get();
417   Instruction &I = *cast<Instruction>(U.getUser());
418   return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
419 }
420 
print(raw_ostream & OS,const Module *) const421 void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
422   if (DivergentValues.empty())
423     return;
424   // iterate instructions using instructions() to ensure a deterministic order.
425   for (auto &I : instructions(F)) {
426     if (isDivergent(I))
427       OS << "DIVERGENT:" << I << '\n';
428   }
429 }
430 
431 // class GPUDivergenceAnalysis
GPUDivergenceAnalysis(Function & F,const DominatorTree & DT,const PostDominatorTree & PDT,const LoopInfo & LI,const TargetTransformInfo & TTI)432 GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F,
433                                              const DominatorTree &DT,
434                                              const PostDominatorTree &PDT,
435                                              const LoopInfo &LI,
436                                              const TargetTransformInfo &TTI)
437     : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) {
438   for (auto &I : instructions(F)) {
439     if (TTI.isSourceOfDivergence(&I)) {
440       DA.markDivergent(I);
441     } else if (TTI.isAlwaysUniform(&I)) {
442       DA.addUniformOverride(I);
443     }
444   }
445   for (auto &Arg : F.args()) {
446     if (TTI.isSourceOfDivergence(&Arg)) {
447       DA.markDivergent(Arg);
448     }
449   }
450 
451   DA.compute();
452 }
453 
isDivergent(const Value & val) const454 bool GPUDivergenceAnalysis::isDivergent(const Value &val) const {
455   return DA.isDivergent(val);
456 }
457 
isDivergentUse(const Use & use) const458 bool GPUDivergenceAnalysis::isDivergentUse(const Use &use) const {
459   return DA.isDivergentUse(use);
460 }
461 
print(raw_ostream & OS,const Module * mod) const462 void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const {
463   OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n";
464   DA.print(OS, mod);
465   OS << "}\n";
466 }
467