• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This template implementation resides in a separate file so that it
10 // does not get injected into every .cpp file that includes the
11 // generic header.
12 //
13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14 //
15 // This file should only be included by files that implement a
16 // specialization of the relvant templates. Currently these are:
17 // - UniformityAnalysis.cpp
18 //
19 // Note: The DEBUG_TYPE macro should be defined before using this
20 // file so that any use of LLVM_DEBUG is associated with the
21 // including file rather than this file.
22 //
23 //===----------------------------------------------------------------------===//
24 ///
25 /// \file
26 /// \brief Implementation of uniformity analysis.
27 ///
28 /// The algorithm is a fixed point iteration that starts with the assumption
29 /// that all control flow and all values are uniform. Starting from sources of
30 /// divergence (whose discovery must be implemented by a CFG- or even
31 /// target-specific derived class), divergence of values is propagated from
32 /// definition to uses in a straight-forward way. The main complexity lies in
33 /// the propagation of the impact of divergent control flow on the divergence of
34 /// values (sync dependencies).
35 ///
36 /// NOTE: In general, no interface exists for a transform to update
37 /// (Machine)UniformityInfo. Additionally, (Machine)CycleAnalysis is a
38 /// transitive dependence, but it also does not provide an interface for
39 /// updating itself. Given that, transforms should not preserve uniformity in
40 /// their getAnalysisUsage() callback.
41 ///
42 //===----------------------------------------------------------------------===//
43 
44 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
45 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
46 
47 #include "llvm/ADT/GenericUniformityInfo.h"
48 
49 #include "llvm/ADT/STLExtras.h"
50 #include "llvm/ADT/SmallPtrSet.h"
51 #include "llvm/ADT/SparseBitVector.h"
52 #include "llvm/ADT/StringExtras.h"
53 #include "llvm/Support/raw_ostream.h"
54 
55 #include <set>
56 
57 #define DEBUG_TYPE "uniformity"
58 
59 namespace llvm {
60 
61 /// Construct a specially modified post-order traversal of cycles.
62 ///
63 /// The ModifiedPO is contructed using a virtually modified CFG as follows:
64 ///
65 /// 1. The successors of pre-entry nodes (predecessors of an cycle
66 ///    entry that are outside the cycle) are replaced by the
67 ///    successors of the successors of the header.
68 /// 2. Successors of the cycle header are replaced by the exit blocks
69 ///    of the cycle.
70 ///
71 /// Effectively, we produce a depth-first numbering with the following
72 /// properties:
73 ///
74 /// 1. Nodes after a cycle are numbered earlier than the cycle header.
75 /// 2. The header is numbered earlier than the nodes in the cycle.
76 /// 3. The numbering of the nodes within the cycle forms an interval
77 ///    starting with the header.
78 ///
79 /// Effectively, the virtual modification arranges the nodes in a
80 /// cycle as a DAG with the header as the sole leaf, and successors of
81 /// the header as the roots. A reverse traversal of this numbering has
82 /// the following invariant on the unmodified original CFG:
83 ///
84 ///    Each node is visited after all its predecessors, except if that
85 ///    predecessor is the cycle header.
86 ///
87 template <typename ContextT> class ModifiedPostOrder {
88 public:
89   using BlockT = typename ContextT::BlockT;
90   using FunctionT = typename ContextT::FunctionT;
91   using DominatorTreeT = typename ContextT::DominatorTreeT;
92 
93   using CycleInfoT = GenericCycleInfo<ContextT>;
94   using CycleT = typename CycleInfoT::CycleT;
95   using const_iterator = typename std::vector<BlockT *>::const_iterator;
96 
ModifiedPostOrder(const ContextT & C)97   ModifiedPostOrder(const ContextT &C) : Context(C) {}
98 
empty()99   bool empty() const { return m_order.empty(); }
size()100   size_t size() const { return m_order.size(); }
101 
clear()102   void clear() { m_order.clear(); }
103   void compute(const CycleInfoT &CI);
104 
count(BlockT * BB)105   unsigned count(BlockT *BB) const { return POIndex.count(BB); }
106   const BlockT *operator[](size_t idx) const { return m_order[idx]; }
107 
108   void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
109     POIndex[&BB] = m_order.size();
110     m_order.push_back(&BB);
111     LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
112                       << "): " << Context.print(&BB) << "\n");
113     if (isReducibleCycleHeader)
114       ReducibleCycleHeaders.insert(&BB);
115   }
116 
getIndex(const BlockT * BB)117   unsigned getIndex(const BlockT *BB) const {
118     assert(POIndex.count(BB));
119     return POIndex.lookup(BB);
120   }
121 
isReducibleCycleHeader(const BlockT * BB)122   bool isReducibleCycleHeader(const BlockT *BB) const {
123     return ReducibleCycleHeaders.contains(BB);
124   }
125 
126 private:
127   SmallVector<const BlockT *> m_order;
128   DenseMap<const BlockT *, unsigned> POIndex;
129   SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
130   const ContextT &Context;
131 
132   void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
133                       SmallPtrSetImpl<const BlockT *> &Finalized);
134 
135   void computeStackPO(SmallVectorImpl<const BlockT *> &Stack,
136                       const CycleInfoT &CI, const CycleT *Cycle,
137                       SmallPtrSetImpl<const BlockT *> &Finalized);
138 };
139 
140 template <typename> class DivergencePropagator;
141 
142 /// \class GenericSyncDependenceAnalysis
143 ///
144 /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
145 ///
146 /// An analysis per divergent branch that returns the set of basic
147 /// blocks whose phi nodes become divergent due to divergent control.
148 /// These are the blocks that are reachable by two disjoint paths from
149 /// the branch, or cycle exits reachable along a path that is disjoint
150 /// from a path to the cycle latch.
151 
152 // --- Above line is not a doxygen comment; intentionally left blank ---
153 //
154 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
155 //
156 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
157 // control-induced divergence in phi nodes.
158 //
159 // -- Reference --
160 // The algorithm is an extension of Section 5 of
161 //
162 //   An abstract interpretation for SPMD divergence
163 //       on reducible control flow graphs.
164 //   Julian Rosemann, Simon Moll and Sebastian Hack
165 //   POPL '21
166 //
167 //
168 // -- Sync dependence --
169 // Sync dependence characterizes the control flow aspect of the
170 // propagation of branch divergence. For example,
171 //
172 //   %cond = icmp slt i32 %tid, 10
173 //   br i1 %cond, label %then, label %else
174 // then:
175 //   br label %merge
176 // else:
177 //   br label %merge
178 // merge:
179 //   %a = phi i32 [ 0, %then ], [ 1, %else ]
180 //
181 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
182 // because %tid is not on its use-def chains, %a is sync dependent on %tid
183 // because the branch "br i1 %cond" depends on %tid and affects which value %a
184 // is assigned to.
185 //
186 //
187 // -- Reduction to SSA construction --
188 // There are two disjoint paths from A to X, if a certain variant of SSA
189 // construction places a phi node in X under the following set-up scheme.
190 //
191 // This variant of SSA construction ignores incoming undef values.
192 // That is paths from the entry without a definition do not result in
193 // phi nodes.
194 //
195 //       entry
196 //     /      \
197 //    A        \
198 //  /   \       Y
199 // B     C     /
200 //  \   /  \  /
201 //    D     E
202 //     \   /
203 //       F
204 //
205 // Assume that A contains a divergent branch. We are interested
206 // in the set of all blocks where each block is reachable from A
207 // via two disjoint paths. This would be the set {D, F} in this
208 // case.
209 // To generally reduce this query to SSA construction we introduce
210 // a virtual variable x and assign to x different values in each
211 // successor block of A.
212 //
213 //           entry
214 //         /      \
215 //        A        \
216 //      /   \       Y
217 // x = 0   x = 1   /
218 //      \  /   \  /
219 //        D     E
220 //         \   /
221 //           F
222 //
223 // Our flavor of SSA construction for x will construct the following
224 //
225 //            entry
226 //          /      \
227 //         A        \
228 //       /   \       Y
229 // x0 = 0   x1 = 1  /
230 //       \   /   \ /
231 //     x2 = phi   E
232 //         \     /
233 //         x3 = phi
234 //
235 // The blocks D and F contain phi nodes and are thus each reachable
236 // by two disjoins paths from A.
237 //
238 // -- Remarks --
239 // * In case of cycle exits we need to check for temporal divergence.
240 //   To this end, we check whether the definition of x differs between the
241 //   cycle exit and the cycle header (_after_ SSA construction).
242 //
243 // * In the presence of irreducible control flow, the fixed point is
244 //   reached only after multiple iterations. This is because labels
245 //   reaching the header of a cycle must be repropagated through the
246 //   cycle. This is true even in a reducible cycle, since the labels
247 //   may have been produced by a nested irreducible cycle.
248 //
249 // * Note that SyncDependenceAnalysis is not concerned with the points
250 //   of convergence in an irreducible cycle. It's only purpose is to
251 //   identify join blocks. The "diverged entry" criterion is
252 //   separately applied on join blocks to determine if an entire
253 //   irreducible cycle is assumed to be divergent.
254 //
255 // * Relevant related work:
256 //     A simple algorithm for global data flow analysis problems.
257 //     Matthew S. Hecht and Jeffrey D. Ullman.
258 //     SIAM Journal on Computing, 4(4):519–532, December 1975.
259 //
260 template <typename ContextT> class GenericSyncDependenceAnalysis {
261 public:
262   using BlockT = typename ContextT::BlockT;
263   using DominatorTreeT = typename ContextT::DominatorTreeT;
264   using FunctionT = typename ContextT::FunctionT;
265   using ValueRefT = typename ContextT::ValueRefT;
266   using InstructionT = typename ContextT::InstructionT;
267 
268   using CycleInfoT = GenericCycleInfo<ContextT>;
269   using CycleT = typename CycleInfoT::CycleT;
270 
271   using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
272   using ModifiedPO = ModifiedPostOrder<ContextT>;
273 
274   // * if BlockLabels[B] == C then C is the dominating definition at
275   //   block B
276   // * if BlockLabels[B] == nullptr then we haven't seen B yet
277   // * if BlockLabels[B] == B then:
278   //   - B is a join point of disjoint paths from X, or,
279   //   - B is an immediate successor of X (initial value), or,
280   //   - B is X
281   using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
282 
283   /// Information discovered by the sync dependence analysis for each
284   /// divergent branch.
285   struct DivergenceDescriptor {
286     // Join points of diverged paths.
287     ConstBlockSet JoinDivBlocks;
288     // Divergent cycle exits
289     ConstBlockSet CycleDivBlocks;
290     // Labels assigned to blocks on diverged paths.
291     BlockLabelMap BlockLabels;
292   };
293 
294   using DivergencePropagatorT = DivergencePropagator<ContextT>;
295 
296   GenericSyncDependenceAnalysis(const ContextT &Context,
297                                 const DominatorTreeT &DT, const CycleInfoT &CI);
298 
299   /// \brief Computes divergent join points and cycle exits caused by branch
300   /// divergence in \p Term.
301   ///
302   /// This returns a pair of sets:
303   /// * The set of blocks which are reachable by disjoint paths from
304   ///   \p Term.
305   /// * The set also contains cycle exits if there two disjoint paths:
306   ///   one from \p Term to the cycle exit and another from \p Term to
307   ///   the cycle header.
308   const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
309 
310 private:
311   static DivergenceDescriptor EmptyDivergenceDesc;
312 
313   ModifiedPO CyclePO;
314 
315   const DominatorTreeT &DT;
316   const CycleInfoT &CI;
317 
318   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
319       CachedControlDivDescs;
320 };
321 
322 /// \brief Analysis that identifies uniform values in a data-parallel
323 /// execution.
324 ///
325 /// This analysis propagates divergence in a data-parallel context
326 /// from sources of divergence to all users. It can be instantiated
327 /// for an IR that provides a suitable SSAContext.
328 template <typename ContextT> class GenericUniformityAnalysisImpl {
329 public:
330   using BlockT = typename ContextT::BlockT;
331   using FunctionT = typename ContextT::FunctionT;
332   using ValueRefT = typename ContextT::ValueRefT;
333   using ConstValueRefT = typename ContextT::ConstValueRefT;
334   using UseT = typename ContextT::UseT;
335   using InstructionT = typename ContextT::InstructionT;
336   using DominatorTreeT = typename ContextT::DominatorTreeT;
337 
338   using CycleInfoT = GenericCycleInfo<ContextT>;
339   using CycleT = typename CycleInfoT::CycleT;
340 
341   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
342   using DivergenceDescriptorT =
343       typename SyncDependenceAnalysisT::DivergenceDescriptor;
344   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
345 
GenericUniformityAnalysisImpl(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)346   GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
347                                 const TargetTransformInfo *TTI)
348       : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
349         TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
350 
351   void initialize();
352 
getFunction()353   const FunctionT &getFunction() const { return F; }
354 
355   /// \brief Mark \p UniVal as a value that is always uniform.
356   void addUniformOverride(const InstructionT &Instr);
357 
358   /// \brief Examine \p I for divergent outputs and add to the worklist.
359   void markDivergent(const InstructionT &I);
360 
361   /// \brief Mark \p DivVal as a divergent value.
362   /// \returns Whether the tracked divergence state of \p DivVal changed.
363   bool markDivergent(ConstValueRefT DivVal);
364 
365   /// \brief Mark outputs of \p Instr as divergent.
366   /// \returns Whether the tracked divergence state of any output has changed.
367   bool markDefsDivergent(const InstructionT &Instr);
368 
369   /// \brief Propagate divergence to all instructions in the region.
370   /// Divergence is seeded by calls to \p markDivergent.
371   void compute();
372 
373   /// \brief Whether any value was marked or analyzed to be divergent.
hasDivergence()374   bool hasDivergence() const { return !DivergentValues.empty(); }
375 
376   /// \brief Whether \p Val will always return a uniform value regardless of its
377   /// operands
378   bool isAlwaysUniform(const InstructionT &Instr) const;
379 
380   bool hasDivergentDefs(const InstructionT &I) const;
381 
isDivergent(const InstructionT & I)382   bool isDivergent(const InstructionT &I) const {
383     if (I.isTerminator()) {
384       return DivergentTermBlocks.contains(I.getParent());
385     }
386     return hasDivergentDefs(I);
387   };
388 
389   /// \brief Whether \p Val is divergent at its definition.
isDivergent(ConstValueRefT V)390   bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
391 
392   bool isDivergentUse(const UseT &U) const;
393 
hasDivergentTerminator(const BlockT & B)394   bool hasDivergentTerminator(const BlockT &B) const {
395     return DivergentTermBlocks.contains(&B);
396   }
397 
398   void print(raw_ostream &out) const;
399 
400 protected:
401   /// \brief Value/block pair representing a single phi input.
402   struct PhiInput {
403     ConstValueRefT value;
404     BlockT *predBlock;
405 
PhiInputPhiInput406     PhiInput(ConstValueRefT value, BlockT *predBlock)
407         : value(value), predBlock(predBlock) {}
408   };
409 
410   const ContextT &Context;
411   const FunctionT &F;
412   const CycleInfoT &CI;
413   const TargetTransformInfo *TTI = nullptr;
414 
415   // Detected/marked divergent values.
416   std::set<ConstValueRefT> DivergentValues;
417   SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
418 
419   // Internal worklist for divergence propagation.
420   std::vector<const InstructionT *> Worklist;
421 
422   /// \brief Mark \p Term as divergent and push all Instructions that become
423   /// divergent as a result on the worklist.
424   void analyzeControlDivergence(const InstructionT &Term);
425 
426 private:
427   const DominatorTreeT &DT;
428 
429   // Recognized cycles with divergent exits.
430   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
431 
432   // Cycles assumed to be divergent.
433   //
434   // We don't use a set here because every insertion needs an explicit
435   // traversal of all existing members.
436   SmallVector<const CycleT *> AssumedDivergent;
437 
438   // The SDA links divergent branches to divergent control-flow joins.
439   SyncDependenceAnalysisT SDA;
440 
441   // Set of known-uniform values.
442   SmallPtrSet<const InstructionT *, 32> UniformOverrides;
443 
444   /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
445   /// the worklist.
446   void taintAndPushAllDefs(const BlockT &JoinBlock);
447 
448   /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
449   /// the worklist.
450   void taintAndPushPhiNodes(const BlockT &JoinBlock);
451 
452   /// \brief Identify all Instructions that become divergent because \p DivExit
453   /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
454   /// divergent and push them on the worklist.
455   void propagateCycleExitDivergence(const BlockT &DivExit,
456                                     const CycleT &DivCycle);
457 
458   /// Mark as divergent all external uses of values defined in \p DefCycle.
459   void analyzeCycleExitDivergence(const CycleT &DefCycle);
460 
461   /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle.
462   void propagateTemporalDivergence(const InstructionT &I,
463                                    const CycleT &DefCycle);
464 
465   /// \brief Push all users of \p Val (in the region) to the worklist.
466   void pushUsers(const InstructionT &I);
467   void pushUsers(ConstValueRefT V);
468 
469   bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
470 
471   /// \brief Whether \p Def is divergent when read in \p ObservingBlock.
472   bool isTemporalDivergent(const BlockT &ObservingBlock,
473                            const InstructionT &Def) const;
474 };
475 
476 template <typename ImplT>
operator()477 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
478   delete Impl;
479 }
480 
481 /// Compute divergence starting with a divergent branch.
482 template <typename ContextT> class DivergencePropagator {
483 public:
484   using BlockT = typename ContextT::BlockT;
485   using DominatorTreeT = typename ContextT::DominatorTreeT;
486   using FunctionT = typename ContextT::FunctionT;
487   using ValueRefT = typename ContextT::ValueRefT;
488 
489   using CycleInfoT = GenericCycleInfo<ContextT>;
490   using CycleT = typename CycleInfoT::CycleT;
491 
492   using ModifiedPO = ModifiedPostOrder<ContextT>;
493   using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
494   using DivergenceDescriptorT =
495       typename SyncDependenceAnalysisT::DivergenceDescriptor;
496   using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
497 
498   const ModifiedPO &CyclePOT;
499   const DominatorTreeT &DT;
500   const CycleInfoT &CI;
501   const BlockT &DivTermBlock;
502   const ContextT &Context;
503 
504   // Track blocks that receive a new label. Every time we relabel a
505   // cycle header, we another pass over the modified post-order in
506   // order to propagate the header label. The bit vector also allows
507   // us to skip labels that have not changed.
508   SparseBitVector<> FreshLabels;
509 
510   // divergent join and cycle exit descriptor.
511   std::unique_ptr<DivergenceDescriptorT> DivDesc;
512   BlockLabelMapT &BlockLabels;
513 
DivergencePropagator(const ModifiedPO & CyclePOT,const DominatorTreeT & DT,const CycleInfoT & CI,const BlockT & DivTermBlock)514   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
515                        const CycleInfoT &CI, const BlockT &DivTermBlock)
516       : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
517         Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
518         BlockLabels(DivDesc->BlockLabels) {}
519 
printDefs(raw_ostream & Out)520   void printDefs(raw_ostream &Out) {
521     Out << "Propagator::BlockLabels {\n";
522     for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
523       const auto *Block = CyclePOT[BlockIdx];
524       const auto *Label = BlockLabels[Block];
525       Out << Context.print(Block) << "(" << BlockIdx << ") : ";
526       if (!Label) {
527         Out << "<null>\n";
528       } else {
529         Out << Context.print(Label) << "\n";
530       }
531     }
532     Out << "}\n";
533   }
534 
535   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
536   // causes a divergent join.
computeJoin(const BlockT & SuccBlock,const BlockT & PushedLabel)537   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
538     const auto *OldLabel = BlockLabels[&SuccBlock];
539 
540     LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
541                       << "\tpushed label: " << Context.print(&PushedLabel)
542                       << "\n"
543                       << "\told label: " << Context.print(OldLabel) << "\n");
544 
545     // Early exit if there is no change in the label.
546     if (OldLabel == &PushedLabel)
547       return false;
548 
549     if (OldLabel != &SuccBlock) {
550       auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
551       // Assigning a new label, mark this in FreshLabels.
552       LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
553       FreshLabels.set(SuccIdx);
554     }
555 
556     // This is not a join if the succ was previously unlabeled.
557     if (!OldLabel) {
558       LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
559                         << "\n");
560       BlockLabels[&SuccBlock] = &PushedLabel;
561       return false;
562     }
563 
564     // This is a new join. Label the join block as itself, and not as
565     // the pushed label.
566     LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
567     BlockLabels[&SuccBlock] = &SuccBlock;
568 
569     return true;
570   }
571 
572   // visiting a virtual cycle exit edge from the cycle header --> temporal
573   // divergence on join
visitCycleExitEdge(const BlockT & ExitBlock,const BlockT & Label)574   bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
575     if (!computeJoin(ExitBlock, Label))
576       return false;
577 
578     // Identified a divergent cycle exit
579     DivDesc->CycleDivBlocks.insert(&ExitBlock);
580     LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
581                       << "\n");
582     return true;
583   }
584 
585   // process \p SuccBlock with reaching definition \p Label
visitEdge(const BlockT & SuccBlock,const BlockT & Label)586   bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
587     if (!computeJoin(SuccBlock, Label))
588       return false;
589 
590     // Divergent, disjoint paths join.
591     DivDesc->JoinDivBlocks.insert(&SuccBlock);
592     LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
593                       << "\n");
594     return true;
595   }
596 
computeJoinPoints()597   std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
598     assert(DivDesc);
599 
600     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
601                       << Context.print(&DivTermBlock) << "\n");
602 
603     // Early stopping criterion
604     int FloorIdx = CyclePOT.size() - 1;
605     const BlockT *FloorLabel = nullptr;
606     int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
607 
608     // Bootstrap with branch targets
609     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
610     for (const auto *SuccBlock : successors(&DivTermBlock)) {
611       if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
612         // If DivTerm exits the cycle immediately, computeJoin() might
613         // not reach SuccBlock with a different label. We need to
614         // check for this exit now.
615         DivDesc->CycleDivBlocks.insert(SuccBlock);
616         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
617                           << Context.print(SuccBlock) << "\n");
618       }
619       auto SuccIdx = CyclePOT.getIndex(SuccBlock);
620       visitEdge(*SuccBlock, *SuccBlock);
621       FloorIdx = std::min<int>(FloorIdx, SuccIdx);
622     }
623 
624     while (true) {
625       auto BlockIdx = FreshLabels.find_last();
626       if (BlockIdx == -1 || BlockIdx < FloorIdx)
627         break;
628 
629       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
630 
631       FreshLabels.reset(BlockIdx);
632       if (BlockIdx == DivTermIdx) {
633         LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
634         continue;
635       }
636 
637       const auto *Block = CyclePOT[BlockIdx];
638       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
639                         << BlockIdx << "\n");
640 
641       const auto *Label = BlockLabels[Block];
642       assert(Label);
643 
644       bool CausedJoin = false;
645       int LoweredFloorIdx = FloorIdx;
646 
647       // If the current block is the header of a reducible cycle that
648       // contains the divergent branch, then the label should be
649       // propagated to the cycle exits. Such a header is the "last
650       // possible join" of any disjoint paths within this cycle. This
651       // prevents detection of spurious joins at the entries of any
652       // irreducible child cycles.
653       //
654       // This conclusion about the header is true for any choice of DFS:
655       //
656       //   If some DFS has a reducible cycle C with header H, then for
657       //   any other DFS, H is the header of a cycle C' that is a
658       //   superset of C. For a divergent branch inside the subgraph
659       //   C, any join node inside C is either H, or some node
660       //   encountered without passing through H.
661       //
662       auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
663         if (!CyclePOT.isReducibleCycleHeader(Block))
664           return nullptr;
665         const auto *BlockCycle = CI.getCycle(Block);
666         if (BlockCycle->contains(&DivTermBlock))
667           return BlockCycle;
668         return nullptr;
669       };
670 
671       if (const auto *BlockCycle = getReducibleParent(Block)) {
672         SmallVector<BlockT *, 4> BlockCycleExits;
673         BlockCycle->getExitBlocks(BlockCycleExits);
674         for (auto *BlockCycleExit : BlockCycleExits) {
675           CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
676           LoweredFloorIdx =
677               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
678         }
679       } else {
680         for (const auto *SuccBlock : successors(Block)) {
681           CausedJoin |= visitEdge(*SuccBlock, *Label);
682           LoweredFloorIdx =
683               std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
684         }
685       }
686 
687       // Floor update
688       if (CausedJoin) {
689         // 1. Different labels pushed to successors
690         FloorIdx = LoweredFloorIdx;
691       } else if (FloorLabel != Label) {
692         // 2. No join caused BUT we pushed a label that is different than the
693         // last pushed label
694         FloorIdx = LoweredFloorIdx;
695         FloorLabel = Label;
696       }
697     }
698 
699     LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
700 
701     // Check every cycle containing DivTermBlock for exit divergence.
702     // A cycle has exit divergence if the label of an exit block does
703     // not match the label of its header.
704     for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
705          Cycle = Cycle->getParentCycle()) {
706       if (Cycle->isReducible()) {
707         // The exit divergence of a reducible cycle is recorded while
708         // propagating labels.
709         continue;
710       }
711       SmallVector<BlockT *> Exits;
712       Cycle->getExitBlocks(Exits);
713       auto *Header = Cycle->getHeader();
714       auto *HeaderLabel = BlockLabels[Header];
715       for (const auto *Exit : Exits) {
716         if (BlockLabels[Exit] != HeaderLabel) {
717           // Identified a divergent cycle exit
718           DivDesc->CycleDivBlocks.insert(Exit);
719           LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
720                             << "\n");
721         }
722       }
723     }
724 
725     return std::move(DivDesc);
726   }
727 };
728 
729 template <typename ContextT>
730 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
731     llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
732 
733 template <typename ContextT>
GenericSyncDependenceAnalysis(const ContextT & Context,const DominatorTreeT & DT,const CycleInfoT & CI)734 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
735     const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
736     : CyclePO(Context), DT(DT), CI(CI) {
737   CyclePO.compute(CI);
738 }
739 
740 template <typename ContextT>
741 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
742     const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
743   // trivial case
744   if (succ_size(DivTermBlock) <= 1) {
745     return EmptyDivergenceDesc;
746   }
747 
748   // already available in cache?
749   auto ItCached = CachedControlDivDescs.find(DivTermBlock);
750   if (ItCached != CachedControlDivDescs.end())
751     return *ItCached->second;
752 
753   // compute all join points
754   DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
755   auto DivDesc = Propagator.computeJoinPoints();
756 
757   auto printBlockSet = [&](ConstBlockSet &Blocks) {
758     return Printable([&](raw_ostream &Out) {
759       Out << "[";
760       ListSeparator LS;
761       for (const auto *BB : Blocks) {
762         Out << LS << CI.getSSAContext().print(BB);
763       }
764       Out << "]\n";
765     });
766   };
767 
768   LLVM_DEBUG(
769       dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
770              << "):\n  JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
771              << "  CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
772              << "\n");
773   (void)printBlockSet;
774 
775   auto ItInserted =
776       CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
777   assert(ItInserted.second);
778   return *ItInserted.first->second;
779 }
780 
781 template <typename ContextT>
markDivergent(const InstructionT & I)782 void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
783     const InstructionT &I) {
784   if (isAlwaysUniform(I))
785     return;
786   bool Marked = false;
787   if (I.isTerminator()) {
788     Marked = DivergentTermBlocks.insert(I.getParent()).second;
789     if (Marked) {
790       LLVM_DEBUG(dbgs() << "marked divergent term block: "
791                         << Context.print(I.getParent()) << "\n");
792     }
793   } else {
794     Marked = markDefsDivergent(I);
795   }
796 
797   if (Marked)
798     Worklist.push_back(&I);
799 }
800 
801 template <typename ContextT>
markDivergent(ConstValueRefT Val)802 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
803     ConstValueRefT Val) {
804   if (DivergentValues.insert(Val).second) {
805     LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
806     return true;
807   }
808   return false;
809 }
810 
811 template <typename ContextT>
addUniformOverride(const InstructionT & Instr)812 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
813     const InstructionT &Instr) {
814   UniformOverrides.insert(&Instr);
815 }
816 
817 // Mark as divergent all external uses of values defined in \p DefCycle.
818 //
819 // A value V defined by a block B inside \p DefCycle may be used outside the
820 // cycle only if the use is a PHI in some exit block, or B dominates some exit
821 // block. Thus, we check uses as follows:
822 //
823 // - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle.
824 // - For every block B inside \p DefCycle that dominates at least one exit
825 //   block, check all uses outside \p DefCycle.
826 //
827 // FIXME: This function does not distinguish between divergent and uniform
828 // exits. For each divergent exit, only the values that are live at that exit
829 // need to be propagated as divergent at their use outside the cycle.
830 template <typename ContextT>
analyzeCycleExitDivergence(const CycleT & DefCycle)831 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
832     const CycleT &DefCycle) {
833   SmallVector<BlockT *> Exits;
834   DefCycle.getExitBlocks(Exits);
835   for (auto *Exit : Exits) {
836     for (auto &Phi : Exit->phis()) {
837       if (usesValueFromCycle(Phi, DefCycle)) {
838         markDivergent(Phi);
839       }
840     }
841   }
842 
843   for (auto *BB : DefCycle.blocks()) {
844     if (!llvm::any_of(Exits,
845                      [&](BlockT *Exit) { return DT.dominates(BB, Exit); }))
846       continue;
847     for (auto &II : *BB) {
848       propagateTemporalDivergence(II, DefCycle);
849     }
850   }
851 }
852 
853 template <typename ContextT>
propagateCycleExitDivergence(const BlockT & DivExit,const CycleT & InnerDivCycle)854 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
855     const BlockT &DivExit, const CycleT &InnerDivCycle) {
856   LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
857                     << "\n");
858   auto *DivCycle = &InnerDivCycle;
859   auto *OuterDivCycle = DivCycle;
860   auto *ExitLevelCycle = CI.getCycle(&DivExit);
861   const unsigned CycleExitDepth =
862       ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
863 
864   // Find outer-most cycle that does not contain \p DivExit
865   while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
866     LLVM_DEBUG(dbgs() << "  Found exiting cycle: "
867                       << Context.print(DivCycle->getHeader()) << "\n");
868     OuterDivCycle = DivCycle;
869     DivCycle = DivCycle->getParentCycle();
870   }
871   LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
872                     << Context.print(OuterDivCycle->getHeader()) << "\n");
873 
874   if (!DivergentExitCycles.insert(OuterDivCycle).second)
875     return;
876 
877   // Exit divergence does not matter if the cycle itself is assumed to
878   // be divergent.
879   for (const auto *C : AssumedDivergent) {
880     if (C->contains(OuterDivCycle))
881       return;
882   }
883 
884   analyzeCycleExitDivergence(*OuterDivCycle);
885 }
886 
887 template <typename ContextT>
taintAndPushAllDefs(const BlockT & BB)888 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
889     const BlockT &BB) {
890   LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
891   for (const auto &I : instrs(BB)) {
892     // Terminators do not produce values; they are divergent only if
893     // the condition is divergent. That is handled when the divergent
894     // condition is placed in the worklist.
895     if (I.isTerminator())
896       break;
897 
898     markDivergent(I);
899   }
900 }
901 
902 /// Mark divergent phi nodes in a join block
903 template <typename ContextT>
taintAndPushPhiNodes(const BlockT & JoinBlock)904 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
905     const BlockT &JoinBlock) {
906   LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
907                     << "\n");
908   for (const auto &Phi : JoinBlock.phis()) {
909     // FIXME: The non-undef value is not constant per se; it just happens to be
910     // uniform and may not dominate this PHI. So assuming that the same value
911     // reaches along all incoming edges may itself be undefined behaviour. This
912     // particular interpretation of the undef value was added to
913     // DivergenceAnalysis in the following review:
914     //
915     // https://reviews.llvm.org/D19013
916     if (ContextT::isConstantOrUndefValuePhi(Phi))
917       continue;
918     markDivergent(Phi);
919   }
920 }
921 
922 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
923 ///
924 /// \return true iff \p Candidate was added to \p Cycles.
925 template <typename CycleT>
insertIfNotContained(SmallVector<CycleT * > & Cycles,CycleT * Candidate)926 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
927                                  CycleT *Candidate) {
928   if (llvm::any_of(Cycles,
929                    [Candidate](CycleT *C) { return C->contains(Candidate); }))
930     return false;
931   Cycles.push_back(Candidate);
932   return true;
933 }
934 
935 /// Return the outermost cycle made divergent by branch outside it.
936 ///
937 /// If two paths that diverged outside an irreducible cycle join
938 /// inside that cycle, then that whole cycle is assumed to be
939 /// divergent. This does not apply if the cycle is reducible.
940 template <typename CycleT, typename BlockT>
getExtDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock)941 static const CycleT *getExtDivCycle(const CycleT *Cycle,
942                                     const BlockT *DivTermBlock,
943                                     const BlockT *JoinBlock) {
944   assert(Cycle);
945   assert(Cycle->contains(JoinBlock));
946 
947   if (Cycle->contains(DivTermBlock))
948     return nullptr;
949 
950   const auto *OriginalCycle = Cycle;
951   const auto *Parent = Cycle->getParentCycle();
952   while (Parent && !Parent->contains(DivTermBlock)) {
953     Cycle = Parent;
954     Parent = Cycle->getParentCycle();
955   }
956 
957   // If the original cycle is not the outermost cycle, then the outermost cycle
958   // is irreducible. If the outermost cycle were reducible, then external
959   // diverged paths would not reach the original inner cycle.
960   (void)OriginalCycle;
961   assert(Cycle == OriginalCycle || !Cycle->isReducible());
962 
963   if (Cycle->isReducible()) {
964     assert(Cycle->getHeader() == JoinBlock);
965     return nullptr;
966   }
967 
968   LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
969   return Cycle;
970 }
971 
972 /// Return the outermost cycle made divergent by branch inside it.
973 ///
974 /// This checks the "diverged entry" criterion defined in the
975 /// docs/ConvergenceAnalysis.html.
976 template <typename ContextT, typename CycleT, typename BlockT,
977           typename DominatorTreeT>
978 static const CycleT *
getIntDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)979 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
980                const BlockT *JoinBlock, const DominatorTreeT &DT,
981                ContextT &Context) {
982   LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
983                     << " for internal branch " << Context.print(DivTermBlock)
984                     << "\n");
985   if (DT.properlyDominates(DivTermBlock, JoinBlock))
986     return nullptr;
987 
988   // Find the smallest common cycle, if one exists.
989   assert(Cycle && Cycle->contains(JoinBlock));
990   while (Cycle && !Cycle->contains(DivTermBlock)) {
991     Cycle = Cycle->getParentCycle();
992   }
993   if (!Cycle || Cycle->isReducible())
994     return nullptr;
995 
996   if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
997     return nullptr;
998 
999   LLVM_DEBUG(dbgs() << "  header " << Context.print(Cycle->getHeader())
1000                     << " does not dominate join\n");
1001 
1002   const auto *Parent = Cycle->getParentCycle();
1003   while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1004     LLVM_DEBUG(dbgs() << "  header " << Context.print(Parent->getHeader())
1005                       << " does not dominate join\n");
1006     Cycle = Parent;
1007     Parent = Parent->getParentCycle();
1008   }
1009 
1010   LLVM_DEBUG(dbgs() << "  cycle made divergent by internal branch\n");
1011   return Cycle;
1012 }
1013 
1014 template <typename ContextT, typename CycleT, typename BlockT,
1015           typename DominatorTreeT>
1016 static const CycleT *
getOutermostDivergentCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1017 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1018                            const BlockT *JoinBlock, const DominatorTreeT &DT,
1019                            ContextT &Context) {
1020   if (!Cycle)
1021     return nullptr;
1022 
1023   // First try to expand Cycle to the largest that contains JoinBlock
1024   // but not DivTermBlock.
1025   const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1026 
1027   // Continue expanding to the largest cycle that contains both.
1028   const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1029 
1030   if (Int)
1031     return Int;
1032   return Ext;
1033 }
1034 
1035 template <typename ContextT>
isTemporalDivergent(const BlockT & ObservingBlock,const InstructionT & Def)1036 bool GenericUniformityAnalysisImpl<ContextT>::isTemporalDivergent(
1037     const BlockT &ObservingBlock, const InstructionT &Def) const {
1038   const BlockT *DefBlock = Def.getParent();
1039   for (const CycleT *Cycle = CI.getCycle(DefBlock);
1040        Cycle && !Cycle->contains(&ObservingBlock);
1041        Cycle = Cycle->getParentCycle()) {
1042     if (DivergentExitCycles.contains(Cycle)) {
1043       return true;
1044     }
1045   }
1046   return false;
1047 }
1048 
1049 template <typename ContextT>
analyzeControlDivergence(const InstructionT & Term)1050 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1051     const InstructionT &Term) {
1052   const auto *DivTermBlock = Term.getParent();
1053   DivergentTermBlocks.insert(DivTermBlock);
1054   LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1055                     << "\n");
1056 
1057   // Don't propagate divergence from unreachable blocks.
1058   if (!DT.isReachableFromEntry(DivTermBlock))
1059     return;
1060 
1061   const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1062   SmallVector<const CycleT *> DivCycles;
1063 
1064   // Iterate over all blocks now reachable by a disjoint path join
1065   for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1066     const auto *Cycle = CI.getCycle(JoinBlock);
1067     LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1068                       << "\n");
1069     if (const auto *Outermost = getOutermostDivergentCycle(
1070             Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1071       LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1072       DivCycles.push_back(Outermost);
1073       continue;
1074     }
1075     taintAndPushPhiNodes(*JoinBlock);
1076   }
1077 
1078   // Sort by order of decreasing depth. This allows later cycles to be skipped
1079   // because they are already contained in earlier ones.
1080   llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1081     return A->getDepth() > B->getDepth();
1082   });
1083 
1084   // Cycles that are assumed divergent due to the diverged entry
1085   // criterion potentially contain temporal divergence depending on
1086   // the DFS chosen. Conservatively, all values produced in such a
1087   // cycle are assumed divergent. "Cycle invariant" values may be
1088   // assumed uniform, but that requires further analysis.
1089   for (auto *C : DivCycles) {
1090     if (!insertIfNotContained(AssumedDivergent, C))
1091       continue;
1092     LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1093     for (const BlockT *BB : C->blocks()) {
1094       taintAndPushAllDefs(*BB);
1095     }
1096   }
1097 
1098   const auto *BranchCycle = CI.getCycle(DivTermBlock);
1099   assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1100   for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1101     propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1102   }
1103 }
1104 
1105 template <typename ContextT>
compute()1106 void GenericUniformityAnalysisImpl<ContextT>::compute() {
1107   // Initialize worklist.
1108   auto DivValuesCopy = DivergentValues;
1109   for (const auto DivVal : DivValuesCopy) {
1110     assert(isDivergent(DivVal) && "Worklist invariant violated!");
1111     pushUsers(DivVal);
1112   }
1113 
1114   // All values on the Worklist are divergent.
1115   // Their users may not have been updated yet.
1116   while (!Worklist.empty()) {
1117     const InstructionT *I = Worklist.back();
1118     Worklist.pop_back();
1119 
1120     LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1121 
1122     if (I->isTerminator()) {
1123       analyzeControlDivergence(*I);
1124       continue;
1125     }
1126 
1127     // propagate value divergence to users
1128     assert(isDivergent(*I) && "Worklist invariant violated!");
1129     pushUsers(*I);
1130   }
1131 }
1132 
1133 template <typename ContextT>
isAlwaysUniform(const InstructionT & Instr)1134 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1135     const InstructionT &Instr) const {
1136   return UniformOverrides.contains(&Instr);
1137 }
1138 
1139 template <typename ContextT>
GenericUniformityInfo(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)1140 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1141     const DominatorTreeT &DT, const CycleInfoT &CI,
1142     const TargetTransformInfo *TTI) {
1143   DA.reset(new ImplT{DT, CI, TTI});
1144 }
1145 
1146 template <typename ContextT>
print(raw_ostream & OS)1147 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1148   bool haveDivergentArgs = false;
1149 
1150   // Control flow instructions may be divergent even if their inputs are
1151   // uniform. Thus, although exceedingly rare, it is possible to have a program
1152   // with no divergent values but with divergent control structures.
1153   if (DivergentValues.empty() && DivergentTermBlocks.empty() &&
1154       DivergentExitCycles.empty()) {
1155     OS << "ALL VALUES UNIFORM\n";
1156     return;
1157   }
1158 
1159   for (const auto &entry : DivergentValues) {
1160     const BlockT *parent = Context.getDefBlock(entry);
1161     if (!parent) {
1162       if (!haveDivergentArgs) {
1163         OS << "DIVERGENT ARGUMENTS:\n";
1164         haveDivergentArgs = true;
1165       }
1166       OS << "  DIVERGENT: " << Context.print(entry) << '\n';
1167     }
1168   }
1169 
1170   if (!AssumedDivergent.empty()) {
1171     OS << "CYCLES ASSSUMED DIVERGENT:\n";
1172     for (const CycleT *cycle : AssumedDivergent) {
1173       OS << "  " << cycle->print(Context) << '\n';
1174     }
1175   }
1176 
1177   if (!DivergentExitCycles.empty()) {
1178     OS << "CYCLES WITH DIVERGENT EXIT:\n";
1179     for (const CycleT *cycle : DivergentExitCycles) {
1180       OS << "  " << cycle->print(Context) << '\n';
1181     }
1182   }
1183 
1184   for (auto &block : F) {
1185     OS << "\nBLOCK " << Context.print(&block) << '\n';
1186 
1187     OS << "DEFINITIONS\n";
1188     SmallVector<ConstValueRefT, 16> defs;
1189     Context.appendBlockDefs(defs, block);
1190     for (auto value : defs) {
1191       if (isDivergent(value))
1192         OS << "  DIVERGENT: ";
1193       else
1194         OS << "             ";
1195       OS << Context.print(value) << '\n';
1196     }
1197 
1198     OS << "TERMINATORS\n";
1199     SmallVector<const InstructionT *, 8> terms;
1200     Context.appendBlockTerms(terms, block);
1201     bool divergentTerminators = hasDivergentTerminator(block);
1202     for (auto *T : terms) {
1203       if (divergentTerminators)
1204         OS << "  DIVERGENT: ";
1205       else
1206         OS << "             ";
1207       OS << Context.print(T) << '\n';
1208     }
1209 
1210     OS << "END BLOCK\n";
1211   }
1212 }
1213 
1214 template <typename ContextT>
hasDivergence()1215 bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1216   return DA->hasDivergence();
1217 }
1218 
1219 template <typename ContextT>
1220 const typename ContextT::FunctionT &
getFunction()1221 GenericUniformityInfo<ContextT>::getFunction() const {
1222   return DA->getFunction();
1223 }
1224 
1225 /// Whether \p V is divergent at its definition.
1226 template <typename ContextT>
isDivergent(ConstValueRefT V)1227 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1228   return DA->isDivergent(V);
1229 }
1230 
1231 template <typename ContextT>
isDivergent(const InstructionT * I)1232 bool GenericUniformityInfo<ContextT>::isDivergent(const InstructionT *I) const {
1233   return DA->isDivergent(*I);
1234 }
1235 
1236 template <typename ContextT>
isDivergentUse(const UseT & U)1237 bool GenericUniformityInfo<ContextT>::isDivergentUse(const UseT &U) const {
1238   return DA->isDivergentUse(U);
1239 }
1240 
1241 template <typename ContextT>
hasDivergentTerminator(const BlockT & B)1242 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1243   return DA->hasDivergentTerminator(B);
1244 }
1245 
1246 /// \brief T helper function for printing.
1247 template <typename ContextT>
print(raw_ostream & out)1248 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1249   DA->print(out);
1250 }
1251 
1252 template <typename ContextT>
computeStackPO(SmallVectorImpl<const BlockT * > & Stack,const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1253 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1254     SmallVectorImpl<const BlockT *> &Stack, const CycleInfoT &CI,
1255     const CycleT *Cycle, SmallPtrSetImpl<const BlockT *> &Finalized) {
1256   LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1257   while (!Stack.empty()) {
1258     auto *NextBB = Stack.back();
1259     if (Finalized.count(NextBB)) {
1260       Stack.pop_back();
1261       continue;
1262     }
1263     LLVM_DEBUG(dbgs() << "  visiting " << CI.getSSAContext().print(NextBB)
1264                       << "\n");
1265     auto *NestedCycle = CI.getCycle(NextBB);
1266     if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1267       LLVM_DEBUG(dbgs() << "  found a cycle\n");
1268       while (NestedCycle->getParentCycle() != Cycle)
1269         NestedCycle = NestedCycle->getParentCycle();
1270 
1271       SmallVector<BlockT *, 3> NestedExits;
1272       NestedCycle->getExitBlocks(NestedExits);
1273       bool PushedNodes = false;
1274       for (auto *NestedExitBB : NestedExits) {
1275         LLVM_DEBUG(dbgs() << "  examine exit: "
1276                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1277         if (Cycle && !Cycle->contains(NestedExitBB))
1278           continue;
1279         if (Finalized.count(NestedExitBB))
1280           continue;
1281         PushedNodes = true;
1282         Stack.push_back(NestedExitBB);
1283         LLVM_DEBUG(dbgs() << "  pushed exit: "
1284                           << CI.getSSAContext().print(NestedExitBB) << "\n");
1285       }
1286       if (!PushedNodes) {
1287         // All loop exits finalized -> finish this node
1288         Stack.pop_back();
1289         computeCyclePO(CI, NestedCycle, Finalized);
1290       }
1291       continue;
1292     }
1293 
1294     LLVM_DEBUG(dbgs() << "  no nested cycle, going into DAG\n");
1295     // DAG-style
1296     bool PushedNodes = false;
1297     for (auto *SuccBB : successors(NextBB)) {
1298       LLVM_DEBUG(dbgs() << "  examine succ: "
1299                         << CI.getSSAContext().print(SuccBB) << "\n");
1300       if (Cycle && !Cycle->contains(SuccBB))
1301         continue;
1302       if (Finalized.count(SuccBB))
1303         continue;
1304       PushedNodes = true;
1305       Stack.push_back(SuccBB);
1306       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(SuccBB)
1307                         << "\n");
1308     }
1309     if (!PushedNodes) {
1310       // Never push nodes twice
1311       LLVM_DEBUG(dbgs() << "  finishing node: "
1312                         << CI.getSSAContext().print(NextBB) << "\n");
1313       Stack.pop_back();
1314       Finalized.insert(NextBB);
1315       appendBlock(*NextBB);
1316     }
1317   }
1318   LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1319 }
1320 
1321 template <typename ContextT>
computeCyclePO(const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1322 void ModifiedPostOrder<ContextT>::computeCyclePO(
1323     const CycleInfoT &CI, const CycleT *Cycle,
1324     SmallPtrSetImpl<const BlockT *> &Finalized) {
1325   LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1326   SmallVector<const BlockT *> Stack;
1327   auto *CycleHeader = Cycle->getHeader();
1328 
1329   LLVM_DEBUG(dbgs() << "  noted header: "
1330                     << CI.getSSAContext().print(CycleHeader) << "\n");
1331   assert(!Finalized.count(CycleHeader));
1332   Finalized.insert(CycleHeader);
1333 
1334   // Visit the header last
1335   LLVM_DEBUG(dbgs() << "  finishing header: "
1336                     << CI.getSSAContext().print(CycleHeader) << "\n");
1337   appendBlock(*CycleHeader, Cycle->isReducible());
1338 
1339   // Initialize with immediate successors
1340   for (auto *BB : successors(CycleHeader)) {
1341     LLVM_DEBUG(dbgs() << "  examine succ: " << CI.getSSAContext().print(BB)
1342                       << "\n");
1343     if (!Cycle->contains(BB))
1344       continue;
1345     if (BB == CycleHeader)
1346       continue;
1347     if (!Finalized.count(BB)) {
1348       LLVM_DEBUG(dbgs() << "  pushed succ: " << CI.getSSAContext().print(BB)
1349                         << "\n");
1350       Stack.push_back(BB);
1351     }
1352   }
1353 
1354   // Compute PO inside region
1355   computeStackPO(Stack, CI, Cycle, Finalized);
1356 
1357   LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1358 }
1359 
1360 /// \brief Generically compute the modified post order.
1361 template <typename ContextT>
compute(const CycleInfoT & CI)1362 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1363   SmallPtrSet<const BlockT *, 32> Finalized;
1364   SmallVector<const BlockT *> Stack;
1365   auto *F = CI.getFunction();
1366   Stack.reserve(24); // FIXME made-up number
1367   Stack.push_back(&F->front());
1368   computeStackPO(Stack, CI, nullptr, Finalized);
1369 }
1370 
1371 } // namespace llvm
1372 
1373 #undef DEBUG_TYPE
1374 
1375 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
1376