1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/algorithm.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28 const std::function<void(T)>& enter,
29 const std::function<void(T)>& leave,
30 const NodeComparator& stable_comparator,
31 const EdgeFilter& edge_filter) {
32 // Stack of work to do.
33 struct Work {
34 T node;
35 bool leave; // Are we entering or leaving n?
36 };
37 std::vector<Work> stack(start.size());
38 for (int i = 0; i < start.size(); ++i) {
39 stack[i] = Work{start[i], false};
40 }
41
42 std::vector<bool> visited(g.num_node_ids(), false);
43 while (!stack.empty()) {
44 Work w = stack.back();
45 stack.pop_back();
46
47 T n = w.node;
48 if (w.leave) {
49 leave(n);
50 continue;
51 }
52
53 if (visited[n->id()]) continue;
54 visited[n->id()] = true;
55 if (enter) enter(n);
56
57 // Arrange to call leave(n) when all done with descendants.
58 if (leave) stack.push_back(Work{n, true});
59
60 auto add_work = [&visited, &stack](Node* out) {
61 if (!visited[out->id()]) {
62 // Note; we must not mark as visited until we actually process it.
63 stack.push_back(Work{out, false});
64 }
65 };
66
67 if (stable_comparator) {
68 std::vector<Node*> nodes_sorted;
69 for (const Edge* out_edge : n->out_edges()) {
70 if (!edge_filter || edge_filter(*out_edge)) {
71 nodes_sorted.emplace_back(out_edge->dst());
72 }
73 }
74 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75 for (Node* out : nodes_sorted) {
76 add_work(out);
77 }
78 } else {
79 for (const Edge* out_edge : n->out_edges()) {
80 if (!edge_filter || edge_filter(*out_edge)) {
81 add_work(out_edge->dst());
82 }
83 }
84 }
85 }
86 }
87 } // namespace
88
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90 const std::function<void(Node*)>& leave,
91 const NodeComparator& stable_comparator,
92 const EdgeFilter& edge_filter) {
93 DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94 edge_filter);
95 }
96
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98 const std::function<void(Node*)>& enter,
99 const std::function<void(Node*)>& leave,
100 const NodeComparator& stable_comparator,
101 const EdgeFilter& edge_filter) {
102 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106 const std::function<void(const Node*)>& enter,
107 const std::function<void(const Node*)>& leave,
108 const NodeComparator& stable_comparator,
109 const EdgeFilter& edge_filter) {
110 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114 const std::function<void(Node*)>& leave,
115 const NodeComparator& stable_comparator) {
116 ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
117 }
118
119 namespace {
120
121 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator)122 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
123 const std::function<void(T)>& enter,
124 const std::function<void(T)>& leave,
125 const NodeComparator& stable_comparator) {
126 // Stack of work to do.
127 struct Work {
128 T node;
129 bool leave; // Are we entering or leaving n?
130 };
131 std::vector<Work> stack(start.size());
132 for (int i = 0; i < start.size(); ++i) {
133 stack[i] = Work{start[i], false};
134 }
135
136 std::vector<bool> visited(g.num_node_ids(), false);
137 while (!stack.empty()) {
138 Work w = stack.back();
139 stack.pop_back();
140
141 T n = w.node;
142 if (w.leave) {
143 leave(n);
144 continue;
145 }
146
147 if (visited[n->id()]) continue;
148 visited[n->id()] = true;
149 if (enter) enter(n);
150
151 // Arrange to call leave(n) when all done with descendants.
152 if (leave) stack.push_back(Work{n, true});
153
154 auto add_work = [&visited, &stack](T out) {
155 if (!visited[out->id()]) {
156 // Note; we must not mark as visited until we actually process it.
157 stack.push_back(Work{out, false});
158 }
159 };
160
161 if (stable_comparator) {
162 std::vector<T> nodes_sorted;
163 for (const Edge* in_edge : n->in_edges()) {
164 nodes_sorted.emplace_back(in_edge->src());
165 }
166 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
167 for (T in : nodes_sorted) {
168 add_work(in);
169 }
170 } else {
171 for (const Edge* in_edge : n->in_edges()) {
172 add_work(in_edge->src());
173 }
174 }
175 }
176 }
177
178 } // namespace
179
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator)180 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
181 const std::function<void(const Node*)>& enter,
182 const std::function<void(const Node*)>& leave,
183 const NodeComparator& stable_comparator) {
184 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
185 }
186
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
188 const std::function<void(Node*)>& enter,
189 const std::function<void(Node*)>& leave,
190 const NodeComparator& stable_comparator) {
191 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
192 }
193
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)194 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
195 const NodeComparator& stable_comparator,
196 const EdgeFilter& edge_filter) {
197 order->clear();
198 DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
199 edge_filter);
200 }
201
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)202 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
203 const NodeComparator& stable_comparator,
204 const EdgeFilter& edge_filter) {
205 GetPostOrder(g, order, stable_comparator, edge_filter);
206 std::reverse(order->begin(), order->end());
207 }
208
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > visited)209 bool PruneForReverseReachability(Graph* g,
210 std::unordered_set<const Node*> visited) {
211 // Compute set of nodes that we need to traverse in order to reach
212 // the nodes in "nodes" by performing a breadth-first search from those
213 // nodes, and accumulating the visited nodes.
214 std::deque<const Node*> queue;
215 for (const Node* n : visited) {
216 VLOG(2) << "Reverse reach init: " << n->name();
217 queue.push_back(n);
218 }
219 while (!queue.empty()) {
220 const Node* n = queue.front();
221 queue.pop_front();
222 for (const Node* in : n->in_nodes()) {
223 if (visited.insert(in).second) {
224 queue.push_back(in);
225 VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
226 }
227 }
228 }
229
230 // Make a pass over the graph to remove nodes not in "visited"
231 std::vector<Node*> all_nodes;
232 all_nodes.reserve(g->num_nodes());
233 for (Node* n : g->nodes()) {
234 all_nodes.push_back(n);
235 }
236
237 bool any_removed = false;
238 for (Node* n : all_nodes) {
239 if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
240 g->RemoveNode(n);
241 any_removed = true;
242 }
243 }
244
245 return any_removed;
246 }
247
FixupSourceAndSinkEdges(Graph * g)248 bool FixupSourceAndSinkEdges(Graph* g) {
249 // Connect all nodes with no incoming edges to source.
250 // Connect all nodes with no outgoing edges to sink.
251 bool changed = false;
252 for (Node* n : g->nodes()) {
253 if (!n->IsSource() && n->in_edges().empty()) {
254 g->AddControlEdge(g->source_node(), n,
255 true /* skip test for duplicates */);
256 changed = true;
257 }
258 if (!n->IsSink() && n->out_edges().empty()) {
259 g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
260 changed = true;
261 }
262 }
263 return changed;
264 }
265
266 } // namespace tensorflow
267