• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/algorithm.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28                    const std::function<void(T)>& enter,
29                    const std::function<void(T)>& leave,
30                    const NodeComparator& stable_comparator,
31                    const EdgeFilter& edge_filter) {
32   // Stack of work to do.
33   struct Work {
34     T node;
35     bool leave;  // Are we entering or leaving n?
36   };
37   std::vector<Work> stack(start.size());
38   for (int i = 0; i < start.size(); ++i) {
39     stack[i] = Work{start[i], false};
40   }
41 
42   std::vector<bool> visited(g.num_node_ids(), false);
43   while (!stack.empty()) {
44     Work w = stack.back();
45     stack.pop_back();
46 
47     T n = w.node;
48     if (w.leave) {
49       leave(n);
50       continue;
51     }
52 
53     if (visited[n->id()]) continue;
54     visited[n->id()] = true;
55     if (enter) enter(n);
56 
57     // Arrange to call leave(n) when all done with descendants.
58     if (leave) stack.push_back(Work{n, true});
59 
60     auto add_work = [&visited, &stack](Node* out) {
61       if (!visited[out->id()]) {
62         // Note; we must not mark as visited until we actually process it.
63         stack.push_back(Work{out, false});
64       }
65     };
66 
67     if (stable_comparator) {
68       std::vector<Node*> nodes_sorted;
69       for (const Edge* out_edge : n->out_edges()) {
70         if (!edge_filter || edge_filter(*out_edge)) {
71           nodes_sorted.emplace_back(out_edge->dst());
72         }
73       }
74       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75       for (Node* out : nodes_sorted) {
76         add_work(out);
77       }
78     } else {
79       for (const Edge* out_edge : n->out_edges()) {
80         if (!edge_filter || edge_filter(*out_edge)) {
81           add_work(out_edge->dst());
82         }
83       }
84     }
85   }
86 }
87 }  // namespace
88 
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90          const std::function<void(Node*)>& leave,
91          const NodeComparator& stable_comparator,
92          const EdgeFilter& edge_filter) {
93   DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94                 edge_filter);
95 }
96 
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98              const std::function<void(Node*)>& enter,
99              const std::function<void(Node*)>& leave,
100              const NodeComparator& stable_comparator,
101              const EdgeFilter& edge_filter) {
102   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104 
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106              const std::function<void(const Node*)>& enter,
107              const std::function<void(const Node*)>& leave,
108              const NodeComparator& stable_comparator,
109              const EdgeFilter& edge_filter) {
110   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112 
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114                 const std::function<void(Node*)>& leave,
115                 const NodeComparator& stable_comparator) {
116   ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
117 }
118 
119 namespace {
120 
121 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator)122 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
123                           const std::function<void(T)>& enter,
124                           const std::function<void(T)>& leave,
125                           const NodeComparator& stable_comparator) {
126   // Stack of work to do.
127   struct Work {
128     T node;
129     bool leave;  // Are we entering or leaving n?
130   };
131   std::vector<Work> stack(start.size());
132   for (int i = 0; i < start.size(); ++i) {
133     stack[i] = Work{start[i], false};
134   }
135 
136   std::vector<bool> visited(g.num_node_ids(), false);
137   while (!stack.empty()) {
138     Work w = stack.back();
139     stack.pop_back();
140 
141     T n = w.node;
142     if (w.leave) {
143       leave(n);
144       continue;
145     }
146 
147     if (visited[n->id()]) continue;
148     visited[n->id()] = true;
149     if (enter) enter(n);
150 
151     // Arrange to call leave(n) when all done with descendants.
152     if (leave) stack.push_back(Work{n, true});
153 
154     auto add_work = [&visited, &stack](T out) {
155       if (!visited[out->id()]) {
156         // Note; we must not mark as visited until we actually process it.
157         stack.push_back(Work{out, false});
158       }
159     };
160 
161     if (stable_comparator) {
162       std::vector<T> nodes_sorted;
163       for (const Edge* in_edge : n->in_edges()) {
164         nodes_sorted.emplace_back(in_edge->src());
165       }
166       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
167       for (T in : nodes_sorted) {
168         add_work(in);
169       }
170     } else {
171       for (const Edge* in_edge : n->in_edges()) {
172         add_work(in_edge->src());
173       }
174     }
175   }
176 }
177 
178 }  // namespace
179 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator)180 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
181                     const std::function<void(const Node*)>& enter,
182                     const std::function<void(const Node*)>& leave,
183                     const NodeComparator& stable_comparator) {
184   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
185 }
186 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
188                     const std::function<void(Node*)>& enter,
189                     const std::function<void(Node*)>& leave,
190                     const NodeComparator& stable_comparator) {
191   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
192 }
193 
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)194 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
195                   const NodeComparator& stable_comparator,
196                   const EdgeFilter& edge_filter) {
197   order->clear();
198   DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
199       edge_filter);
200 }
201 
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)202 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
203                          const NodeComparator& stable_comparator,
204                          const EdgeFilter& edge_filter) {
205   GetPostOrder(g, order, stable_comparator, edge_filter);
206   std::reverse(order->begin(), order->end());
207 }
208 
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > start)209 bool PruneForReverseReachability(Graph* g,
210                                  std::unordered_set<const Node*> start) {
211   // Compute set of nodes that we need to traverse in order to reach
212   // the nodes in "start" by performing a breadth-first search from those
213   // nodes, and accumulating the visited nodes.
214   std::vector<bool> visited(g->num_node_ids());
215   for (auto node : start) {
216     visited[node->id()] = true;
217   }
218   std::deque<const Node*> queue(start.begin(), start.end());
219   while (!queue.empty()) {
220     const Node* n = queue.front();
221     queue.pop_front();
222     for (const Node* in : n->in_nodes()) {
223       if (!visited[in->id()]) {
224         visited[in->id()] = true;
225         queue.push_back(in);
226         VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
227       }
228     }
229   }
230 
231   // Make a pass over the graph to remove nodes not in "visited".
232   bool any_removed = false;
233   for (int i = 0; i < visited.size(); ++i) {
234     if (!visited[i]) {
235       Node* n = g->FindNodeId(i);
236       if (n != nullptr && !n->IsSource() && !n->IsSink()) {
237         g->RemoveNode(n);
238         any_removed = true;
239       }
240     }
241   }
242   return any_removed;
243 }
244 
FixupSourceAndSinkEdges(Graph * g)245 bool FixupSourceAndSinkEdges(Graph* g) {
246   // Connect all nodes with no incoming edges to source.
247   // Connect all nodes with no outgoing edges to sink.
248   bool changed = false;
249   for (Node* n : g->nodes()) {
250     if (!n->IsSource() && n->in_edges().empty()) {
251       g->AddControlEdge(g->source_node(), n,
252                         true /* skip test for duplicates */);
253       changed = true;
254     }
255     if (!n->IsSink() && n->out_edges().empty()) {
256       g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
257       changed = true;
258     }
259   }
260   return changed;
261 }
262 
263 }  // namespace tensorflow
264