• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/algorithm.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28                    const std::function<void(T)>& enter,
29                    const std::function<void(T)>& leave,
30                    const NodeComparator& stable_comparator,
31                    const EdgeFilter& edge_filter) {
32   // Stack of work to do.
33   struct Work {
34     T node;
35     bool leave;  // Are we entering or leaving n?
36   };
37   std::vector<Work> stack(start.size());
38   for (int i = 0; i < start.size(); ++i) {
39     stack[i] = Work{start[i], false};
40   }
41 
42   std::vector<bool> visited(g.num_node_ids(), false);
43   while (!stack.empty()) {
44     Work w = stack.back();
45     stack.pop_back();
46 
47     T n = w.node;
48     if (w.leave) {
49       leave(n);
50       continue;
51     }
52 
53     if (visited[n->id()]) continue;
54     visited[n->id()] = true;
55     if (enter) enter(n);
56 
57     // Arrange to call leave(n) when all done with descendants.
58     if (leave) stack.push_back(Work{n, true});
59 
60     auto add_work = [&visited, &stack](Node* out) {
61       if (!visited[out->id()]) {
62         // Note; we must not mark as visited until we actually process it.
63         stack.push_back(Work{out, false});
64       }
65     };
66 
67     if (stable_comparator) {
68       std::vector<Node*> nodes_sorted;
69       for (const Edge* out_edge : n->out_edges()) {
70         if (!edge_filter || edge_filter(*out_edge)) {
71           nodes_sorted.emplace_back(out_edge->dst());
72         }
73       }
74       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75       for (Node* out : nodes_sorted) {
76         add_work(out);
77       }
78     } else {
79       for (const Edge* out_edge : n->out_edges()) {
80         if (!edge_filter || edge_filter(*out_edge)) {
81           add_work(out_edge->dst());
82         }
83       }
84     }
85   }
86 }
87 }  // namespace
88 
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90          const std::function<void(Node*)>& leave,
91          const NodeComparator& stable_comparator,
92          const EdgeFilter& edge_filter) {
93   DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94                 edge_filter);
95 }
96 
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98              const std::function<void(Node*)>& enter,
99              const std::function<void(Node*)>& leave,
100              const NodeComparator& stable_comparator,
101              const EdgeFilter& edge_filter) {
102   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104 
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106              const std::function<void(const Node*)>& enter,
107              const std::function<void(const Node*)>& leave,
108              const NodeComparator& stable_comparator,
109              const EdgeFilter& edge_filter) {
110   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112 
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114                 const std::function<void(Node*)>& leave,
115                 const NodeComparator& stable_comparator) {
116   ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
117 }
118 
119 namespace {
120 
121 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator)122 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
123                           const std::function<void(T)>& enter,
124                           const std::function<void(T)>& leave,
125                           const NodeComparator& stable_comparator) {
126   // Stack of work to do.
127   struct Work {
128     T node;
129     bool leave;  // Are we entering or leaving n?
130   };
131   std::vector<Work> stack(start.size());
132   for (int i = 0; i < start.size(); ++i) {
133     stack[i] = Work{start[i], false};
134   }
135 
136   std::vector<bool> visited(g.num_node_ids(), false);
137   while (!stack.empty()) {
138     Work w = stack.back();
139     stack.pop_back();
140 
141     T n = w.node;
142     if (w.leave) {
143       leave(n);
144       continue;
145     }
146 
147     if (visited[n->id()]) continue;
148     visited[n->id()] = true;
149     if (enter) enter(n);
150 
151     // Arrange to call leave(n) when all done with descendants.
152     if (leave) stack.push_back(Work{n, true});
153 
154     auto add_work = [&visited, &stack](T out) {
155       if (!visited[out->id()]) {
156         // Note; we must not mark as visited until we actually process it.
157         stack.push_back(Work{out, false});
158       }
159     };
160 
161     if (stable_comparator) {
162       std::vector<T> nodes_sorted;
163       for (const Edge* in_edge : n->in_edges()) {
164         nodes_sorted.emplace_back(in_edge->src());
165       }
166       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
167       for (T in : nodes_sorted) {
168         add_work(in);
169       }
170     } else {
171       for (const Edge* in_edge : n->in_edges()) {
172         add_work(in_edge->src());
173       }
174     }
175   }
176 }
177 
178 }  // namespace
179 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator)180 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
181                     const std::function<void(const Node*)>& enter,
182                     const std::function<void(const Node*)>& leave,
183                     const NodeComparator& stable_comparator) {
184   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
185 }
186 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
188                     const std::function<void(Node*)>& enter,
189                     const std::function<void(Node*)>& leave,
190                     const NodeComparator& stable_comparator) {
191   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
192 }
193 
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)194 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
195                   const NodeComparator& stable_comparator,
196                   const EdgeFilter& edge_filter) {
197   order->clear();
198   DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
199       edge_filter);
200 }
201 
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)202 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
203                          const NodeComparator& stable_comparator,
204                          const EdgeFilter& edge_filter) {
205   GetPostOrder(g, order, stable_comparator, edge_filter);
206   std::reverse(order->begin(), order->end());
207 }
208 
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > visited)209 bool PruneForReverseReachability(Graph* g,
210                                  std::unordered_set<const Node*> visited) {
211   // Compute set of nodes that we need to traverse in order to reach
212   // the nodes in "nodes" by performing a breadth-first search from those
213   // nodes, and accumulating the visited nodes.
214   std::deque<const Node*> queue;
215   for (const Node* n : visited) {
216     VLOG(2) << "Reverse reach init: " << n->name();
217     queue.push_back(n);
218   }
219   while (!queue.empty()) {
220     const Node* n = queue.front();
221     queue.pop_front();
222     for (const Node* in : n->in_nodes()) {
223       if (visited.insert(in).second) {
224         queue.push_back(in);
225         VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
226       }
227     }
228   }
229 
230   // Make a pass over the graph to remove nodes not in "visited"
231   std::vector<Node*> all_nodes;
232   all_nodes.reserve(g->num_nodes());
233   for (Node* n : g->nodes()) {
234     all_nodes.push_back(n);
235   }
236 
237   bool any_removed = false;
238   for (Node* n : all_nodes) {
239     if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
240       g->RemoveNode(n);
241       any_removed = true;
242     }
243   }
244 
245   return any_removed;
246 }
247 
FixupSourceAndSinkEdges(Graph * g)248 bool FixupSourceAndSinkEdges(Graph* g) {
249   // Connect all nodes with no incoming edges to source.
250   // Connect all nodes with no outgoing edges to sink.
251   bool changed = false;
252   for (Node* n : g->nodes()) {
253     if (!n->IsSource() && n->in_edges().empty()) {
254       g->AddControlEdge(g->source_node(), n,
255                         true /* skip test for duplicates */);
256       changed = true;
257     }
258     if (!n->IsSink() && n->out_edges().empty()) {
259       g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
260       changed = true;
261     }
262   }
263   return changed;
264 }
265 
266 }  // namespace tensorflow
267