1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/algorithm.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28 const std::function<void(T)>& enter,
29 const std::function<void(T)>& leave,
30 const NodeComparator& stable_comparator,
31 const EdgeFilter& edge_filter) {
32 // Stack of work to do.
33 struct Work {
34 T node;
35 bool leave; // Are we entering or leaving n?
36 };
37 std::vector<Work> stack(start.size());
38 for (int i = 0; i < start.size(); ++i) {
39 stack[i] = Work{start[i], false};
40 }
41
42 std::vector<bool> visited(g.num_node_ids(), false);
43 while (!stack.empty()) {
44 Work w = stack.back();
45 stack.pop_back();
46
47 T n = w.node;
48 if (w.leave) {
49 leave(n);
50 continue;
51 }
52
53 if (visited[n->id()]) continue;
54 visited[n->id()] = true;
55 if (enter) enter(n);
56
57 // Arrange to call leave(n) when all done with descendants.
58 if (leave) stack.push_back(Work{n, true});
59
60 auto add_work = [&visited, &stack](Node* out) {
61 if (!visited[out->id()]) {
62 // Note; we must not mark as visited until we actually process it.
63 stack.push_back(Work{out, false});
64 }
65 };
66
67 if (stable_comparator) {
68 std::vector<Node*> nodes_sorted;
69 for (const Edge* out_edge : n->out_edges()) {
70 if (!edge_filter || edge_filter(*out_edge)) {
71 nodes_sorted.emplace_back(out_edge->dst());
72 }
73 }
74 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75 for (Node* out : nodes_sorted) {
76 add_work(out);
77 }
78 } else {
79 for (const Edge* out_edge : n->out_edges()) {
80 if (!edge_filter || edge_filter(*out_edge)) {
81 add_work(out_edge->dst());
82 }
83 }
84 }
85 }
86 }
87 } // namespace
88
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90 const std::function<void(Node*)>& leave,
91 const NodeComparator& stable_comparator,
92 const EdgeFilter& edge_filter) {
93 DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94 edge_filter);
95 }
96
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98 const std::function<void(Node*)>& enter,
99 const std::function<void(Node*)>& leave,
100 const NodeComparator& stable_comparator,
101 const EdgeFilter& edge_filter) {
102 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106 const std::function<void(const Node*)>& enter,
107 const std::function<void(const Node*)>& leave,
108 const NodeComparator& stable_comparator,
109 const EdgeFilter& edge_filter) {
110 DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114 const std::function<void(Node*)>& leave,
115 const NodeComparator& stable_comparator) {
116 ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
117 }
118
119 namespace {
120
121 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator)122 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
123 const std::function<void(T)>& enter,
124 const std::function<void(T)>& leave,
125 const NodeComparator& stable_comparator) {
126 // Stack of work to do.
127 struct Work {
128 T node;
129 bool leave; // Are we entering or leaving n?
130 };
131 std::vector<Work> stack(start.size());
132 for (int i = 0; i < start.size(); ++i) {
133 stack[i] = Work{start[i], false};
134 }
135
136 std::vector<bool> visited(g.num_node_ids(), false);
137 while (!stack.empty()) {
138 Work w = stack.back();
139 stack.pop_back();
140
141 T n = w.node;
142 if (w.leave) {
143 leave(n);
144 continue;
145 }
146
147 if (visited[n->id()]) continue;
148 visited[n->id()] = true;
149 if (enter) enter(n);
150
151 // Arrange to call leave(n) when all done with descendants.
152 if (leave) stack.push_back(Work{n, true});
153
154 auto add_work = [&visited, &stack](T out) {
155 if (!visited[out->id()]) {
156 // Note; we must not mark as visited until we actually process it.
157 stack.push_back(Work{out, false});
158 }
159 };
160
161 if (stable_comparator) {
162 std::vector<T> nodes_sorted;
163 for (const Edge* in_edge : n->in_edges()) {
164 nodes_sorted.emplace_back(in_edge->src());
165 }
166 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
167 for (T in : nodes_sorted) {
168 add_work(in);
169 }
170 } else {
171 for (const Edge* in_edge : n->in_edges()) {
172 add_work(in_edge->src());
173 }
174 }
175 }
176 }
177
178 } // namespace
179
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator)180 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
181 const std::function<void(const Node*)>& enter,
182 const std::function<void(const Node*)>& leave,
183 const NodeComparator& stable_comparator) {
184 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
185 }
186
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
188 const std::function<void(Node*)>& enter,
189 const std::function<void(Node*)>& leave,
190 const NodeComparator& stable_comparator) {
191 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
192 }
193
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)194 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
195 const NodeComparator& stable_comparator,
196 const EdgeFilter& edge_filter) {
197 order->clear();
198 DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
199 edge_filter);
200 }
201
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)202 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
203 const NodeComparator& stable_comparator,
204 const EdgeFilter& edge_filter) {
205 GetPostOrder(g, order, stable_comparator, edge_filter);
206 std::reverse(order->begin(), order->end());
207 }
208
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > start)209 bool PruneForReverseReachability(Graph* g,
210 std::unordered_set<const Node*> start) {
211 // Compute set of nodes that we need to traverse in order to reach
212 // the nodes in "start" by performing a breadth-first search from those
213 // nodes, and accumulating the visited nodes.
214 std::vector<bool> visited(g->num_node_ids());
215 for (auto node : start) {
216 visited[node->id()] = true;
217 }
218 std::deque<const Node*> queue(start.begin(), start.end());
219 while (!queue.empty()) {
220 const Node* n = queue.front();
221 queue.pop_front();
222 for (const Node* in : n->in_nodes()) {
223 if (!visited[in->id()]) {
224 visited[in->id()] = true;
225 queue.push_back(in);
226 VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
227 }
228 }
229 }
230
231 // Make a pass over the graph to remove nodes not in "visited".
232 bool any_removed = false;
233 for (int i = 0; i < visited.size(); ++i) {
234 if (!visited[i]) {
235 Node* n = g->FindNodeId(i);
236 if (n != nullptr && !n->IsSource() && !n->IsSink()) {
237 g->RemoveNode(n);
238 any_removed = true;
239 }
240 }
241 }
242 return any_removed;
243 }
244
FixupSourceAndSinkEdges(Graph * g)245 bool FixupSourceAndSinkEdges(Graph* g) {
246 // Connect all nodes with no incoming edges to source.
247 // Connect all nodes with no outgoing edges to sink.
248 bool changed = false;
249 for (Node* n : g->nodes()) {
250 if (!n->IsSource() && n->in_edges().empty()) {
251 g->AddControlEdge(g->source_node(), n,
252 true /* skip test for duplicates */);
253 changed = true;
254 }
255 if (!n->IsSink() && n->out_edges().empty()) {
256 g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
257 changed = true;
258 }
259 }
260 return changed;
261 }
262
263 } // namespace tensorflow
264