• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/utils/cycle_detector.h"
17 
18 #include <algorithm>
19 
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 namespace mlir {
24 
25 namespace {
26 
27 using NodeSet = llvm::DenseSet<int32_t>;
28 using OrderedNodeSet = OrderedSet<int32_t>;
29 
30 template <typename T>
31 struct VecStruct {
32   typedef llvm::SmallVector<T, 4> type;
33 };
34 template <typename T>
35 using Vec = typename VecStruct<T>::type;
36 
37 struct Node {
38   // rank number assigned by Pearce-Kelly algorithm
39   int32_t rank;
40   // Temporary marker used by depth-first-search
41   bool visited;
42   // User-supplied data
43   void* data;
44   // List of immediate predecessor nodes in graph
45   OrderedNodeSet in;
46   // List of immediate successor nodes in graph
47   OrderedNodeSet out;
48 };
49 
50 }  // namespace
51 
52 struct GraphCycles::Rep {
53   Vec<Node*> nodes;
54   // Indices for unused entries in nodes
55   Vec<int32_t> free_nodes;
56 
57   // Temporary state.
58   // Results of forward DFS
59   Vec<int32_t> deltaf;
60   // Results of backward DFS
61   Vec<int32_t> deltab;
62   // All nodes to reprocess
63   Vec<int32_t> list;
64   // Rank values to assign to list entries
65   Vec<int32_t> merged;
66   // Emulates recursion stack when doing depth first search
67   Vec<int32_t> stack;
68 };
69 
GraphCycles(int32_t num_nodes)70 GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
71   rep_->nodes.reserve(num_nodes);
72   for (int32_t i = 0; i < num_nodes; ++i) {
73     Node* n = new Node;
74     n->visited = false;
75     n->data = nullptr;
76     n->rank = rep_->nodes.size();
77     rep_->nodes.push_back(n);
78   }
79 }
80 
~GraphCycles()81 GraphCycles::~GraphCycles() {
82   for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
83     delete rep_->nodes[i];
84   }
85   delete rep_;
86 }
87 
HasEdge(int32_t x,int32_t y) const88 bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
89   return rep_->nodes[x]->out.Contains(y);
90 }
91 
RemoveEdge(int32_t x,int32_t y)92 void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
93   rep_->nodes[x]->out.Erase(y);
94   rep_->nodes[y]->in.Erase(x);
95   // No need to update the rank assignment since a previous valid
96   // rank assignment remains valid after an edge deletion.
97 }
98 
99 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
100 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
101 static void Reorder(GraphCycles::Rep* r);
102 static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
103 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
104                        Vec<int32_t>* dst);
105 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
106 
InsertEdge(int32_t x,int32_t y)107 bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
108   if (x == y) return false;
109   Rep* r = rep_;
110   Node* nx = r->nodes[x];
111   if (!nx->out.Insert(y)) {
112     // Edge already exists.
113     return true;
114   }
115 
116   Node* ny = r->nodes[y];
117   ny->in.Insert(x);
118 
119   if (nx->rank <= ny->rank) {
120     // New edge is consistent with existing rank assignment.
121     return true;
122   }
123 
124   // Current rank assignments are incompatible with the new edge.  Recompute.
125   // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
126   if (ForwardDFS(r, y, nx->rank)) {
127     // Found a cycle.  Undo the insertion and tell caller.
128     nx->out.Erase(y);
129     ny->in.Erase(x);
130     // Since we do not call Reorder() on this path, clear any visited
131     // markers left by ForwardDFS.
132     ClearVisitedBits(r, r->deltaf);
133     return false;
134   }
135   BackwardDFS(r, x, ny->rank);
136   Reorder(r);
137   return true;
138 }
139 
140 // Follows the edges from producer to consumer and searchs if the node having
141 // rank `n` can reach the node having rank `upper_bound` using a DFS search.
142 // When doing DFS search, We only consider the pathes that satisfy the ranks
143 // of the nodes of the path are all smaller than `upper_bound`.
144 //
145 // Returns true if such path exists.
ForwardDFS(GraphCycles::Rep * r,int32_t n,int32_t upper_bound)146 static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
147   // Avoid recursion since stack space might be limited.
148   // We instead keep a stack of nodes to visit.
149   r->deltaf.clear();
150   r->stack.clear();
151   r->stack.push_back(n);
152   while (!r->stack.empty()) {
153     n = r->stack.back();
154     r->stack.pop_back();
155     Node* nn = r->nodes[n];
156     if (nn->visited) continue;
157 
158     nn->visited = true;
159     r->deltaf.push_back(n);
160 
161     for (auto w : nn->out.GetSequence()) {
162       Node* nw = r->nodes[w];
163       if (nw->rank == upper_bound) {
164         return true;
165       }
166       if (!nw->visited && nw->rank < upper_bound) {
167         r->stack.push_back(w);
168       }
169     }
170   }
171   return false;
172 }
173 
174 // Follows the edges from consumer to producer and visit all the nodes that
175 // is reachable from node `n` and have rank larger than `lower_bound`.
BackwardDFS(GraphCycles::Rep * r,int32_t n,int32_t lower_bound)176 static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
177   r->deltab.clear();
178   r->stack.clear();
179   r->stack.push_back(n);
180   while (!r->stack.empty()) {
181     n = r->stack.back();
182     r->stack.pop_back();
183     Node* nn = r->nodes[n];
184     if (nn->visited) continue;
185 
186     nn->visited = true;
187     r->deltab.push_back(n);
188 
189     for (auto w : nn->in.GetSequence()) {
190       Node* nw = r->nodes[w];
191       if (!nw->visited && lower_bound < nw->rank) {
192         r->stack.push_back(w);
193       }
194     }
195   }
196 }
197 
198 // Recomputes rank assignments to make them compatible with the edges (producer
199 // has smaller rank than its consumer)
Reorder(GraphCycles::Rep * r)200 static void Reorder(GraphCycles::Rep* r) {
201   Sort(r->nodes, &r->deltab);
202   Sort(r->nodes, &r->deltaf);
203 
204   // Adds contents of delta lists to list (backwards deltas first).
205   r->list.clear();
206   MoveToList(r, &r->deltab, &r->list);
207   MoveToList(r, &r->deltaf, &r->list);
208 
209   // Produce sorted list of all ranks that will be reassigned.
210   r->merged.resize(r->deltab.size() + r->deltaf.size());
211   std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
212              r->deltaf.end(), r->merged.begin());
213 
214   // Assign the ranks in order to the collected list.
215   for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
216     r->nodes[r->list[i]]->rank = r->merged[i];
217   }
218 }
219 
220 // Sorts nodes in the vector according to their ranks. Small rank first.
Sort(const Vec<Node * > & nodes,Vec<int32_t> * delta)221 static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
222   struct ByRank {
223     const Vec<Node*>* nodes;
224     bool operator()(int32_t a, int32_t b) const {
225       return (*nodes)[a]->rank < (*nodes)[b]->rank;
226     }
227   };
228   ByRank cmp;
229   cmp.nodes = &nodes;
230   std::sort(delta->begin(), delta->end(), cmp);
231 }
232 
233 // Collects ranks of nodes in vector `src` to vector `dst`
MoveToList(GraphCycles::Rep * r,Vec<int32_t> * src,Vec<int32_t> * dst)234 static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
235                        Vec<int32_t>* dst) {
236   for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
237     int32_t w = (*src)[i];
238     // Replace src entry with its rank
239     (*src)[i] = r->nodes[w]->rank;
240     // Prepare for future DFS calls
241     r->nodes[w]->visited = false;
242     dst->push_back(w);
243   }
244 }
245 
246 // Clears bookkeeping fileds used during the last DFS process.
ClearVisitedBits(GraphCycles::Rep * r,const Vec<int32_t> & nodes)247 static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
248   for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
249     r->nodes[nodes[i]]->visited = false;
250   }
251 }
252 
IsReachable(int32_t x,int32_t y)253 bool GraphCycles::IsReachable(int32_t x, int32_t y) {
254   if (x == y) return true;
255   Rep* r = rep_;
256   Node* nx = r->nodes[x];
257   Node* ny = r->nodes[y];
258 
259   if (nx->rank >= ny->rank) {
260     // x cannot reach y since it is after it in the topological ordering
261     return false;
262   }
263 
264   // See if x can reach y using a DFS search that is limited to y's rank
265   bool reachable = ForwardDFS(r, x, ny->rank);
266 
267   // Clear any visited markers left by ForwardDFS.
268   ClearVisitedBits(r, r->deltaf);
269   return reachable;
270 }
271 
ContractEdge(int32_t a,int32_t b)272 llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
273   assert(HasEdge(a, b));
274   RemoveEdge(a, b);
275 
276   if (IsReachable(a, b)) {
277     // Restore the graph to its original state.
278     InsertEdge(a, b);
279     return {};
280   }
281 
282   if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
283       rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
284     // Swap "a" and "b" to minimize copying.
285     std::swap(a, b);
286   }
287 
288   Node* nb = rep_->nodes[b];
289   OrderedNodeSet out = std::move(nb->out);
290   OrderedNodeSet in = std::move(nb->in);
291   for (int32_t y : out.GetSequence()) {
292     rep_->nodes[y]->in.Erase(b);
293   }
294   for (int32_t y : in.GetSequence()) {
295     rep_->nodes[y]->out.Erase(b);
296   }
297   rep_->free_nodes.push_back(b);
298 
299   rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
300   for (int32_t y : out.GetSequence()) {
301     InsertEdge(a, y);
302   }
303 
304   rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
305   for (int32_t y : in.GetSequence()) {
306     InsertEdge(y, a);
307   }
308 
309   // Note, if the swap happened it might be what originally was called "b".
310   return a;
311 }
312 
SuccessorsCopy(int32_t node) const313 std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
314   return rep_->nodes[node]->out.GetSequence();
315 }
316 
317 namespace {
SortInPostOrder(const Vec<Node * > & nodes,std::vector<int32_t> * to_sort)318 void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
319   std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
320     return nodes[a]->rank > nodes[b]->rank;
321   });
322 }
323 }  // namespace
324 
AllNodesInPostOrder() const325 std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
326   llvm::DenseSet<int32_t> free_nodes_set;
327   for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
328 
329   std::vector<int32_t> all_nodes;
330   all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
331   for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
332     if (!free_nodes_set.count(i)) {
333       all_nodes.push_back(i);
334     }
335   }
336 
337   SortInPostOrder(rep_->nodes, &all_nodes);
338   return all_nodes;
339 }
340 
341 }  // namespace mlir
342