• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/scc.h"
17 #include <stack>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/grappler/utils.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
28 // Data structure used to store data for Tarjan's Strongly Connected
29 // Components algorithm.
30 struct SCCNodeData {
SCCNodeDatatensorflow::grappler::SCCNodeData31   SCCNodeData()
32       : node(nullptr),
33         index(-1),
34         lowlink(-1),
35         onstack(false),
36         caller(nullptr),
37         caller_loop_location(-1) {}
ResetStacktensorflow::grappler::SCCNodeData38   void ResetStack(int new_index, SCCNodeData* new_caller) {
39     index = new_index;
40     lowlink = new_index;
41     onstack = true;
42     caller = new_caller;
43     caller_loop_location = 0;
44   }
45   const NodeDef* node;
46   int index;
47   int lowlink;
48   bool onstack;
49   std::vector<SCCNodeData*> children;
50   // StrongConnect "call stack" storage.
51   SCCNodeData* caller;       // Node calling StrongConnect
52   int caller_loop_location;  // Index in parent StrongConnect for loop
53 };
54 
55 // Core DFS step of Tarjan's Strongly Connected Component algorithm
56 // (implemented using iteration instead of recursion).
StrongConnect(SCCNodeData * v,std::stack<SCCNodeData * > * stack,int * index,std::unordered_map<const NodeDef *,int> * components,int * scc_index)57 void StrongConnect(SCCNodeData* v, std::stack<SCCNodeData*>* stack, int* index,
58                    std::unordered_map<const NodeDef*, int>* components,
59                    int* scc_index) {
60   // Iterative version of Tarjan's StrongConnect function.
61   // The "call stack" state is composed of a SCCNodeData's caller and
62   // caller_loop_location properties.
63   v->ResetStack(*index /* index */, nullptr /* caller */);
64   ++*index;
65   stack->push(v);
66 
67   // No one put v on a StrongConnect call stack, reset caller values.
68   v->caller = nullptr;
69   v->caller_loop_location = 0;
70 
71   SCCNodeData* last = v;
72   while (true) {
73     if (last->caller_loop_location < last->children.size()) {
74       // Recursive equivalent: Looping over the children of v (possibly
75       // continuing at v->caller_loop_location after having finished a
76       // recursive call.
77       SCCNodeData* w = last->children[last->caller_loop_location];
78       ++(last->caller_loop_location);  // For loop iterator increment
79       if (w->index == -1) {
80         w->ResetStack(*index /* index */, last /* caller */);
81         ++*index;
82         stack->push(w);
83         last = w;
84       } else if (w->onstack == true) {
85         last->lowlink = std::min(last->lowlink, w->index);
86       }
87     } else {
88       // At the end of v's children
89       if (last->lowlink == last->index) {
90         // v is the root of a strongly connected component
91         SCCNodeData* top;
92         while (true) {
93           top = stack->top();
94           stack->pop();
95           top->onstack = false;
96           (*components)[top->node] = *scc_index;
97           if (top == last) {
98             break;
99           }
100         }
101         ++*scc_index;
102       }
103 
104       // Go up the recursive call stack
105       SCCNodeData* next_last = last->caller;
106       if (next_last == nullptr) {
107         // All nodes have been seen; finished.
108         break;
109       } else {
110         next_last->lowlink = std::min(next_last->lowlink, last->lowlink);
111         last = next_last;
112       }
113     }
114   }
115 }
116 
117 // This is an implementation of Tarjan's Strongly Connected Components
118 // DFS algorithm.  Most of the hard work is done in the function
119 // StrongConnect, which is an iterative reimplementation of the
120 // recursive version described here:
121 //   https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
122 //
123 // The edges for the purpose of this algorithm are directed from input
124 // to op (the reverse of the declarations of the NodeDef, which
125 // contain in-edges)
StronglyConnectedComponents(const GraphDef & graph,std::unordered_map<const NodeDef *,int> * components,int * num_components)126 void StronglyConnectedComponents(
127     const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
128     int* num_components) {
129   std::stack<SCCNodeData*> stack;
130   std::unordered_map<string, SCCNodeData*> name_to_data;
131   std::vector<SCCNodeData> node_data_container;
132   node_data_container.reserve(graph.node_size());
133   std::unordered_map<const NodeDef*, SCCNodeData*> node_to_data;
134 
135   for (const NodeDef& node : graph.node()) {
136     SCCNodeData node_data;
137     node_data.node = &node;
138     node_data_container.push_back(node_data);
139     name_to_data[node.name()] = &(*node_data_container.rbegin());
140     node_to_data[&node] = &(*node_data_container.rbegin());
141   }
142 
143   // Create a list of top-level parents (add them to object queue)
144   // Also create a mapping from nodes to their children.
145   // Inputs might not be present if called on a subgraph.
146   for (const NodeDef& node : graph.node()) {
147     for (const string& input : node.input()) {
148       auto it = name_to_data.find(NodeName(input));
149       if (it != name_to_data.end()) {
150         it->second->children.push_back(node_to_data[&node]);
151       }
152     }
153   }
154 
155   components->clear();
156   *num_components = 0;
157   int index = 0;
158   for (auto& v : node_data_container) {
159     if (v.index == -1) {
160       // Node has not yet been visited.  Start a DFS at v.
161       StrongConnect(&v, &stack, &index, components, num_components);
162     }
163   }
164 
165   std::vector<int> counts_per_component(*num_components, 0);
166   for (auto& component : *components) {
167     DCHECK(component.second >= 0);
168     DCHECK(component.second < *num_components);
169     counts_per_component[component.second]++;
170   }
171   bool has_single_element_component = false;
172   for (auto& component : *components) {
173     if (counts_per_component[component.second] == 1) {
174       component.second = -1;
175       (*num_components)--;
176       has_single_element_component = true;
177     }
178   }
179   if (has_single_element_component) {
180     (*num_components) += 1;
181   }
182 }
183 
IdentifyLoops(const GraphDef & graph,std::unordered_map<const NodeDef *,std::vector<int>> * loops)184 int IdentifyLoops(const GraphDef& graph,
185                   std::unordered_map<const NodeDef*, std::vector<int>>* loops) {
186   int num_components = 0;
187   std::unordered_map<const NodeDef*, int> components;
188   StronglyConnectedComponents(graph, &components, &num_components);
189   if (num_components <= 1) {
190     if (!components.empty() && components.begin()->second == -1) {
191       return 0;
192     }
193   }
194 
195   std::unordered_map<int, std::vector<const NodeDef*>> component_ids;
196   for (const auto it : components) {
197     int id = it.second;
198     if (id < 0) {
199       continue;
200     }
201     component_ids[id].push_back(it.first);
202   }
203 
204   int loop_id = 0;
205   for (const auto& component : component_ids) {
206     const std::vector<const NodeDef*>& component_nodes = component.second;
207     std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
208     GraphDef subgraph;
209     std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
210 
211     for (const auto& component_node : component_nodes) {
212       NodeDef* node = subgraph.add_node();
213       *node = *component_node;
214       subgraph_mapping[node] = component_node;
215       if (IsNextIteration(*node)) {
216         CHECK_EQ(1, node->input_size());
217         next_iter_nodes.emplace_back(node, node->input(0));
218       }
219     }
220     if (next_iter_nodes.size() == 1) {
221       for (const auto& component_node : component_nodes) {
222         (*loops)[component_node].push_back(loop_id);
223       }
224       ++loop_id;
225     } else {
226       for (int i = 0; i < next_iter_nodes.size(); ++i) {
227         for (int j = 0; j < next_iter_nodes.size(); ++j) {
228           next_iter_nodes[j].first->clear_input();
229           if (i == j) {
230             *next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second;
231           }
232         }
233         int num_components = 0;
234         std::unordered_map<const NodeDef*, int> components;
235         StronglyConnectedComponents(subgraph, &components, &num_components);
236         CHECK_GE(num_components, 1);
237         for (const auto it : components) {
238           int id = it.second;
239           if (id < 0) {
240             continue;
241           }
242           (*loops)[subgraph_mapping[it.first]].push_back(loop_id);
243         }
244         ++loop_id;
245       }
246     }
247   }
248 
249   return loop_id;
250 }
251 
252 }  // namespace grappler
253 }  // namespace tensorflow
254