1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/scc.h"
17 #include <stack>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/grappler/utils.h"
24
25 namespace tensorflow {
26 namespace grappler {
27
28 // Data structure used to store data for Tarjan's Strongly Connected
29 // Components algorithm.
30 struct SCCNodeData {
SCCNodeDatatensorflow::grappler::SCCNodeData31 SCCNodeData()
32 : node(nullptr),
33 index(-1),
34 lowlink(-1),
35 onstack(false),
36 caller(nullptr),
37 caller_loop_location(-1) {}
ResetStacktensorflow::grappler::SCCNodeData38 void ResetStack(int new_index, SCCNodeData* new_caller) {
39 index = new_index;
40 lowlink = new_index;
41 onstack = true;
42 caller = new_caller;
43 caller_loop_location = 0;
44 }
45 const NodeDef* node;
46 int index;
47 int lowlink;
48 bool onstack;
49 std::vector<SCCNodeData*> children;
50 // StrongConnect "call stack" storage.
51 SCCNodeData* caller; // Node calling StrongConnect
52 int caller_loop_location; // Index in parent StrongConnect for loop
53 };
54
55 // Core DFS step of Tarjan's Strongly Connected Component algorithm
56 // (implemented using iteration instead of recursion).
StrongConnect(SCCNodeData * v,std::stack<SCCNodeData * > * stack,int * index,std::unordered_map<const NodeDef *,int> * components,int * scc_index)57 void StrongConnect(SCCNodeData* v, std::stack<SCCNodeData*>* stack, int* index,
58 std::unordered_map<const NodeDef*, int>* components,
59 int* scc_index) {
60 // Iterative version of Tarjan's StrongConnect function.
61 // The "call stack" state is composed of a SCCNodeData's caller and
62 // caller_loop_location properties.
63 v->ResetStack(*index /* index */, nullptr /* caller */);
64 ++*index;
65 stack->push(v);
66
67 // No one put v on a StrongConnect call stack, reset caller values.
68 v->caller = nullptr;
69 v->caller_loop_location = 0;
70
71 SCCNodeData* last = v;
72 while (true) {
73 if (last->caller_loop_location < last->children.size()) {
74 // Recursive equivalent: Looping over the children of v (possibly
75 // continuing at v->caller_loop_location after having finished a
76 // recursive call.
77 SCCNodeData* w = last->children[last->caller_loop_location];
78 ++(last->caller_loop_location); // For loop iterator increment
79 if (w->index == -1) {
80 w->ResetStack(*index /* index */, last /* caller */);
81 ++*index;
82 stack->push(w);
83 last = w;
84 } else if (w->onstack == true) {
85 last->lowlink = std::min(last->lowlink, w->index);
86 }
87 } else {
88 // At the end of v's children
89 if (last->lowlink == last->index) {
90 // v is the root of a strongly connected component
91 SCCNodeData* top;
92 while (true) {
93 top = stack->top();
94 stack->pop();
95 top->onstack = false;
96 (*components)[top->node] = *scc_index;
97 if (top == last) {
98 break;
99 }
100 }
101 ++*scc_index;
102 }
103
104 // Go up the recursive call stack
105 SCCNodeData* next_last = last->caller;
106 if (next_last == nullptr) {
107 // All nodes have been seen; finished.
108 break;
109 } else {
110 next_last->lowlink = std::min(next_last->lowlink, last->lowlink);
111 last = next_last;
112 }
113 }
114 }
115 }
116
117 // This is an implementation of Tarjan's Strongly Connected Components
118 // DFS algorithm. Most of the hard work is done in the function
119 // StrongConnect, which is an iterative reimplementation of the
120 // recursive version described here:
121 // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
122 //
123 // The edges for the purpose of this algorithm are directed from input
124 // to op (the reverse of the declarations of the NodeDef, which
125 // contain in-edges)
StronglyConnectedComponents(const GraphDef & graph,std::unordered_map<const NodeDef *,int> * components,int * num_components)126 void StronglyConnectedComponents(
127 const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
128 int* num_components) {
129 std::stack<SCCNodeData*> stack;
130 std::unordered_map<string, SCCNodeData*> name_to_data;
131 std::vector<SCCNodeData> node_data_container;
132 node_data_container.reserve(graph.node_size());
133 std::unordered_map<const NodeDef*, SCCNodeData*> node_to_data;
134
135 for (const NodeDef& node : graph.node()) {
136 SCCNodeData node_data;
137 node_data.node = &node;
138 node_data_container.push_back(node_data);
139 name_to_data[node.name()] = &(*node_data_container.rbegin());
140 node_to_data[&node] = &(*node_data_container.rbegin());
141 }
142
143 // Create a list of top-level parents (add them to object queue)
144 // Also create a mapping from nodes to their children.
145 // Inputs might not be present if called on a subgraph.
146 for (const NodeDef& node : graph.node()) {
147 for (const string& input : node.input()) {
148 auto it = name_to_data.find(NodeName(input));
149 if (it != name_to_data.end()) {
150 it->second->children.push_back(node_to_data[&node]);
151 }
152 }
153 }
154
155 components->clear();
156 *num_components = 0;
157 int index = 0;
158 for (auto& v : node_data_container) {
159 if (v.index == -1) {
160 // Node has not yet been visited. Start a DFS at v.
161 StrongConnect(&v, &stack, &index, components, num_components);
162 }
163 }
164
165 std::vector<int> counts_per_component(*num_components, 0);
166 for (auto& component : *components) {
167 DCHECK(component.second >= 0);
168 DCHECK(component.second < *num_components);
169 counts_per_component[component.second]++;
170 }
171 bool has_single_element_component = false;
172 for (auto& component : *components) {
173 if (counts_per_component[component.second] == 1) {
174 component.second = -1;
175 (*num_components)--;
176 has_single_element_component = true;
177 }
178 }
179 if (has_single_element_component) {
180 (*num_components) += 1;
181 }
182 }
183
IdentifyLoops(const GraphDef & graph,std::unordered_map<const NodeDef *,std::vector<int>> * loops)184 int IdentifyLoops(const GraphDef& graph,
185 std::unordered_map<const NodeDef*, std::vector<int>>* loops) {
186 int num_components = 0;
187 std::unordered_map<const NodeDef*, int> components;
188 StronglyConnectedComponents(graph, &components, &num_components);
189 if (num_components <= 1) {
190 if (!components.empty() && components.begin()->second == -1) {
191 return 0;
192 }
193 }
194
195 std::unordered_map<int, std::vector<const NodeDef*>> component_ids;
196 for (const auto it : components) {
197 int id = it.second;
198 if (id < 0) {
199 continue;
200 }
201 component_ids[id].push_back(it.first);
202 }
203
204 int loop_id = 0;
205 for (const auto& component : component_ids) {
206 const std::vector<const NodeDef*>& component_nodes = component.second;
207 std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
208 GraphDef subgraph;
209 std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
210
211 for (const auto& component_node : component_nodes) {
212 NodeDef* node = subgraph.add_node();
213 *node = *component_node;
214 subgraph_mapping[node] = component_node;
215 if (IsNextIteration(*node)) {
216 CHECK_EQ(1, node->input_size());
217 next_iter_nodes.emplace_back(node, node->input(0));
218 }
219 }
220 if (next_iter_nodes.size() == 1) {
221 for (const auto& component_node : component_nodes) {
222 (*loops)[component_node].push_back(loop_id);
223 }
224 ++loop_id;
225 } else {
226 for (int i = 0; i < next_iter_nodes.size(); ++i) {
227 for (int j = 0; j < next_iter_nodes.size(); ++j) {
228 next_iter_nodes[j].first->clear_input();
229 if (i == j) {
230 *next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second;
231 }
232 }
233 int num_components = 0;
234 std::unordered_map<const NodeDef*, int> components;
235 StronglyConnectedComponents(subgraph, &components, &num_components);
236 CHECK_GE(num_components, 1);
237 for (const auto it : components) {
238 int id = it.second;
239 if (id < 0) {
240 continue;
241 }
242 (*loops)[subgraph_mapping[it.first]].push_back(loop_id);
243 }
244 ++loop_id;
245 }
246 }
247 }
248
249 return loop_id;
250 }
251
252 } // namespace grappler
253 } // namespace tensorflow
254