1 //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements Analysis functions specific to slicing in Function.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/SetVector.h"
21
22 ///
23 /// Implements Analysis functions specific to slicing in Function.
24 ///
25
26 using namespace mlir;
27
28 using llvm::SetVector;
29
getForwardSliceImpl(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)30 static void getForwardSliceImpl(Operation *op,
31 SetVector<Operation *> *forwardSlice,
32 TransitiveFilter filter) {
33 if (!op) {
34 return;
35 }
36
37 // Evaluate whether we should keep this use.
38 // This is useful in particular to implement scoping; i.e. return the
39 // transitive forwardSlice in the current scope.
40 if (!filter(op)) {
41 return;
42 }
43
44 if (auto forOp = dyn_cast<AffineForOp>(op)) {
45 for (Operation *userOp : forOp.getInductionVar().getUsers())
46 if (forwardSlice->count(userOp) == 0)
47 getForwardSliceImpl(userOp, forwardSlice, filter);
48 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
49 for (Operation *userOp : forOp.getInductionVar().getUsers())
50 if (forwardSlice->count(userOp) == 0)
51 getForwardSliceImpl(userOp, forwardSlice, filter);
52 for (Value result : forOp.getResults())
53 for (Operation *userOp : result.getUsers())
54 if (forwardSlice->count(userOp) == 0)
55 getForwardSliceImpl(userOp, forwardSlice, filter);
56 } else {
57 assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
58 for (Value result : op->getResults()) {
59 for (Operation *userOp : result.getUsers())
60 if (forwardSlice->count(userOp) == 0)
61 getForwardSliceImpl(userOp, forwardSlice, filter);
62 }
63 }
64
65 forwardSlice->insert(op);
66 }
67
getForwardSlice(Operation * op,SetVector<Operation * > * forwardSlice,TransitiveFilter filter)68 void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
69 TransitiveFilter filter) {
70 getForwardSliceImpl(op, forwardSlice, filter);
71 // Don't insert the top level operation, we just queried on it and don't
72 // want it in the results.
73 forwardSlice->remove(op);
74
75 // Reverse to get back the actual topological order.
76 // std::reverse does not work out of the box on SetVector and I want an
77 // in-place swap based thing (the real std::reverse, not the LLVM adapter).
78 std::vector<Operation *> v(forwardSlice->takeVector());
79 forwardSlice->insert(v.rbegin(), v.rend());
80 }
81
getBackwardSliceImpl(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)82 static void getBackwardSliceImpl(Operation *op,
83 SetVector<Operation *> *backwardSlice,
84 TransitiveFilter filter) {
85 if (!op)
86 return;
87
88 assert((op->getNumRegions() == 0 ||
89 isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
90 "unexpected generic op with regions");
91
92 // Evaluate whether we should keep this def.
93 // This is useful in particular to implement scoping; i.e. return the
94 // transitive forwardSlice in the current scope.
95 if (!filter(op)) {
96 return;
97 }
98
99 for (auto en : llvm::enumerate(op->getOperands())) {
100 auto operand = en.value();
101 if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
102 if (auto affIv = getForInductionVarOwner(operand)) {
103 auto *affOp = affIv.getOperation();
104 if (backwardSlice->count(affOp) == 0)
105 getBackwardSliceImpl(affOp, backwardSlice, filter);
106 } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
107 auto *loopOp = loopIv.getOperation();
108 if (backwardSlice->count(loopOp) == 0)
109 getBackwardSliceImpl(loopOp, backwardSlice, filter);
110 } else if (blockArg.getOwner() !=
111 &op->getParentOfType<FuncOp>().getBody().front()) {
112 op->emitError("unsupported CF for operand ") << en.index();
113 llvm_unreachable("Unsupported control flow");
114 }
115 continue;
116 }
117 auto *op = operand.getDefiningOp();
118 if (backwardSlice->count(op) == 0) {
119 getBackwardSliceImpl(op, backwardSlice, filter);
120 }
121 }
122
123 backwardSlice->insert(op);
124 }
125
getBackwardSlice(Operation * op,SetVector<Operation * > * backwardSlice,TransitiveFilter filter)126 void mlir::getBackwardSlice(Operation *op,
127 SetVector<Operation *> *backwardSlice,
128 TransitiveFilter filter) {
129 getBackwardSliceImpl(op, backwardSlice, filter);
130
131 // Don't insert the top level operation, we just queried on it and don't
132 // want it in the results.
133 backwardSlice->remove(op);
134 }
135
getSlice(Operation * op,TransitiveFilter backwardFilter,TransitiveFilter forwardFilter)136 SetVector<Operation *> mlir::getSlice(Operation *op,
137 TransitiveFilter backwardFilter,
138 TransitiveFilter forwardFilter) {
139 SetVector<Operation *> slice;
140 slice.insert(op);
141
142 unsigned currentIndex = 0;
143 SetVector<Operation *> backwardSlice;
144 SetVector<Operation *> forwardSlice;
145 while (currentIndex != slice.size()) {
146 auto *currentOp = (slice)[currentIndex];
147 // Compute and insert the backwardSlice starting from currentOp.
148 backwardSlice.clear();
149 getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
150 slice.insert(backwardSlice.begin(), backwardSlice.end());
151
152 // Compute and insert the forwardSlice starting from currentOp.
153 forwardSlice.clear();
154 getForwardSlice(currentOp, &forwardSlice, forwardFilter);
155 slice.insert(forwardSlice.begin(), forwardSlice.end());
156 ++currentIndex;
157 }
158 return topologicalSort(slice);
159 }
160
161 namespace {
162 /// DFS post-order implementation that maintains a global count to work across
163 /// multiple invocations, to help implement topological sort on multi-root DAGs.
164 /// We traverse all operations but only record the ones that appear in
165 /// `toSort` for the final result.
166 struct DFSState {
DFSState__anon167741270111::DFSState167 DFSState(const SetVector<Operation *> &set)
168 : toSort(set), topologicalCounts(), seen() {}
169 const SetVector<Operation *> &toSort;
170 SmallVector<Operation *, 16> topologicalCounts;
171 DenseSet<Operation *> seen;
172 };
173 } // namespace
174
DFSPostorder(Operation * current,DFSState * state)175 static void DFSPostorder(Operation *current, DFSState *state) {
176 for (Value result : current->getResults()) {
177 for (Operation *op : result.getUsers())
178 DFSPostorder(op, state);
179 }
180 bool inserted;
181 using IterTy = decltype(state->seen.begin());
182 IterTy iter;
183 std::tie(iter, inserted) = state->seen.insert(current);
184 if (inserted) {
185 if (state->toSort.count(current) > 0) {
186 state->topologicalCounts.push_back(current);
187 }
188 }
189 }
190
191 SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)192 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
193 if (toSort.empty()) {
194 return toSort;
195 }
196
197 // Run from each root with global count and `seen` set.
198 DFSState state(toSort);
199 for (auto *s : toSort) {
200 assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
201 DFSPostorder(s, &state);
202 }
203
204 // Reorder and return.
205 SetVector<Operation *> res;
206 for (auto it = state.topologicalCounts.rbegin(),
207 eit = state.topologicalCounts.rend();
208 it != eit; ++it) {
209 res.insert(*it);
210 }
211 return res;
212 }
213