1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/topological_sort.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <unordered_map>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/grappler/graph_topology_view.h"
25 #include "tensorflow/core/grappler/graph_view.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/status.h"
29
30 namespace tensorflow {
31 namespace grappler {
32
33 namespace {
34
MakeEphemeralEdges(const absl::Span<const TopologicalDependency> extra_dependencies)35 std::vector<GraphView::Edge> MakeEphemeralEdges(
36 const absl::Span<const TopologicalDependency> extra_dependencies) {
37 std::vector<GraphView::Edge> ephemeral_edges;
38 ephemeral_edges.reserve(extra_dependencies.size());
39 for (const auto& dep : extra_dependencies) {
40 ephemeral_edges.emplace_back(
41 GraphView::OutputPort(dep.from, Graph::kControlSlot),
42 GraphView::InputPort(dep.to, Graph::kControlSlot));
43 }
44 return ephemeral_edges;
45 }
46
47 // Kahn's algorithm is implemented.
48 // For details, see https://en.wikipedia.org/wiki/Topological_sorting
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<int> * ready_nodes)49 Status ComputeTopologicalOrder(
50 const GraphDef& graph,
51 const absl::Span<const TopologicalDependency> extra_dependencies,
52 std::vector<int>* ready_nodes) {
53 GraphTopologyView graph_view;
54 TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
55 graph, MakeEphemeralEdges(extra_dependencies)));
56
57 // Keep track of how many inputs are ready for the given node.
58 std::vector<int> num_ready_inputs(graph.node_size(), 0);
59
60 // We'll push index of ready nodes to this output vector.
61 ready_nodes->reserve(graph.node_size());
62
63 int front = 0;
64 int back = 0;
65
66 for (int i = 0; i < graph.node_size(); i++) {
67 if (graph_view.GetFanin(i).empty()) {
68 ready_nodes->push_back(i);
69 back++;
70 }
71 if (IsMerge(graph.node(i))) {
72 for (int input : graph_view.GetFanin(i)) {
73 if (IsNextIteration(graph.node(input))) {
74 num_ready_inputs[i]++;
75 }
76 }
77 }
78 }
79
80 while (front != back) {
81 int ready_node = (*ready_nodes)[front];
82 for (int fanout : graph_view.GetFanout(ready_node)) {
83 ++num_ready_inputs[fanout];
84 if (num_ready_inputs[fanout] == graph_view.GetFanin(fanout).size()) {
85 ready_nodes->push_back(fanout);
86 ++back;
87 }
88 }
89 ++front;
90 }
91
92 if (back != graph_view.num_nodes()) {
93 return errors::InvalidArgument(
94 "The graph couldn't be sorted in topological order.");
95 }
96 return Status::OK();
97 }
98
99 } // namespace
100
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<const NodeDef * > * topo_order)101 Status ComputeTopologicalOrder(
102 const GraphDef& graph,
103 const absl::Span<const TopologicalDependency> extra_dependencies,
104 std::vector<const NodeDef*>* topo_order) {
105 std::vector<int> ready_nodes;
106 TF_RETURN_IF_ERROR(
107 ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
108
109 topo_order->reserve(ready_nodes.size());
110 for (int ready_node_idx : ready_nodes) {
111 topo_order->emplace_back(&graph.node(ready_node_idx));
112 }
113
114 return Status::OK();
115 }
116
ComputeTopologicalOrder(const GraphDef & graph,std::vector<const NodeDef * > * topo_order)117 Status ComputeTopologicalOrder(const GraphDef& graph,
118 std::vector<const NodeDef*>* topo_order) {
119 return ComputeTopologicalOrder(graph, {}, topo_order);
120 }
121
ReversedTopologicalSort(GraphDef * graph)122 Status ReversedTopologicalSort(GraphDef* graph) {
123 std::vector<int> ready_nodes;
124 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
125 std::reverse(ready_nodes.begin(), ready_nodes.end());
126 PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
127 return Status::OK();
128 }
129
TopologicalSort(GraphDef * graph)130 Status TopologicalSort(GraphDef* graph) {
131 std::vector<int> ready_nodes;
132 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
133 PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
134 return Status::OK();
135 }
136
137 } // namespace grappler
138 } // namespace tensorflow
139