• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/topological_sort.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <unordered_map>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/grappler/graph_topology_view.h"
25 #include "tensorflow/core/grappler/graph_view.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/status.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 
33 namespace {
34 
MakeEphemeralEdges(const absl::Span<const TopologicalDependency> extra_dependencies)35 std::vector<GraphView::Edge> MakeEphemeralEdges(
36     const absl::Span<const TopologicalDependency> extra_dependencies) {
37   std::vector<GraphView::Edge> ephemeral_edges;
38   ephemeral_edges.reserve(extra_dependencies.size());
39   for (const auto& dep : extra_dependencies) {
40     ephemeral_edges.emplace_back(
41         GraphView::OutputPort(dep.from, Graph::kControlSlot),
42         GraphView::InputPort(dep.to, Graph::kControlSlot));
43   }
44   return ephemeral_edges;
45 }
46 
47 // Kahn's algorithm is implemented.
48 // For details, see https://en.wikipedia.org/wiki/Topological_sorting
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<int> * ready_nodes)49 Status ComputeTopologicalOrder(
50     const GraphDef& graph,
51     const absl::Span<const TopologicalDependency> extra_dependencies,
52     std::vector<int>* ready_nodes) {
53   GraphTopologyView graph_view;
54   TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
55       graph, MakeEphemeralEdges(extra_dependencies)));
56 
57   // Keep track of how many inputs are ready for the given node.
58   std::vector<int> num_ready_inputs(graph.node_size(), 0);
59 
60   // We'll push index of ready nodes to this output vector.
61   ready_nodes->reserve(graph.node_size());
62 
63   int front = 0;
64   int back = 0;
65 
66   for (int i = 0; i < graph.node_size(); i++) {
67     if (graph_view.GetFanin(i).empty()) {
68       ready_nodes->push_back(i);
69       back++;
70     }
71     if (IsMerge(graph.node(i))) {
72       for (int input : graph_view.GetFanin(i)) {
73         if (IsNextIteration(graph.node(input))) {
74           num_ready_inputs[i]++;
75         }
76       }
77     }
78   }
79 
80   while (front != back) {
81     int ready_node = (*ready_nodes)[front];
82     for (int fanout : graph_view.GetFanout(ready_node)) {
83       ++num_ready_inputs[fanout];
84       if (num_ready_inputs[fanout] == graph_view.GetFanin(fanout).size()) {
85         ready_nodes->push_back(fanout);
86         ++back;
87       }
88     }
89     ++front;
90   }
91 
92   if (back != graph_view.num_nodes()) {
93     return errors::InvalidArgument(
94         "The graph couldn't be sorted in topological order.");
95   }
96   return Status::OK();
97 }
98 
99 }  // namespace
100 
ComputeTopologicalOrder(const GraphDef & graph,const absl::Span<const TopologicalDependency> extra_dependencies,std::vector<const NodeDef * > * topo_order)101 Status ComputeTopologicalOrder(
102     const GraphDef& graph,
103     const absl::Span<const TopologicalDependency> extra_dependencies,
104     std::vector<const NodeDef*>* topo_order) {
105   std::vector<int> ready_nodes;
106   TF_RETURN_IF_ERROR(
107       ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
108 
109   topo_order->reserve(ready_nodes.size());
110   for (int ready_node_idx : ready_nodes) {
111     topo_order->emplace_back(&graph.node(ready_node_idx));
112   }
113 
114   return Status::OK();
115 }
116 
ComputeTopologicalOrder(const GraphDef & graph,std::vector<const NodeDef * > * topo_order)117 Status ComputeTopologicalOrder(const GraphDef& graph,
118                                std::vector<const NodeDef*>* topo_order) {
119   return ComputeTopologicalOrder(graph, {}, topo_order);
120 }
121 
ReversedTopologicalSort(GraphDef * graph)122 Status ReversedTopologicalSort(GraphDef* graph) {
123   std::vector<int> ready_nodes;
124   TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
125   std::reverse(ready_nodes.begin(), ready_nodes.end());
126   PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
127   return Status::OK();
128 }
129 
TopologicalSort(GraphDef * graph)130 Status TopologicalSort(GraphDef* graph) {
131   std::vector<int> ready_nodes;
132   TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
133   PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
134   return Status::OK();
135 }
136 
137 }  // namespace grappler
138 }  // namespace tensorflow
139