• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // The divergence analysis determines which instructions and branches are
11 // divergent given a set of divergent source instructions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
16 #define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/Analysis/SyncDependenceAnalysis.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/Pass.h"
22 #include <vector>
23 
24 namespace llvm {
25 class Module;
26 class Value;
27 class Instruction;
28 class Loop;
29 class raw_ostream;
30 class TargetTransformInfo;
31 
32 /// \brief Generic divergence analysis for reducible CFGs.
33 ///
34 /// This analysis propagates divergence in a data-parallel context from sources
35 /// of divergence to all users. It requires reducible CFGs. All assignments
36 /// should be in SSA form.
37 class DivergenceAnalysis {
38 public:
39   /// \brief This instance will analyze the whole function \p F or the loop \p
40   /// RegionLoop.
41   ///
42   /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
43   /// Otherwise the whole function is analyzed.
44   /// \param IsLCSSAForm whether the analysis may assume that the IR in the
45   /// region in in LCSSA form.
46   DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
47                      const DominatorTree &DT, const LoopInfo &LI,
48                      SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
49 
50   /// \brief The loop that defines the analyzed region (if any).
getRegionLoop()51   const Loop *getRegionLoop() const { return RegionLoop; }
getFunction()52   const Function &getFunction() const { return F; }
53 
54   /// \brief Whether \p BB is part of the region.
55   bool inRegion(const BasicBlock &BB) const;
56   /// \brief Whether \p I is part of the region.
57   bool inRegion(const Instruction &I) const;
58 
59   /// \brief Mark \p UniVal as a value that is always uniform.
60   void addUniformOverride(const Value &UniVal);
61 
62   /// \brief Mark \p DivVal as a value that is always divergent.
63   void markDivergent(const Value &DivVal);
64 
65   /// \brief Propagate divergence to all instructions in the region.
66   /// Divergence is seeded by calls to \p markDivergent.
67   void compute();
68 
69   /// \brief Whether any value was marked or analyzed to be divergent.
hasDetectedDivergence()70   bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
71 
72   /// \brief Whether \p Val will always return a uniform value regardless of its
73   /// operands
74   bool isAlwaysUniform(const Value &Val) const;
75 
76   /// \brief Whether \p Val is divergent at its definition.
77   bool isDivergent(const Value &Val) const;
78 
79   /// \brief Whether \p U is divergent. Uses of a uniform value can be divergent.
80   bool isDivergentUse(const Use &U) const;
81 
82   void print(raw_ostream &OS, const Module *) const;
83 
84 private:
85   bool updateTerminator(const Instruction &Term) const;
86   bool updatePHINode(const PHINode &Phi) const;
87 
88   /// \brief Computes whether \p Inst is divergent based on the
89   /// divergence of its operands.
90   ///
91   /// \returns Whether \p Inst is divergent.
92   ///
93   /// This should only be called for non-phi, non-terminator instructions.
94   bool updateNormalInstruction(const Instruction &Inst) const;
95 
96   /// \brief Mark users of live-out users as divergent.
97   ///
98   /// \param LoopHeader the header of the divergent loop.
99   ///
100   /// Marks all users of live-out values of the loop headed by \p LoopHeader
101   /// as divergent and puts them on the worklist.
102   void taintLoopLiveOuts(const BasicBlock &LoopHeader);
103 
104   /// \brief Push all users of \p Val (in the region) to the worklist
105   void pushUsers(const Value &I);
106 
107   /// \brief Push all phi nodes in @block to the worklist
108   void pushPHINodes(const BasicBlock &Block);
109 
110   /// \brief Mark \p Block as join divergent
111   ///
112   /// A block is join divergent if two threads may reach it from different
113   /// incoming blocks at the same time.
markBlockJoinDivergent(const BasicBlock & Block)114   void markBlockJoinDivergent(const BasicBlock &Block) {
115     DivergentJoinBlocks.insert(&Block);
116   }
117 
118   /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
119   bool isTemporalDivergent(const BasicBlock &ObservingBlock,
120                            const Value &Val) const;
121 
122   /// \brief Whether \p Block is join divergent
123   ///
124   /// (see markBlockJoinDivergent).
isJoinDivergent(const BasicBlock & Block)125   bool isJoinDivergent(const BasicBlock &Block) const {
126     return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
127   }
128 
129   /// \brief Propagate control-induced divergence to users (phi nodes and
130   /// instructions).
131   //
132   // \param JoinBlock is a divergent loop exit or join point of two disjoint
133   // paths.
134   // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
135   bool propagateJoinDivergence(const BasicBlock &JoinBlock,
136                                const Loop *TermLoop);
137 
138   /// \brief Propagate induced value divergence due to control divergence in \p
139   /// Term.
140   void propagateBranchDivergence(const Instruction &Term);
141 
142   /// \brief Propagate divergent caused by a divergent loop exit.
143   ///
144   /// \param ExitingLoop is a divergent loop.
145   void propagateLoopDivergence(const Loop &ExitingLoop);
146 
147 private:
148   const Function &F;
149   // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
150   // Otw, analyze the whole function
151   const Loop *RegionLoop;
152 
153   const DominatorTree &DT;
154   const LoopInfo &LI;
155 
156   // Recognized divergent loops
157   DenseSet<const Loop *> DivergentLoops;
158 
159   // The SDA links divergent branches to divergent control-flow joins.
160   SyncDependenceAnalysis &SDA;
161 
162   // Use simplified code path for LCSSA form.
163   bool IsLCSSAForm;
164 
165   // Set of known-uniform values.
166   DenseSet<const Value *> UniformOverrides;
167 
168   // Blocks with joining divergent control from different predecessors.
169   DenseSet<const BasicBlock *> DivergentJoinBlocks;
170 
171   // Detected/marked divergent values.
172   DenseSet<const Value *> DivergentValues;
173 
174   // Internal worklist for divergence propagation.
175   std::vector<const Instruction *> Worklist;
176 };
177 
178 /// \brief Divergence analysis frontend for GPU kernels.
179 class GPUDivergenceAnalysis {
180   SyncDependenceAnalysis SDA;
181   DivergenceAnalysis DA;
182 
183 public:
184   /// Runs the divergence analysis on @F, a GPU kernel
185   GPUDivergenceAnalysis(Function &F, const DominatorTree &DT,
186                         const PostDominatorTree &PDT, const LoopInfo &LI,
187                         const TargetTransformInfo &TTI);
188 
189   /// Whether any divergence was detected.
hasDivergence()190   bool hasDivergence() const { return DA.hasDetectedDivergence(); }
191 
192   /// The GPU kernel this analysis result is for
getFunction()193   const Function &getFunction() const { return DA.getFunction(); }
194 
195   /// Whether \p V is divergent at its definition.
196   bool isDivergent(const Value &V) const;
197 
198   /// Whether \p U is divergent. Uses of a uniform value can be divergent.
199   bool isDivergentUse(const Use &U) const;
200 
201   /// Whether \p V is uniform/non-divergent.
isUniform(const Value & V)202   bool isUniform(const Value &V) const { return !isDivergent(V); }
203 
204   /// Whether \p U is uniform/non-divergent. Uses of a uniform value can be
205   /// divergent.
isUniformUse(const Use & U)206   bool isUniformUse(const Use &U) const { return !isDivergentUse(U); }
207 
208   /// Print all divergent values in the kernel.
209   void print(raw_ostream &OS, const Module *) const;
210 };
211 
212 } // namespace llvm
213 
214 #endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
215