1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/algorithm.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28 const std::function<void(T)>& enter,
29 const std::function<void(T)>& leave,
30 const NodeComparator& stable_comparator,
31 const EdgeFilter& edge_filter) {
32 // Stack of work to do.
33 struct Work {
34 T node;
35 bool leave; // Are we entering or leaving n?
36 };
37 std::vector<Work> stack(start.size());
38 for (int i = 0; i < start.size(); ++i) {
39 stack[i] = Work{start[i], false};
40 }
41
42 std::vector<bool> visited(g.num_node_ids(), false);
43 while (!stack.empty()) {
44 Work w = stack.back();
45 stack.pop_back();
46
47 T n = w.node;
48 if (w.leave) {
49 leave(n);
50 continue;
51 }
52
53 if (visited[n->id()]) continue;
54 visited[n->id()] = true;
55 if (enter) enter(n);
56
57 // Arrange to call leave(n) when all done with descendants.
58 if (leave) stack.push_back(Work{n, true});
59
60 auto add_work = [&visited, &stack](Node* out) {
61 if (!visited[out->id()]) {
62 // Note; we must not mark as visited until we actually process it.
63 stack.push_back(Work{out, false});
64 }
65 };
66
67 if (stable_comparator) {
68 std::vector<Node*> nodes_sorted;
69 for (const Edge* out_edge : n->out_edges()) {
70 if (!edge_filter || edge_filter(*out_edge)) {
71 nodes_sorted.emplace_back(out_edge->dst());
72 }
73 }
74 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75 for (Node* out : nodes_sorted) {
76 add_work(out);
77 }
78 } else {
79 for (const Edge* out_edge : n->out_edges()) {
80 if (!edge_filter || edge_filter(*out_edge)) {
81 add_work(out_edge->dst());
82 }
83 }
84 }
85 }
86 }
87 } // namespace
88
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90 const std::function<void(Node*)>& leave,
91 const NodeComparator& stable_comparator,
92 const EdgeFilter& edge_filter) {
93 DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94 edge_filter);
95 }
96
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98 const std::function<void(Node*)>& enter,
99 const std::function<void(Node*)>& leave,
100 const NodeComparator& stable_comparator,
101 const EdgeFilter& edge_filter) {
102 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106 const std::function<void(const Node*)>& enter,
107 const std::function<void(const Node*)>& leave,
108 const NodeComparator& stable_comparator,
109 const EdgeFilter& edge_filter) {
110 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114 const std::function<void(Node*)>& leave,
115 const NodeComparator& stable_comparator,
116 const EdgeFilter& edge_filter) {
117 ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator,
118 edge_filter);
119 }
120
121 namespace {
122
123 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)124 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
125 const std::function<void(T)>& enter,
126 const std::function<void(T)>& leave,
127 const NodeComparator& stable_comparator,
128 const EdgeFilter& edge_filter) {
129 // Stack of work to do.
130 struct Work {
131 T node;
132 bool leave; // Are we entering or leaving n?
133 };
134 std::vector<Work> stack(start.size());
135 for (int i = 0; i < start.size(); ++i) {
136 stack[i] = Work{start[i], false};
137 }
138
139 std::vector<bool> visited(g.num_node_ids(), false);
140 while (!stack.empty()) {
141 Work w = stack.back();
142 stack.pop_back();
143
144 T n = w.node;
145 if (w.leave) {
146 leave(n);
147 continue;
148 }
149
150 if (visited[n->id()]) continue;
151 visited[n->id()] = true;
152 if (enter) enter(n);
153
154 // Arrange to call leave(n) when all done with descendants.
155 if (leave) stack.push_back(Work{n, true});
156
157 auto add_work = [&visited, &stack](T out) {
158 if (!visited[out->id()]) {
159 // Note; we must not mark as visited until we actually process it.
160 stack.push_back(Work{out, false});
161 }
162 };
163
164 if (stable_comparator) {
165 std::vector<T> nodes_sorted;
166 for (const Edge* in_edge : n->in_edges()) {
167 if (!edge_filter || edge_filter(*in_edge)) {
168 nodes_sorted.emplace_back(in_edge->src());
169 }
170 }
171 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
172 for (T in : nodes_sorted) {
173 add_work(in);
174 }
175 } else {
176 for (const Edge* in_edge : n->in_edges()) {
177 if (!edge_filter || edge_filter(*in_edge)) {
178 add_work(in_edge->src());
179 }
180 }
181 }
182 }
183 }
184
185 } // namespace
186
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
188 const std::function<void(const Node*)>& enter,
189 const std::function<void(const Node*)>& leave,
190 const NodeComparator& stable_comparator,
191 const EdgeFilter& edge_filter) {
192 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
193 }
194
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)195 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
196 const std::function<void(Node*)>& enter,
197 const std::function<void(Node*)>& leave,
198 const NodeComparator& stable_comparator,
199 const EdgeFilter& edge_filter) {
200 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
201 }
202
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)203 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
204 const NodeComparator& stable_comparator,
205 const EdgeFilter& edge_filter) {
206 order->clear();
207 DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
208 edge_filter);
209 }
210
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)211 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
212 const NodeComparator& stable_comparator,
213 const EdgeFilter& edge_filter) {
214 GetPostOrder(g, order, stable_comparator, edge_filter);
215 std::reverse(order->begin(), order->end());
216 }
217
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > start)218 bool PruneForReverseReachability(Graph* g,
219 std::unordered_set<const Node*> start) {
220 // Compute set of nodes that we need to traverse in order to reach
221 // the nodes in "start" by performing a breadth-first search from those
222 // nodes, and accumulating the visited nodes.
223 std::vector<bool> visited(g->num_node_ids());
224 for (auto node : start) {
225 visited[node->id()] = true;
226 }
227 std::deque<const Node*> queue(start.begin(), start.end());
228 while (!queue.empty()) {
229 const Node* n = queue.front();
230 queue.pop_front();
231 for (const Node* in : n->in_nodes()) {
232 if (!visited[in->id()]) {
233 visited[in->id()] = true;
234 queue.push_back(in);
235 VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
236 }
237 }
238 }
239
240 // Make a pass over the graph to remove nodes not in "visited".
241 bool any_removed = false;
242 for (int i = 0; i < visited.size(); ++i) {
243 if (!visited[i]) {
244 Node* n = g->FindNodeId(i);
245 if (n != nullptr && !n->IsSource() && !n->IsSink()) {
246 g->RemoveNode(n);
247 any_removed = true;
248 }
249 }
250 }
251 return any_removed;
252 }
253
FixupSourceAndSinkEdges(Graph * g)254 bool FixupSourceAndSinkEdges(Graph* g) {
255 // Connect all nodes with no incoming edges to source.
256 // Connect all nodes with no outgoing edges to sink.
257 bool changed = false;
258 for (Node* n : g->nodes()) {
259 if (!n->IsSource() && n->in_edges().empty()) {
260 g->AddControlEdge(g->source_node(), n,
261 true /* skip test for duplicates */);
262 changed = true;
263 }
264 if (!n->IsSink() && n->out_edges().empty()) {
265 g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
266 changed = true;
267 }
268 }
269 return changed;
270 }
271
272 } // namespace tensorflow
273