1 //===- GenericUniformityImpl.h -----------------------*- C++ -*------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This template implementation resides in a separate file so that it
10 // does not get injected into every .cpp file that includes the
11 // generic header.
12 //
13 // DO NOT INCLUDE THIS FILE WHEN MERELY USING UNIFORMITYINFO.
14 //
15 // This file should only be included by files that implement a
16 // specialization of the relvant templates. Currently these are:
17 // - UniformityAnalysis.cpp
18 //
19 // Note: The DEBUG_TYPE macro should be defined before using this
20 // file so that any use of LLVM_DEBUG is associated with the
21 // including file rather than this file.
22 //
23 //===----------------------------------------------------------------------===//
24 ///
25 /// \file
26 /// \brief Implementation of uniformity analysis.
27 ///
28 /// The algorithm is a fixed point iteration that starts with the assumption
29 /// that all control flow and all values are uniform. Starting from sources of
30 /// divergence (whose discovery must be implemented by a CFG- or even
31 /// target-specific derived class), divergence of values is propagated from
32 /// definition to uses in a straight-forward way. The main complexity lies in
33 /// the propagation of the impact of divergent control flow on the divergence of
34 /// values (sync dependencies).
35 ///
36 //===----------------------------------------------------------------------===//
37
38 #ifndef LLVM_ADT_GENERICUNIFORMITYIMPL_H
39 #define LLVM_ADT_GENERICUNIFORMITYIMPL_H
40
41 #include "llvm/ADT/GenericUniformityInfo.h"
42
43 #include "llvm/ADT/SmallPtrSet.h"
44 #include "llvm/ADT/SparseBitVector.h"
45 #include "llvm/ADT/StringExtras.h"
46 #include "llvm/Support/raw_ostream.h"
47
48 #include <set>
49
50 #define DEBUG_TYPE "uniformity"
51
52 namespace llvm {
53
unique(Range && R)54 template <typename Range> auto unique(Range &&R) {
55 return std::unique(adl_begin(R), adl_end(R));
56 }
57
58 /// Construct a specially modified post-order traversal of cycles.
59 ///
60 /// The ModifiedPO is contructed using a virtually modified CFG as follows:
61 ///
62 /// 1. The successors of pre-entry nodes (predecessors of an cycle
63 /// entry that are outside the cycle) are replaced by the
64 /// successors of the successors of the header.
65 /// 2. Successors of the cycle header are replaced by the exit blocks
66 /// of the cycle.
67 ///
68 /// Effectively, we produce a depth-first numbering with the following
69 /// properties:
70 ///
71 /// 1. Nodes after a cycle are numbered earlier than the cycle header.
72 /// 2. The header is numbered earlier than the nodes in the cycle.
73 /// 3. The numbering of the nodes within the cycle forms an interval
74 /// starting with the header.
75 ///
76 /// Effectively, the virtual modification arranges the nodes in a
77 /// cycle as a DAG with the header as the sole leaf, and successors of
78 /// the header as the roots. A reverse traversal of this numbering has
79 /// the following invariant on the unmodified original CFG:
80 ///
81 /// Each node is visited after all its predecessors, except if that
82 /// predecessor is the cycle header.
83 ///
84 template <typename ContextT> class ModifiedPostOrder {
85 public:
86 using BlockT = typename ContextT::BlockT;
87 using FunctionT = typename ContextT::FunctionT;
88 using DominatorTreeT = typename ContextT::DominatorTreeT;
89
90 using CycleInfoT = GenericCycleInfo<ContextT>;
91 using CycleT = typename CycleInfoT::CycleT;
92 using const_iterator = typename std::vector<BlockT *>::const_iterator;
93
ModifiedPostOrder(const ContextT & C)94 ModifiedPostOrder(const ContextT &C) : Context(C) {}
95
empty()96 bool empty() const { return m_order.empty(); }
size()97 size_t size() const { return m_order.size(); }
98
clear()99 void clear() { m_order.clear(); }
100 void compute(const CycleInfoT &CI);
101
count(BlockT * BB)102 unsigned count(BlockT *BB) const { return POIndex.count(BB); }
103 const BlockT *operator[](size_t idx) const { return m_order[idx]; }
104
105 void appendBlock(const BlockT &BB, bool isReducibleCycleHeader = false) {
106 POIndex[&BB] = m_order.size();
107 m_order.push_back(&BB);
108 LLVM_DEBUG(dbgs() << "ModifiedPO(" << POIndex[&BB]
109 << "): " << Context.print(&BB) << "\n");
110 if (isReducibleCycleHeader)
111 ReducibleCycleHeaders.insert(&BB);
112 }
113
getIndex(const BlockT * BB)114 unsigned getIndex(const BlockT *BB) const {
115 assert(POIndex.count(BB));
116 return POIndex.lookup(BB);
117 }
118
isReducibleCycleHeader(const BlockT * BB)119 bool isReducibleCycleHeader(const BlockT *BB) const {
120 return ReducibleCycleHeaders.contains(BB);
121 }
122
123 private:
124 SmallVector<const BlockT *> m_order;
125 DenseMap<const BlockT *, unsigned> POIndex;
126 SmallPtrSet<const BlockT *, 32> ReducibleCycleHeaders;
127 const ContextT &Context;
128
129 void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle,
130 SmallPtrSetImpl<const BlockT *> &Finalized);
131
132 void computeStackPO(SmallVectorImpl<const BlockT *> &Stack,
133 const CycleInfoT &CI, const CycleT *Cycle,
134 SmallPtrSetImpl<const BlockT *> &Finalized);
135 };
136
137 template <typename> class DivergencePropagator;
138
139 /// \class GenericSyncDependenceAnalysis
140 ///
141 /// \brief Locate join blocks for disjoint paths starting at a divergent branch.
142 ///
143 /// An analysis per divergent branch that returns the set of basic
144 /// blocks whose phi nodes become divergent due to divergent control.
145 /// These are the blocks that are reachable by two disjoint paths from
146 /// the branch, or cycle exits reachable along a path that is disjoint
147 /// from a path to the cycle latch.
148
149 // --- Above line is not a doxygen comment; intentionally left blank ---
150 //
151 // Originally implemented in SyncDependenceAnalysis.cpp for DivergenceAnalysis.
152 //
153 // The SyncDependenceAnalysis is used in the UniformityAnalysis to model
154 // control-induced divergence in phi nodes.
155 //
156 // -- Reference --
157 // The algorithm is an extension of Section 5 of
158 //
159 // An abstract interpretation for SPMD divergence
160 // on reducible control flow graphs.
161 // Julian Rosemann, Simon Moll and Sebastian Hack
162 // POPL '21
163 //
164 //
165 // -- Sync dependence --
166 // Sync dependence characterizes the control flow aspect of the
167 // propagation of branch divergence. For example,
168 //
169 // %cond = icmp slt i32 %tid, 10
170 // br i1 %cond, label %then, label %else
171 // then:
172 // br label %merge
173 // else:
174 // br label %merge
175 // merge:
176 // %a = phi i32 [ 0, %then ], [ 1, %else ]
177 //
178 // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
179 // because %tid is not on its use-def chains, %a is sync dependent on %tid
180 // because the branch "br i1 %cond" depends on %tid and affects which value %a
181 // is assigned to.
182 //
183 //
184 // -- Reduction to SSA construction --
185 // There are two disjoint paths from A to X, if a certain variant of SSA
186 // construction places a phi node in X under the following set-up scheme.
187 //
188 // This variant of SSA construction ignores incoming undef values.
189 // That is paths from the entry without a definition do not result in
190 // phi nodes.
191 //
192 // entry
193 // / \
194 // A \
195 // / \ Y
196 // B C /
197 // \ / \ /
198 // D E
199 // \ /
200 // F
201 //
202 // Assume that A contains a divergent branch. We are interested
203 // in the set of all blocks where each block is reachable from A
204 // via two disjoint paths. This would be the set {D, F} in this
205 // case.
206 // To generally reduce this query to SSA construction we introduce
207 // a virtual variable x and assign to x different values in each
208 // successor block of A.
209 //
210 // entry
211 // / \
212 // A \
213 // / \ Y
214 // x = 0 x = 1 /
215 // \ / \ /
216 // D E
217 // \ /
218 // F
219 //
220 // Our flavor of SSA construction for x will construct the following
221 //
222 // entry
223 // / \
224 // A \
225 // / \ Y
226 // x0 = 0 x1 = 1 /
227 // \ / \ /
228 // x2 = phi E
229 // \ /
230 // x3 = phi
231 //
232 // The blocks D and F contain phi nodes and are thus each reachable
233 // by two disjoins paths from A.
234 //
235 // -- Remarks --
236 // * In case of cycle exits we need to check for temporal divergence.
237 // To this end, we check whether the definition of x differs between the
238 // cycle exit and the cycle header (_after_ SSA construction).
239 //
240 // * In the presence of irreducible control flow, the fixed point is
241 // reached only after multiple iterations. This is because labels
242 // reaching the header of a cycle must be repropagated through the
243 // cycle. This is true even in a reducible cycle, since the labels
244 // may have been produced by a nested irreducible cycle.
245 //
246 // * Note that SyncDependenceAnalysis is not concerned with the points
247 // of convergence in an irreducible cycle. It's only purpose is to
248 // identify join blocks. The "diverged entry" criterion is
249 // separately applied on join blocks to determine if an entire
250 // irreducible cycle is assumed to be divergent.
251 //
252 // * Relevant related work:
253 // A simple algorithm for global data flow analysis problems.
254 // Matthew S. Hecht and Jeffrey D. Ullman.
255 // SIAM Journal on Computing, 4(4):519–532, December 1975.
256 //
257 template <typename ContextT> class GenericSyncDependenceAnalysis {
258 public:
259 using BlockT = typename ContextT::BlockT;
260 using DominatorTreeT = typename ContextT::DominatorTreeT;
261 using FunctionT = typename ContextT::FunctionT;
262 using ValueRefT = typename ContextT::ValueRefT;
263 using InstructionT = typename ContextT::InstructionT;
264
265 using CycleInfoT = GenericCycleInfo<ContextT>;
266 using CycleT = typename CycleInfoT::CycleT;
267
268 using ConstBlockSet = SmallPtrSet<const BlockT *, 4>;
269 using ModifiedPO = ModifiedPostOrder<ContextT>;
270
271 // * if BlockLabels[B] == C then C is the dominating definition at
272 // block B
273 // * if BlockLabels[B] == nullptr then we haven't seen B yet
274 // * if BlockLabels[B] == B then:
275 // - B is a join point of disjoint paths from X, or,
276 // - B is an immediate successor of X (initial value), or,
277 // - B is X
278 using BlockLabelMap = DenseMap<const BlockT *, const BlockT *>;
279
280 /// Information discovered by the sync dependence analysis for each
281 /// divergent branch.
282 struct DivergenceDescriptor {
283 // Join points of diverged paths.
284 ConstBlockSet JoinDivBlocks;
285 // Divergent cycle exits
286 ConstBlockSet CycleDivBlocks;
287 // Labels assigned to blocks on diverged paths.
288 BlockLabelMap BlockLabels;
289 };
290
291 using DivergencePropagatorT = DivergencePropagator<ContextT>;
292
293 GenericSyncDependenceAnalysis(const ContextT &Context,
294 const DominatorTreeT &DT, const CycleInfoT &CI);
295
296 /// \brief Computes divergent join points and cycle exits caused by branch
297 /// divergence in \p Term.
298 ///
299 /// This returns a pair of sets:
300 /// * The set of blocks which are reachable by disjoint paths from
301 /// \p Term.
302 /// * The set also contains cycle exits if there two disjoint paths:
303 /// one from \p Term to the cycle exit and another from \p Term to
304 /// the cycle header.
305 const DivergenceDescriptor &getJoinBlocks(const BlockT *DivTermBlock);
306
307 private:
308 static DivergenceDescriptor EmptyDivergenceDesc;
309
310 ModifiedPO CyclePO;
311
312 const DominatorTreeT &DT;
313 const CycleInfoT &CI;
314
315 DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
316 CachedControlDivDescs;
317 };
318
319 /// \brief Analysis that identifies uniform values in a data-parallel
320 /// execution.
321 ///
322 /// This analysis propagates divergence in a data-parallel context
323 /// from sources of divergence to all users. It can be instantiated
324 /// for an IR that provides a suitable SSAContext.
325 template <typename ContextT> class GenericUniformityAnalysisImpl {
326 public:
327 using BlockT = typename ContextT::BlockT;
328 using FunctionT = typename ContextT::FunctionT;
329 using ValueRefT = typename ContextT::ValueRefT;
330 using ConstValueRefT = typename ContextT::ConstValueRefT;
331 using UseT = typename ContextT::UseT;
332 using InstructionT = typename ContextT::InstructionT;
333 using DominatorTreeT = typename ContextT::DominatorTreeT;
334
335 using CycleInfoT = GenericCycleInfo<ContextT>;
336 using CycleT = typename CycleInfoT::CycleT;
337
338 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
339 using DivergenceDescriptorT =
340 typename SyncDependenceAnalysisT::DivergenceDescriptor;
341 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
342
GenericUniformityAnalysisImpl(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)343 GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
344 const TargetTransformInfo *TTI)
345 : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
346 TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
347
348 void initialize();
349
getFunction()350 const FunctionT &getFunction() const { return F; }
351
352 /// \brief Mark \p UniVal as a value that is always uniform.
353 void addUniformOverride(const InstructionT &Instr);
354
355 /// \brief Examine \p I for divergent outputs and add to the worklist.
356 void markDivergent(const InstructionT &I);
357
358 /// \brief Mark \p DivVal as a divergent value.
359 /// \returns Whether the tracked divergence state of \p DivVal changed.
360 bool markDivergent(ConstValueRefT DivVal);
361
362 /// \brief Mark outputs of \p Instr as divergent.
363 /// \returns Whether the tracked divergence state of any output has changed.
364 bool markDefsDivergent(const InstructionT &Instr);
365
366 /// \brief Propagate divergence to all instructions in the region.
367 /// Divergence is seeded by calls to \p markDivergent.
368 void compute();
369
370 /// \brief Whether any value was marked or analyzed to be divergent.
hasDivergence()371 bool hasDivergence() const { return !DivergentValues.empty(); }
372
373 /// \brief Whether \p Val will always return a uniform value regardless of its
374 /// operands
375 bool isAlwaysUniform(const InstructionT &Instr) const;
376
377 bool hasDivergentDefs(const InstructionT &I) const;
378
isDivergent(const InstructionT & I)379 bool isDivergent(const InstructionT &I) const {
380 if (I.isTerminator()) {
381 return DivergentTermBlocks.contains(I.getParent());
382 }
383 return hasDivergentDefs(I);
384 };
385
386 /// \brief Whether \p Val is divergent at its definition.
isDivergent(ConstValueRefT V)387 bool isDivergent(ConstValueRefT V) const { return DivergentValues.count(V); }
388
389 bool isDivergentUse(const UseT &U) const;
390
hasDivergentTerminator(const BlockT & B)391 bool hasDivergentTerminator(const BlockT &B) const {
392 return DivergentTermBlocks.contains(&B);
393 }
394
395 void print(raw_ostream &out) const;
396
397 protected:
398 /// \brief Value/block pair representing a single phi input.
399 struct PhiInput {
400 ConstValueRefT value;
401 BlockT *predBlock;
402
PhiInputPhiInput403 PhiInput(ConstValueRefT value, BlockT *predBlock)
404 : value(value), predBlock(predBlock) {}
405 };
406
407 const ContextT &Context;
408 const FunctionT &F;
409 const CycleInfoT &CI;
410 const TargetTransformInfo *TTI = nullptr;
411
412 // Detected/marked divergent values.
413 std::set<ConstValueRefT> DivergentValues;
414 SmallPtrSet<const BlockT *, 32> DivergentTermBlocks;
415
416 // Internal worklist for divergence propagation.
417 std::vector<const InstructionT *> Worklist;
418
419 /// \brief Mark \p Term as divergent and push all Instructions that become
420 /// divergent as a result on the worklist.
421 void analyzeControlDivergence(const InstructionT &Term);
422
423 private:
424 const DominatorTreeT &DT;
425
426 // Recognized cycles with divergent exits.
427 SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
428
429 // Cycles assumed to be divergent.
430 //
431 // We don't use a set here because every insertion needs an explicit
432 // traversal of all existing members.
433 SmallVector<const CycleT *> AssumedDivergent;
434
435 // The SDA links divergent branches to divergent control-flow joins.
436 SyncDependenceAnalysisT SDA;
437
438 // Set of known-uniform values.
439 SmallPtrSet<const InstructionT *, 32> UniformOverrides;
440
441 /// \brief Mark all nodes in \p JoinBlock as divergent and push them on
442 /// the worklist.
443 void taintAndPushAllDefs(const BlockT &JoinBlock);
444
445 /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on
446 /// the worklist.
447 void taintAndPushPhiNodes(const BlockT &JoinBlock);
448
449 /// \brief Identify all Instructions that become divergent because \p DivExit
450 /// is a divergent cycle exit of \p DivCycle. Mark those instructions as
451 /// divergent and push them on the worklist.
452 void propagateCycleExitDivergence(const BlockT &DivExit,
453 const CycleT &DivCycle);
454
455 /// Mark as divergent all external uses of values defined in \p DefCycle.
456 void analyzeCycleExitDivergence(const CycleT &DefCycle);
457
458 /// \brief Mark as divergent all uses of \p I that are outside \p DefCycle.
459 void propagateTemporalDivergence(const InstructionT &I,
460 const CycleT &DefCycle);
461
462 /// \brief Push all users of \p Val (in the region) to the worklist.
463 void pushUsers(const InstructionT &I);
464 void pushUsers(ConstValueRefT V);
465
466 bool usesValueFromCycle(const InstructionT &I, const CycleT &DefCycle) const;
467
468 /// \brief Whether \p Def is divergent when read in \p ObservingBlock.
469 bool isTemporalDivergent(const BlockT &ObservingBlock,
470 const InstructionT &Def) const;
471 };
472
473 template <typename ImplT>
operator()474 void GenericUniformityAnalysisImplDeleter<ImplT>::operator()(ImplT *Impl) {
475 delete Impl;
476 }
477
478 /// Compute divergence starting with a divergent branch.
479 template <typename ContextT> class DivergencePropagator {
480 public:
481 using BlockT = typename ContextT::BlockT;
482 using DominatorTreeT = typename ContextT::DominatorTreeT;
483 using FunctionT = typename ContextT::FunctionT;
484 using ValueRefT = typename ContextT::ValueRefT;
485
486 using CycleInfoT = GenericCycleInfo<ContextT>;
487 using CycleT = typename CycleInfoT::CycleT;
488
489 using ModifiedPO = ModifiedPostOrder<ContextT>;
490 using SyncDependenceAnalysisT = GenericSyncDependenceAnalysis<ContextT>;
491 using DivergenceDescriptorT =
492 typename SyncDependenceAnalysisT::DivergenceDescriptor;
493 using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap;
494
495 const ModifiedPO &CyclePOT;
496 const DominatorTreeT &DT;
497 const CycleInfoT &CI;
498 const BlockT &DivTermBlock;
499 const ContextT &Context;
500
501 // Track blocks that receive a new label. Every time we relabel a
502 // cycle header, we another pass over the modified post-order in
503 // order to propagate the header label. The bit vector also allows
504 // us to skip labels that have not changed.
505 SparseBitVector<> FreshLabels;
506
507 // divergent join and cycle exit descriptor.
508 std::unique_ptr<DivergenceDescriptorT> DivDesc;
509 BlockLabelMapT &BlockLabels;
510
DivergencePropagator(const ModifiedPO & CyclePOT,const DominatorTreeT & DT,const CycleInfoT & CI,const BlockT & DivTermBlock)511 DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
512 const CycleInfoT &CI, const BlockT &DivTermBlock)
513 : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
514 Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
515 BlockLabels(DivDesc->BlockLabels) {}
516
printDefs(raw_ostream & Out)517 void printDefs(raw_ostream &Out) {
518 Out << "Propagator::BlockLabels {\n";
519 for (int BlockIdx = (int)CyclePOT.size() - 1; BlockIdx >= 0; --BlockIdx) {
520 const auto *Block = CyclePOT[BlockIdx];
521 const auto *Label = BlockLabels[Block];
522 Out << Context.print(Block) << "(" << BlockIdx << ") : ";
523 if (!Label) {
524 Out << "<null>\n";
525 } else {
526 Out << Context.print(Label) << "\n";
527 }
528 }
529 Out << "}\n";
530 }
531
532 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
533 // causes a divergent join.
computeJoin(const BlockT & SuccBlock,const BlockT & PushedLabel)534 bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
535 const auto *OldLabel = BlockLabels[&SuccBlock];
536
537 LLVM_DEBUG(dbgs() << "labeling " << Context.print(&SuccBlock) << ":\n"
538 << "\tpushed label: " << Context.print(&PushedLabel)
539 << "\n"
540 << "\told label: " << Context.print(OldLabel) << "\n");
541
542 // Early exit if there is no change in the label.
543 if (OldLabel == &PushedLabel)
544 return false;
545
546 if (OldLabel != &SuccBlock) {
547 auto SuccIdx = CyclePOT.getIndex(&SuccBlock);
548 // Assigning a new label, mark this in FreshLabels.
549 LLVM_DEBUG(dbgs() << "\tfresh label: " << SuccIdx << "\n");
550 FreshLabels.set(SuccIdx);
551 }
552
553 // This is not a join if the succ was previously unlabeled.
554 if (!OldLabel) {
555 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&PushedLabel)
556 << "\n");
557 BlockLabels[&SuccBlock] = &PushedLabel;
558 return false;
559 }
560
561 // This is a new join. Label the join block as itself, and not as
562 // the pushed label.
563 LLVM_DEBUG(dbgs() << "\tnew label: " << Context.print(&SuccBlock) << "\n");
564 BlockLabels[&SuccBlock] = &SuccBlock;
565
566 return true;
567 }
568
569 // visiting a virtual cycle exit edge from the cycle header --> temporal
570 // divergence on join
visitCycleExitEdge(const BlockT & ExitBlock,const BlockT & Label)571 bool visitCycleExitEdge(const BlockT &ExitBlock, const BlockT &Label) {
572 if (!computeJoin(ExitBlock, Label))
573 return false;
574
575 // Identified a divergent cycle exit
576 DivDesc->CycleDivBlocks.insert(&ExitBlock);
577 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(&ExitBlock)
578 << "\n");
579 return true;
580 }
581
582 // process \p SuccBlock with reaching definition \p Label
visitEdge(const BlockT & SuccBlock,const BlockT & Label)583 bool visitEdge(const BlockT &SuccBlock, const BlockT &Label) {
584 if (!computeJoin(SuccBlock, Label))
585 return false;
586
587 // Divergent, disjoint paths join.
588 DivDesc->JoinDivBlocks.insert(&SuccBlock);
589 LLVM_DEBUG(dbgs() << "\tDivergent join: " << Context.print(&SuccBlock)
590 << "\n");
591 return true;
592 }
593
computeJoinPoints()594 std::unique_ptr<DivergenceDescriptorT> computeJoinPoints() {
595 assert(DivDesc);
596
597 LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
598 << Context.print(&DivTermBlock) << "\n");
599
600 // Early stopping criterion
601 int FloorIdx = CyclePOT.size() - 1;
602 const BlockT *FloorLabel = nullptr;
603 int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
604
605 // Bootstrap with branch targets
606 auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
607 for (const auto *SuccBlock : successors(&DivTermBlock)) {
608 if (DivTermCycle && !DivTermCycle->contains(SuccBlock)) {
609 // If DivTerm exits the cycle immediately, computeJoin() might
610 // not reach SuccBlock with a different label. We need to
611 // check for this exit now.
612 DivDesc->CycleDivBlocks.insert(SuccBlock);
613 LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
614 << Context.print(SuccBlock) << "\n");
615 }
616 auto SuccIdx = CyclePOT.getIndex(SuccBlock);
617 visitEdge(*SuccBlock, *SuccBlock);
618 FloorIdx = std::min<int>(FloorIdx, SuccIdx);
619 }
620
621 while (true) {
622 auto BlockIdx = FreshLabels.find_last();
623 if (BlockIdx == -1 || BlockIdx < FloorIdx)
624 break;
625
626 LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
627
628 FreshLabels.reset(BlockIdx);
629 if (BlockIdx == DivTermIdx) {
630 LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
631 continue;
632 }
633
634 const auto *Block = CyclePOT[BlockIdx];
635 LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
636 << BlockIdx << "\n");
637
638 const auto *Label = BlockLabels[Block];
639 assert(Label);
640
641 bool CausedJoin = false;
642 int LoweredFloorIdx = FloorIdx;
643
644 // If the current block is the header of a reducible cycle that
645 // contains the divergent branch, then the label should be
646 // propagated to the cycle exits. Such a header is the "last
647 // possible join" of any disjoint paths within this cycle. This
648 // prevents detection of spurious joins at the entries of any
649 // irreducible child cycles.
650 //
651 // This conclusion about the header is true for any choice of DFS:
652 //
653 // If some DFS has a reducible cycle C with header H, then for
654 // any other DFS, H is the header of a cycle C' that is a
655 // superset of C. For a divergent branch inside the subgraph
656 // C, any join node inside C is either H, or some node
657 // encountered without passing through H.
658 //
659 auto getReducibleParent = [&](const BlockT *Block) -> const CycleT * {
660 if (!CyclePOT.isReducibleCycleHeader(Block))
661 return nullptr;
662 const auto *BlockCycle = CI.getCycle(Block);
663 if (BlockCycle->contains(&DivTermBlock))
664 return BlockCycle;
665 return nullptr;
666 };
667
668 if (const auto *BlockCycle = getReducibleParent(Block)) {
669 SmallVector<BlockT *, 4> BlockCycleExits;
670 BlockCycle->getExitBlocks(BlockCycleExits);
671 for (auto *BlockCycleExit : BlockCycleExits) {
672 CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
673 LoweredFloorIdx =
674 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
675 }
676 } else {
677 for (const auto *SuccBlock : successors(Block)) {
678 CausedJoin |= visitEdge(*SuccBlock, *Label);
679 LoweredFloorIdx =
680 std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
681 }
682 }
683
684 // Floor update
685 if (CausedJoin) {
686 // 1. Different labels pushed to successors
687 FloorIdx = LoweredFloorIdx;
688 } else if (FloorLabel != Label) {
689 // 2. No join caused BUT we pushed a label that is different than the
690 // last pushed label
691 FloorIdx = LoweredFloorIdx;
692 FloorLabel = Label;
693 }
694 }
695
696 LLVM_DEBUG(dbgs() << "Final labeling:\n"; printDefs(dbgs()));
697
698 // Check every cycle containing DivTermBlock for exit divergence.
699 // A cycle has exit divergence if the label of an exit block does
700 // not match the label of its header.
701 for (const auto *Cycle = CI.getCycle(&DivTermBlock); Cycle;
702 Cycle = Cycle->getParentCycle()) {
703 if (Cycle->isReducible()) {
704 // The exit divergence of a reducible cycle is recorded while
705 // propagating labels.
706 continue;
707 }
708 SmallVector<BlockT *> Exits;
709 Cycle->getExitBlocks(Exits);
710 auto *Header = Cycle->getHeader();
711 auto *HeaderLabel = BlockLabels[Header];
712 for (const auto *Exit : Exits) {
713 if (BlockLabels[Exit] != HeaderLabel) {
714 // Identified a divergent cycle exit
715 DivDesc->CycleDivBlocks.insert(Exit);
716 LLVM_DEBUG(dbgs() << "\tDivergent cycle exit: " << Context.print(Exit)
717 << "\n");
718 }
719 }
720 }
721
722 return std::move(DivDesc);
723 }
724 };
725
726 template <typename ContextT>
727 typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
728 llvm::GenericSyncDependenceAnalysis<ContextT>::EmptyDivergenceDesc;
729
730 template <typename ContextT>
GenericSyncDependenceAnalysis(const ContextT & Context,const DominatorTreeT & DT,const CycleInfoT & CI)731 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
732 const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
733 : CyclePO(Context), DT(DT), CI(CI) {
734 CyclePO.compute(CI);
735 }
736
737 template <typename ContextT>
738 auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
739 const BlockT *DivTermBlock) -> const DivergenceDescriptor & {
740 // trivial case
741 if (succ_size(DivTermBlock) <= 1) {
742 return EmptyDivergenceDesc;
743 }
744
745 // already available in cache?
746 auto ItCached = CachedControlDivDescs.find(DivTermBlock);
747 if (ItCached != CachedControlDivDescs.end())
748 return *ItCached->second;
749
750 // compute all join points
751 DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
752 auto DivDesc = Propagator.computeJoinPoints();
753
754 auto printBlockSet = [&](ConstBlockSet &Blocks) {
755 return Printable([&](raw_ostream &Out) {
756 Out << "[";
757 ListSeparator LS;
758 for (const auto *BB : Blocks) {
759 Out << LS << CI.getSSAContext().print(BB);
760 }
761 Out << "]\n";
762 });
763 };
764
765 LLVM_DEBUG(
766 dbgs() << "\nResult (" << CI.getSSAContext().print(DivTermBlock)
767 << "):\n JoinDivBlocks: " << printBlockSet(DivDesc->JoinDivBlocks)
768 << " CycleDivBlocks: " << printBlockSet(DivDesc->CycleDivBlocks)
769 << "\n");
770 (void)printBlockSet;
771
772 auto ItInserted =
773 CachedControlDivDescs.try_emplace(DivTermBlock, std::move(DivDesc));
774 assert(ItInserted.second);
775 return *ItInserted.first->second;
776 }
777
778 template <typename ContextT>
markDivergent(const InstructionT & I)779 void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
780 const InstructionT &I) {
781 if (isAlwaysUniform(I))
782 return;
783 bool Marked = false;
784 if (I.isTerminator()) {
785 Marked = DivergentTermBlocks.insert(I.getParent()).second;
786 if (Marked) {
787 LLVM_DEBUG(dbgs() << "marked divergent term block: "
788 << Context.print(I.getParent()) << "\n");
789 }
790 } else {
791 Marked = markDefsDivergent(I);
792 }
793
794 if (Marked)
795 Worklist.push_back(&I);
796 }
797
798 template <typename ContextT>
markDivergent(ConstValueRefT Val)799 bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
800 ConstValueRefT Val) {
801 if (DivergentValues.insert(Val).second) {
802 LLVM_DEBUG(dbgs() << "marked divergent: " << Context.print(Val) << "\n");
803 return true;
804 }
805 return false;
806 }
807
808 template <typename ContextT>
addUniformOverride(const InstructionT & Instr)809 void GenericUniformityAnalysisImpl<ContextT>::addUniformOverride(
810 const InstructionT &Instr) {
811 UniformOverrides.insert(&Instr);
812 }
813
814 // Mark as divergent all external uses of values defined in \p DefCycle.
815 //
816 // A value V defined by a block B inside \p DefCycle may be used outside the
817 // cycle only if the use is a PHI in some exit block, or B dominates some exit
818 // block. Thus, we check uses as follows:
819 //
820 // - Check all PHIs in all exit blocks for inputs defined inside \p DefCycle.
821 // - For every block B inside \p DefCycle that dominates at least one exit
822 // block, check all uses outside \p DefCycle.
823 //
824 // FIXME: This function does not distinguish between divergent and uniform
825 // exits. For each divergent exit, only the values that are live at that exit
826 // need to be propagated as divergent at their use outside the cycle.
827 template <typename ContextT>
analyzeCycleExitDivergence(const CycleT & DefCycle)828 void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
829 const CycleT &DefCycle) {
830 SmallVector<BlockT *> Exits;
831 DefCycle.getExitBlocks(Exits);
832 for (auto *Exit : Exits) {
833 for (auto &Phi : Exit->phis()) {
834 if (usesValueFromCycle(Phi, DefCycle)) {
835 markDivergent(Phi);
836 }
837 }
838 }
839
840 for (auto *BB : DefCycle.blocks()) {
841 if (!llvm::any_of(Exits,
842 [&](BlockT *Exit) { return DT.dominates(BB, Exit); }))
843 continue;
844 for (auto &II : *BB) {
845 propagateTemporalDivergence(II, DefCycle);
846 }
847 }
848 }
849
850 template <typename ContextT>
propagateCycleExitDivergence(const BlockT & DivExit,const CycleT & InnerDivCycle)851 void GenericUniformityAnalysisImpl<ContextT>::propagateCycleExitDivergence(
852 const BlockT &DivExit, const CycleT &InnerDivCycle) {
853 LLVM_DEBUG(dbgs() << "\tpropCycleExitDiv " << Context.print(&DivExit)
854 << "\n");
855 auto *DivCycle = &InnerDivCycle;
856 auto *OuterDivCycle = DivCycle;
857 auto *ExitLevelCycle = CI.getCycle(&DivExit);
858 const unsigned CycleExitDepth =
859 ExitLevelCycle ? ExitLevelCycle->getDepth() : 0;
860
861 // Find outer-most cycle that does not contain \p DivExit
862 while (DivCycle && DivCycle->getDepth() > CycleExitDepth) {
863 LLVM_DEBUG(dbgs() << " Found exiting cycle: "
864 << Context.print(DivCycle->getHeader()) << "\n");
865 OuterDivCycle = DivCycle;
866 DivCycle = DivCycle->getParentCycle();
867 }
868 LLVM_DEBUG(dbgs() << "\tOuter-most exiting cycle: "
869 << Context.print(OuterDivCycle->getHeader()) << "\n");
870
871 if (!DivergentExitCycles.insert(OuterDivCycle).second)
872 return;
873
874 // Exit divergence does not matter if the cycle itself is assumed to
875 // be divergent.
876 for (const auto *C : AssumedDivergent) {
877 if (C->contains(OuterDivCycle))
878 return;
879 }
880
881 analyzeCycleExitDivergence(*OuterDivCycle);
882 }
883
884 template <typename ContextT>
taintAndPushAllDefs(const BlockT & BB)885 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
886 const BlockT &BB) {
887 LLVM_DEBUG(dbgs() << "taintAndPushAllDefs " << Context.print(&BB) << "\n");
888 for (const auto &I : instrs(BB)) {
889 // Terminators do not produce values; they are divergent only if
890 // the condition is divergent. That is handled when the divergent
891 // condition is placed in the worklist.
892 if (I.isTerminator())
893 break;
894
895 markDivergent(I);
896 }
897 }
898
899 /// Mark divergent phi nodes in a join block
900 template <typename ContextT>
taintAndPushPhiNodes(const BlockT & JoinBlock)901 void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
902 const BlockT &JoinBlock) {
903 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << Context.print(&JoinBlock)
904 << "\n");
905 for (const auto &Phi : JoinBlock.phis()) {
906 // FIXME: The non-undef value is not constant per se; it just happens to be
907 // uniform and may not dominate this PHI. So assuming that the same value
908 // reaches along all incoming edges may itself be undefined behaviour. This
909 // particular interpretation of the undef value was added to
910 // DivergenceAnalysis in the following review:
911 //
912 // https://reviews.llvm.org/D19013
913 if (ContextT::isConstantOrUndefValuePhi(Phi))
914 continue;
915 markDivergent(Phi);
916 }
917 }
918
919 /// Add \p Candidate to \p Cycles if it is not already contained in \p Cycles.
920 ///
921 /// \return true iff \p Candidate was added to \p Cycles.
922 template <typename CycleT>
insertIfNotContained(SmallVector<CycleT * > & Cycles,CycleT * Candidate)923 static bool insertIfNotContained(SmallVector<CycleT *> &Cycles,
924 CycleT *Candidate) {
925 if (llvm::any_of(Cycles,
926 [Candidate](CycleT *C) { return C->contains(Candidate); }))
927 return false;
928 Cycles.push_back(Candidate);
929 return true;
930 }
931
932 /// Return the outermost cycle made divergent by branch outside it.
933 ///
934 /// If two paths that diverged outside an irreducible cycle join
935 /// inside that cycle, then that whole cycle is assumed to be
936 /// divergent. This does not apply if the cycle is reducible.
937 template <typename CycleT, typename BlockT>
getExtDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock)938 static const CycleT *getExtDivCycle(const CycleT *Cycle,
939 const BlockT *DivTermBlock,
940 const BlockT *JoinBlock) {
941 assert(Cycle);
942 assert(Cycle->contains(JoinBlock));
943
944 if (Cycle->contains(DivTermBlock))
945 return nullptr;
946
947 const auto *OriginalCycle = Cycle;
948 const auto *Parent = Cycle->getParentCycle();
949 while (Parent && !Parent->contains(DivTermBlock)) {
950 Cycle = Parent;
951 Parent = Cycle->getParentCycle();
952 }
953
954 // If the original cycle is not the outermost cycle, then the outermost cycle
955 // is irreducible. If the outermost cycle were reducible, then external
956 // diverged paths would not reach the original inner cycle.
957 (void)OriginalCycle;
958 assert(Cycle == OriginalCycle || !Cycle->isReducible());
959
960 if (Cycle->isReducible()) {
961 assert(Cycle->getHeader() == JoinBlock);
962 return nullptr;
963 }
964
965 LLVM_DEBUG(dbgs() << "cycle made divergent by external branch\n");
966 return Cycle;
967 }
968
969 /// Return the outermost cycle made divergent by branch inside it.
970 ///
971 /// This checks the "diverged entry" criterion defined in the
972 /// docs/ConvergenceAnalysis.html.
973 template <typename ContextT, typename CycleT, typename BlockT,
974 typename DominatorTreeT>
975 static const CycleT *
getIntDivCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)976 getIntDivCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
977 const BlockT *JoinBlock, const DominatorTreeT &DT,
978 ContextT &Context) {
979 LLVM_DEBUG(dbgs() << "examine join " << Context.print(JoinBlock)
980 << " for internal branch " << Context.print(DivTermBlock)
981 << "\n");
982 if (DT.properlyDominates(DivTermBlock, JoinBlock))
983 return nullptr;
984
985 // Find the smallest common cycle, if one exists.
986 assert(Cycle && Cycle->contains(JoinBlock));
987 while (Cycle && !Cycle->contains(DivTermBlock)) {
988 Cycle = Cycle->getParentCycle();
989 }
990 if (!Cycle || Cycle->isReducible())
991 return nullptr;
992
993 if (DT.properlyDominates(Cycle->getHeader(), JoinBlock))
994 return nullptr;
995
996 LLVM_DEBUG(dbgs() << " header " << Context.print(Cycle->getHeader())
997 << " does not dominate join\n");
998
999 const auto *Parent = Cycle->getParentCycle();
1000 while (Parent && !DT.properlyDominates(Parent->getHeader(), JoinBlock)) {
1001 LLVM_DEBUG(dbgs() << " header " << Context.print(Parent->getHeader())
1002 << " does not dominate join\n");
1003 Cycle = Parent;
1004 Parent = Parent->getParentCycle();
1005 }
1006
1007 LLVM_DEBUG(dbgs() << " cycle made divergent by internal branch\n");
1008 return Cycle;
1009 }
1010
1011 template <typename ContextT, typename CycleT, typename BlockT,
1012 typename DominatorTreeT>
1013 static const CycleT *
getOutermostDivergentCycle(const CycleT * Cycle,const BlockT * DivTermBlock,const BlockT * JoinBlock,const DominatorTreeT & DT,ContextT & Context)1014 getOutermostDivergentCycle(const CycleT *Cycle, const BlockT *DivTermBlock,
1015 const BlockT *JoinBlock, const DominatorTreeT &DT,
1016 ContextT &Context) {
1017 if (!Cycle)
1018 return nullptr;
1019
1020 // First try to expand Cycle to the largest that contains JoinBlock
1021 // but not DivTermBlock.
1022 const auto *Ext = getExtDivCycle(Cycle, DivTermBlock, JoinBlock);
1023
1024 // Continue expanding to the largest cycle that contains both.
1025 const auto *Int = getIntDivCycle(Cycle, DivTermBlock, JoinBlock, DT, Context);
1026
1027 if (Int)
1028 return Int;
1029 return Ext;
1030 }
1031
1032 template <typename ContextT>
isTemporalDivergent(const BlockT & ObservingBlock,const InstructionT & Def)1033 bool GenericUniformityAnalysisImpl<ContextT>::isTemporalDivergent(
1034 const BlockT &ObservingBlock, const InstructionT &Def) const {
1035 const BlockT *DefBlock = Def.getParent();
1036 for (const CycleT *Cycle = CI.getCycle(DefBlock);
1037 Cycle && !Cycle->contains(&ObservingBlock);
1038 Cycle = Cycle->getParentCycle()) {
1039 if (DivergentExitCycles.contains(Cycle)) {
1040 return true;
1041 }
1042 }
1043 return false;
1044 }
1045
1046 template <typename ContextT>
analyzeControlDivergence(const InstructionT & Term)1047 void GenericUniformityAnalysisImpl<ContextT>::analyzeControlDivergence(
1048 const InstructionT &Term) {
1049 const auto *DivTermBlock = Term.getParent();
1050 DivergentTermBlocks.insert(DivTermBlock);
1051 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Context.print(DivTermBlock)
1052 << "\n");
1053
1054 // Don't propagate divergence from unreachable blocks.
1055 if (!DT.isReachableFromEntry(DivTermBlock))
1056 return;
1057
1058 const auto &DivDesc = SDA.getJoinBlocks(DivTermBlock);
1059 SmallVector<const CycleT *> DivCycles;
1060
1061 // Iterate over all blocks now reachable by a disjoint path join
1062 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
1063 const auto *Cycle = CI.getCycle(JoinBlock);
1064 LLVM_DEBUG(dbgs() << "visiting join block " << Context.print(JoinBlock)
1065 << "\n");
1066 if (const auto *Outermost = getOutermostDivergentCycle(
1067 Cycle, DivTermBlock, JoinBlock, DT, Context)) {
1068 LLVM_DEBUG(dbgs() << "found divergent cycle\n");
1069 DivCycles.push_back(Outermost);
1070 continue;
1071 }
1072 taintAndPushPhiNodes(*JoinBlock);
1073 }
1074
1075 // Sort by order of decreasing depth. This allows later cycles to be skipped
1076 // because they are already contained in earlier ones.
1077 llvm::sort(DivCycles, [](const CycleT *A, const CycleT *B) {
1078 return A->getDepth() > B->getDepth();
1079 });
1080
1081 // Cycles that are assumed divergent due to the diverged entry
1082 // criterion potentially contain temporal divergence depending on
1083 // the DFS chosen. Conservatively, all values produced in such a
1084 // cycle are assumed divergent. "Cycle invariant" values may be
1085 // assumed uniform, but that requires further analysis.
1086 for (auto *C : DivCycles) {
1087 if (!insertIfNotContained(AssumedDivergent, C))
1088 continue;
1089 LLVM_DEBUG(dbgs() << "process divergent cycle\n");
1090 for (const BlockT *BB : C->blocks()) {
1091 taintAndPushAllDefs(*BB);
1092 }
1093 }
1094
1095 const auto *BranchCycle = CI.getCycle(DivTermBlock);
1096 assert(DivDesc.CycleDivBlocks.empty() || BranchCycle);
1097 for (const auto *DivExitBlock : DivDesc.CycleDivBlocks) {
1098 propagateCycleExitDivergence(*DivExitBlock, *BranchCycle);
1099 }
1100 }
1101
1102 template <typename ContextT>
compute()1103 void GenericUniformityAnalysisImpl<ContextT>::compute() {
1104 // Initialize worklist.
1105 auto DivValuesCopy = DivergentValues;
1106 for (const auto DivVal : DivValuesCopy) {
1107 assert(isDivergent(DivVal) && "Worklist invariant violated!");
1108 pushUsers(DivVal);
1109 }
1110
1111 // All values on the Worklist are divergent.
1112 // Their users may not have been updated yet.
1113 while (!Worklist.empty()) {
1114 const InstructionT *I = Worklist.back();
1115 Worklist.pop_back();
1116
1117 LLVM_DEBUG(dbgs() << "worklist pop: " << Context.print(I) << "\n");
1118
1119 if (I->isTerminator()) {
1120 analyzeControlDivergence(*I);
1121 continue;
1122 }
1123
1124 // propagate value divergence to users
1125 assert(isDivergent(*I) && "Worklist invariant violated!");
1126 pushUsers(*I);
1127 }
1128 }
1129
1130 template <typename ContextT>
isAlwaysUniform(const InstructionT & Instr)1131 bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
1132 const InstructionT &Instr) const {
1133 return UniformOverrides.contains(&Instr);
1134 }
1135
1136 template <typename ContextT>
GenericUniformityInfo(const DominatorTreeT & DT,const CycleInfoT & CI,const TargetTransformInfo * TTI)1137 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1138 const DominatorTreeT &DT, const CycleInfoT &CI,
1139 const TargetTransformInfo *TTI) {
1140 DA.reset(new ImplT{DT, CI, TTI});
1141 }
1142
1143 template <typename ContextT>
print(raw_ostream & OS)1144 void GenericUniformityAnalysisImpl<ContextT>::print(raw_ostream &OS) const {
1145 bool haveDivergentArgs = false;
1146
1147 // Control flow instructions may be divergent even if their inputs are
1148 // uniform. Thus, although exceedingly rare, it is possible to have a program
1149 // with no divergent values but with divergent control structures.
1150 if (DivergentValues.empty() && DivergentTermBlocks.empty() &&
1151 DivergentExitCycles.empty()) {
1152 OS << "ALL VALUES UNIFORM\n";
1153 return;
1154 }
1155
1156 for (const auto &entry : DivergentValues) {
1157 const BlockT *parent = Context.getDefBlock(entry);
1158 if (!parent) {
1159 if (!haveDivergentArgs) {
1160 OS << "DIVERGENT ARGUMENTS:\n";
1161 haveDivergentArgs = true;
1162 }
1163 OS << " DIVERGENT: " << Context.print(entry) << '\n';
1164 }
1165 }
1166
1167 if (!AssumedDivergent.empty()) {
1168 OS << "CYCLES ASSSUMED DIVERGENT:\n";
1169 for (const CycleT *cycle : AssumedDivergent) {
1170 OS << " " << cycle->print(Context) << '\n';
1171 }
1172 }
1173
1174 if (!DivergentExitCycles.empty()) {
1175 OS << "CYCLES WITH DIVERGENT EXIT:\n";
1176 for (const CycleT *cycle : DivergentExitCycles) {
1177 OS << " " << cycle->print(Context) << '\n';
1178 }
1179 }
1180
1181 for (auto &block : F) {
1182 OS << "\nBLOCK " << Context.print(&block) << '\n';
1183
1184 OS << "DEFINITIONS\n";
1185 SmallVector<ConstValueRefT, 16> defs;
1186 Context.appendBlockDefs(defs, block);
1187 for (auto value : defs) {
1188 if (isDivergent(value))
1189 OS << " DIVERGENT: ";
1190 else
1191 OS << " ";
1192 OS << Context.print(value) << '\n';
1193 }
1194
1195 OS << "TERMINATORS\n";
1196 SmallVector<const InstructionT *, 8> terms;
1197 Context.appendBlockTerms(terms, block);
1198 bool divergentTerminators = hasDivergentTerminator(block);
1199 for (auto *T : terms) {
1200 if (divergentTerminators)
1201 OS << " DIVERGENT: ";
1202 else
1203 OS << " ";
1204 OS << Context.print(T) << '\n';
1205 }
1206
1207 OS << "END BLOCK\n";
1208 }
1209 }
1210
1211 template <typename ContextT>
hasDivergence()1212 bool GenericUniformityInfo<ContextT>::hasDivergence() const {
1213 return DA->hasDivergence();
1214 }
1215
1216 template <typename ContextT>
1217 const typename ContextT::FunctionT &
getFunction()1218 GenericUniformityInfo<ContextT>::getFunction() const {
1219 return DA->getFunction();
1220 }
1221
1222 /// Whether \p V is divergent at its definition.
1223 template <typename ContextT>
isDivergent(ConstValueRefT V)1224 bool GenericUniformityInfo<ContextT>::isDivergent(ConstValueRefT V) const {
1225 return DA->isDivergent(V);
1226 }
1227
1228 template <typename ContextT>
isDivergent(const InstructionT * I)1229 bool GenericUniformityInfo<ContextT>::isDivergent(const InstructionT *I) const {
1230 return DA->isDivergent(*I);
1231 }
1232
1233 template <typename ContextT>
isDivergentUse(const UseT & U)1234 bool GenericUniformityInfo<ContextT>::isDivergentUse(const UseT &U) const {
1235 return DA->isDivergentUse(U);
1236 }
1237
1238 template <typename ContextT>
hasDivergentTerminator(const BlockT & B)1239 bool GenericUniformityInfo<ContextT>::hasDivergentTerminator(const BlockT &B) {
1240 return DA->hasDivergentTerminator(B);
1241 }
1242
1243 /// \brief T helper function for printing.
1244 template <typename ContextT>
print(raw_ostream & out)1245 void GenericUniformityInfo<ContextT>::print(raw_ostream &out) const {
1246 DA->print(out);
1247 }
1248
1249 template <typename ContextT>
computeStackPO(SmallVectorImpl<const BlockT * > & Stack,const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1250 void llvm::ModifiedPostOrder<ContextT>::computeStackPO(
1251 SmallVectorImpl<const BlockT *> &Stack, const CycleInfoT &CI,
1252 const CycleT *Cycle, SmallPtrSetImpl<const BlockT *> &Finalized) {
1253 LLVM_DEBUG(dbgs() << "inside computeStackPO\n");
1254 while (!Stack.empty()) {
1255 auto *NextBB = Stack.back();
1256 if (Finalized.count(NextBB)) {
1257 Stack.pop_back();
1258 continue;
1259 }
1260 LLVM_DEBUG(dbgs() << " visiting " << CI.getSSAContext().print(NextBB)
1261 << "\n");
1262 auto *NestedCycle = CI.getCycle(NextBB);
1263 if (Cycle != NestedCycle && (!Cycle || Cycle->contains(NestedCycle))) {
1264 LLVM_DEBUG(dbgs() << " found a cycle\n");
1265 while (NestedCycle->getParentCycle() != Cycle)
1266 NestedCycle = NestedCycle->getParentCycle();
1267
1268 SmallVector<BlockT *, 3> NestedExits;
1269 NestedCycle->getExitBlocks(NestedExits);
1270 bool PushedNodes = false;
1271 for (auto *NestedExitBB : NestedExits) {
1272 LLVM_DEBUG(dbgs() << " examine exit: "
1273 << CI.getSSAContext().print(NestedExitBB) << "\n");
1274 if (Cycle && !Cycle->contains(NestedExitBB))
1275 continue;
1276 if (Finalized.count(NestedExitBB))
1277 continue;
1278 PushedNodes = true;
1279 Stack.push_back(NestedExitBB);
1280 LLVM_DEBUG(dbgs() << " pushed exit: "
1281 << CI.getSSAContext().print(NestedExitBB) << "\n");
1282 }
1283 if (!PushedNodes) {
1284 // All loop exits finalized -> finish this node
1285 Stack.pop_back();
1286 computeCyclePO(CI, NestedCycle, Finalized);
1287 }
1288 continue;
1289 }
1290
1291 LLVM_DEBUG(dbgs() << " no nested cycle, going into DAG\n");
1292 // DAG-style
1293 bool PushedNodes = false;
1294 for (auto *SuccBB : successors(NextBB)) {
1295 LLVM_DEBUG(dbgs() << " examine succ: "
1296 << CI.getSSAContext().print(SuccBB) << "\n");
1297 if (Cycle && !Cycle->contains(SuccBB))
1298 continue;
1299 if (Finalized.count(SuccBB))
1300 continue;
1301 PushedNodes = true;
1302 Stack.push_back(SuccBB);
1303 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(SuccBB)
1304 << "\n");
1305 }
1306 if (!PushedNodes) {
1307 // Never push nodes twice
1308 LLVM_DEBUG(dbgs() << " finishing node: "
1309 << CI.getSSAContext().print(NextBB) << "\n");
1310 Stack.pop_back();
1311 Finalized.insert(NextBB);
1312 appendBlock(*NextBB);
1313 }
1314 }
1315 LLVM_DEBUG(dbgs() << "exited computeStackPO\n");
1316 }
1317
1318 template <typename ContextT>
computeCyclePO(const CycleInfoT & CI,const CycleT * Cycle,SmallPtrSetImpl<const BlockT * > & Finalized)1319 void ModifiedPostOrder<ContextT>::computeCyclePO(
1320 const CycleInfoT &CI, const CycleT *Cycle,
1321 SmallPtrSetImpl<const BlockT *> &Finalized) {
1322 LLVM_DEBUG(dbgs() << "inside computeCyclePO\n");
1323 SmallVector<const BlockT *> Stack;
1324 auto *CycleHeader = Cycle->getHeader();
1325
1326 LLVM_DEBUG(dbgs() << " noted header: "
1327 << CI.getSSAContext().print(CycleHeader) << "\n");
1328 assert(!Finalized.count(CycleHeader));
1329 Finalized.insert(CycleHeader);
1330
1331 // Visit the header last
1332 LLVM_DEBUG(dbgs() << " finishing header: "
1333 << CI.getSSAContext().print(CycleHeader) << "\n");
1334 appendBlock(*CycleHeader, Cycle->isReducible());
1335
1336 // Initialize with immediate successors
1337 for (auto *BB : successors(CycleHeader)) {
1338 LLVM_DEBUG(dbgs() << " examine succ: " << CI.getSSAContext().print(BB)
1339 << "\n");
1340 if (!Cycle->contains(BB))
1341 continue;
1342 if (BB == CycleHeader)
1343 continue;
1344 if (!Finalized.count(BB)) {
1345 LLVM_DEBUG(dbgs() << " pushed succ: " << CI.getSSAContext().print(BB)
1346 << "\n");
1347 Stack.push_back(BB);
1348 }
1349 }
1350
1351 // Compute PO inside region
1352 computeStackPO(Stack, CI, Cycle, Finalized);
1353
1354 LLVM_DEBUG(dbgs() << "exited computeCyclePO\n");
1355 }
1356
1357 /// \brief Generically compute the modified post order.
1358 template <typename ContextT>
compute(const CycleInfoT & CI)1359 void llvm::ModifiedPostOrder<ContextT>::compute(const CycleInfoT &CI) {
1360 SmallPtrSet<const BlockT *, 32> Finalized;
1361 SmallVector<const BlockT *> Stack;
1362 auto *F = CI.getFunction();
1363 Stack.reserve(24); // FIXME made-up number
1364 Stack.push_back(&F->front());
1365 computeStackPO(Stack, CI, nullptr, Finalized);
1366 }
1367
1368 } // namespace llvm
1369
1370 #undef DEBUG_TYPE
1371
1372 #endif // LLVM_ADT_GENERICUNIFORMITYIMPL_H
1373