1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/utils/cycle_detector.h"
17
18 #include <algorithm>
19
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/SmallVector.h"
22
23 namespace mlir {
24
25 namespace {
26
27 using NodeSet = llvm::DenseSet<int32_t>;
28 using OrderedNodeSet = OrderedSet<int32_t>;
29
30 template <typename T>
31 struct VecStruct {
32 typedef llvm::SmallVector<T, 4> type;
33 };
34 template <typename T>
35 using Vec = typename VecStruct<T>::type;
36
37 struct Node {
38 // rank number assigned by Pearce-Kelly algorithm
39 int32_t rank;
40 // Temporary marker used by depth-first-search
41 bool visited;
42 // User-supplied data
43 void* data;
44 // List of immediate predecessor nodes in graph
45 OrderedNodeSet in;
46 // List of immediate successor nodes in graph
47 OrderedNodeSet out;
48 };
49
50 } // namespace
51
52 struct GraphCycles::Rep {
53 Vec<Node*> nodes;
54 // Indices for unused entries in nodes
55 Vec<int32_t> free_nodes;
56
57 // Temporary state.
58 // Results of forward DFS
59 Vec<int32_t> deltaf;
60 // Results of backward DFS
61 Vec<int32_t> deltab;
62 // All nodes to reprocess
63 Vec<int32_t> list;
64 // Rank values to assign to list entries
65 Vec<int32_t> merged;
66 // Emulates recursion stack when doing depth first search
67 Vec<int32_t> stack;
68 };
69
GraphCycles(int32_t num_nodes)70 GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
71 rep_->nodes.reserve(num_nodes);
72 for (int32_t i = 0; i < num_nodes; ++i) {
73 Node* n = new Node;
74 n->visited = false;
75 n->data = nullptr;
76 n->rank = rep_->nodes.size();
77 rep_->nodes.push_back(n);
78 }
79 }
80
~GraphCycles()81 GraphCycles::~GraphCycles() {
82 for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
83 delete rep_->nodes[i];
84 }
85 delete rep_;
86 }
87
HasEdge(int32_t x,int32_t y) const88 bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
89 return rep_->nodes[x]->out.Contains(y);
90 }
91
RemoveEdge(int32_t x,int32_t y)92 void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
93 rep_->nodes[x]->out.Erase(y);
94 rep_->nodes[y]->in.Erase(x);
95 // No need to update the rank assignment since a previous valid
96 // rank assignment remains valid after an edge deletion.
97 }
98
99 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
100 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
101 static void Reorder(GraphCycles::Rep* r);
102 static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
103 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
104 Vec<int32_t>* dst);
105 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
106
InsertEdge(int32_t x,int32_t y)107 bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
108 if (x == y) return false;
109 Rep* r = rep_;
110 Node* nx = r->nodes[x];
111 if (!nx->out.Insert(y)) {
112 // Edge already exists.
113 return true;
114 }
115
116 Node* ny = r->nodes[y];
117 ny->in.Insert(x);
118
119 if (nx->rank <= ny->rank) {
120 // New edge is consistent with existing rank assignment.
121 return true;
122 }
123
124 // Current rank assignments are incompatible with the new edge. Recompute.
125 // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
126 if (ForwardDFS(r, y, nx->rank)) {
127 // Found a cycle. Undo the insertion and tell caller.
128 nx->out.Erase(y);
129 ny->in.Erase(x);
130 // Since we do not call Reorder() on this path, clear any visited
131 // markers left by ForwardDFS.
132 ClearVisitedBits(r, r->deltaf);
133 return false;
134 }
135 BackwardDFS(r, x, ny->rank);
136 Reorder(r);
137 return true;
138 }
139
140 // Follows the edges from producer to consumer and searchs if the node having
141 // rank `n` can reach the node having rank `upper_bound` using a DFS search.
142 // When doing DFS search, We only consider the pathes that satisfy the ranks
143 // of the nodes of the path are all smaller than `upper_bound`.
144 //
145 // Returns true if such path exists.
ForwardDFS(GraphCycles::Rep * r,int32_t n,int32_t upper_bound)146 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
147 // Avoid recursion since stack space might be limited.
148 // We instead keep a stack of nodes to visit.
149 r->deltaf.clear();
150 r->stack.clear();
151 r->stack.push_back(n);
152 while (!r->stack.empty()) {
153 n = r->stack.back();
154 r->stack.pop_back();
155 Node* nn = r->nodes[n];
156 if (nn->visited) continue;
157
158 nn->visited = true;
159 r->deltaf.push_back(n);
160
161 for (auto w : nn->out.GetSequence()) {
162 Node* nw = r->nodes[w];
163 if (nw->rank == upper_bound) {
164 return true;
165 }
166 if (!nw->visited && nw->rank < upper_bound) {
167 r->stack.push_back(w);
168 }
169 }
170 }
171 return false;
172 }
173
174 // Follows the edges from consumer to producer and visit all the nodes that
175 // is reachable from node `n` and have rank larger than `lower_bound`.
BackwardDFS(GraphCycles::Rep * r,int32_t n,int32_t lower_bound)176 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
177 r->deltab.clear();
178 r->stack.clear();
179 r->stack.push_back(n);
180 while (!r->stack.empty()) {
181 n = r->stack.back();
182 r->stack.pop_back();
183 Node* nn = r->nodes[n];
184 if (nn->visited) continue;
185
186 nn->visited = true;
187 r->deltab.push_back(n);
188
189 for (auto w : nn->in.GetSequence()) {
190 Node* nw = r->nodes[w];
191 if (!nw->visited && lower_bound < nw->rank) {
192 r->stack.push_back(w);
193 }
194 }
195 }
196 }
197
198 // Recomputes rank assignments to make them compatible with the edges (producer
199 // has smaller rank than its consumer)
Reorder(GraphCycles::Rep * r)200 static void Reorder(GraphCycles::Rep* r) {
201 Sort(r->nodes, &r->deltab);
202 Sort(r->nodes, &r->deltaf);
203
204 // Adds contents of delta lists to list (backwards deltas first).
205 r->list.clear();
206 MoveToList(r, &r->deltab, &r->list);
207 MoveToList(r, &r->deltaf, &r->list);
208
209 // Produce sorted list of all ranks that will be reassigned.
210 r->merged.resize(r->deltab.size() + r->deltaf.size());
211 std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
212 r->deltaf.end(), r->merged.begin());
213
214 // Assign the ranks in order to the collected list.
215 for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
216 r->nodes[r->list[i]]->rank = r->merged[i];
217 }
218 }
219
220 // Sorts nodes in the vector according to their ranks. Small rank first.
Sort(const Vec<Node * > & nodes,Vec<int32_t> * delta)221 static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
222 struct ByRank {
223 const Vec<Node*>* nodes;
224 bool operator()(int32_t a, int32_t b) const {
225 return (*nodes)[a]->rank < (*nodes)[b]->rank;
226 }
227 };
228 ByRank cmp;
229 cmp.nodes = &nodes;
230 std::sort(delta->begin(), delta->end(), cmp);
231 }
232
233 // Collects ranks of nodes in vector `src` to vector `dst`
MoveToList(GraphCycles::Rep * r,Vec<int32_t> * src,Vec<int32_t> * dst)234 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
235 Vec<int32_t>* dst) {
236 for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
237 int32_t w = (*src)[i];
238 // Replace src entry with its rank
239 (*src)[i] = r->nodes[w]->rank;
240 // Prepare for future DFS calls
241 r->nodes[w]->visited = false;
242 dst->push_back(w);
243 }
244 }
245
246 // Clears bookkeeping fileds used during the last DFS process.
ClearVisitedBits(GraphCycles::Rep * r,const Vec<int32_t> & nodes)247 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
248 for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
249 r->nodes[nodes[i]]->visited = false;
250 }
251 }
252
IsReachable(int32_t x,int32_t y)253 bool GraphCycles::IsReachable(int32_t x, int32_t y) {
254 if (x == y) return true;
255 Rep* r = rep_;
256 Node* nx = r->nodes[x];
257 Node* ny = r->nodes[y];
258
259 if (nx->rank >= ny->rank) {
260 // x cannot reach y since it is after it in the topological ordering
261 return false;
262 }
263
264 // See if x can reach y using a DFS search that is limited to y's rank
265 bool reachable = ForwardDFS(r, x, ny->rank);
266
267 // Clear any visited markers left by ForwardDFS.
268 ClearVisitedBits(r, r->deltaf);
269 return reachable;
270 }
271
ContractEdge(int32_t a,int32_t b)272 llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
273 assert(HasEdge(a, b));
274 RemoveEdge(a, b);
275
276 if (IsReachable(a, b)) {
277 // Restore the graph to its original state.
278 InsertEdge(a, b);
279 return {};
280 }
281
282 if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
283 rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
284 // Swap "a" and "b" to minimize copying.
285 std::swap(a, b);
286 }
287
288 Node* nb = rep_->nodes[b];
289 OrderedNodeSet out = std::move(nb->out);
290 OrderedNodeSet in = std::move(nb->in);
291 for (int32_t y : out.GetSequence()) {
292 rep_->nodes[y]->in.Erase(b);
293 }
294 for (int32_t y : in.GetSequence()) {
295 rep_->nodes[y]->out.Erase(b);
296 }
297 rep_->free_nodes.push_back(b);
298
299 rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
300 for (int32_t y : out.GetSequence()) {
301 InsertEdge(a, y);
302 }
303
304 rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
305 for (int32_t y : in.GetSequence()) {
306 InsertEdge(y, a);
307 }
308
309 // Note, if the swap happened it might be what originally was called "b".
310 return a;
311 }
312
SuccessorsCopy(int32_t node) const313 std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
314 return rep_->nodes[node]->out.GetSequence();
315 }
316
317 namespace {
SortInPostOrder(const Vec<Node * > & nodes,std::vector<int32_t> * to_sort)318 void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
319 std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
320 return nodes[a]->rank > nodes[b]->rank;
321 });
322 }
323 } // namespace
324
AllNodesInPostOrder() const325 std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
326 llvm::DenseSet<int32_t> free_nodes_set;
327 for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
328
329 std::vector<int32_t> all_nodes;
330 all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
331 for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
332 if (!free_nodes_set.count(i)) {
333 all_nodes.push_back(i);
334 }
335 }
336
337 SortInPostOrder(rep_->nodes, &all_nodes);
338 return all_nodes;
339 }
340
341 } // namespace mlir
342