1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <deque>
17 #include <iostream>
18
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
22 #include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
23 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 namespace graph_analyzer {
28
GraphAnalyzer(const GraphDef & graph,int subgraph_size)29 GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
30 : graph_(graph), subgraph_size_(subgraph_size) {}
31
~GraphAnalyzer()32 GraphAnalyzer::~GraphAnalyzer() {}
33
Run()34 Status GraphAnalyzer::Run() {
35 // The signature computation code would detect this too, but better
36 // to report it up front than spend time computing all the graphs first.
37 if (subgraph_size_ > Signature::kMaxGraphSize) {
38 return Status(error::INVALID_ARGUMENT,
39 absl::StrFormat("Subgraphs of %d nodes are not supported, "
40 "the maximal supported node count is %d.",
41 subgraph_size_, Signature::kMaxGraphSize));
42 }
43
44 Status st = BuildMap();
45 if (!st.ok()) {
46 return st;
47 }
48
49 FindSubgraphs();
50 DropInvalidSubgraphs();
51 st = CollateResult();
52 if (!st.ok()) {
53 return st;
54 }
55
56 return OkStatus();
57 }
58
BuildMap()59 Status GraphAnalyzer::BuildMap() {
60 nodes_.clear();
61 return GenNode::BuildGraphInMap(graph_, &nodes_);
62 }
63
FindSubgraphs()64 void GraphAnalyzer::FindSubgraphs() {
65 result_.clear();
66
67 if (subgraph_size_ < 1) {
68 return;
69 }
70
71 partial_.clear();
72 todo_.clear(); // Just in case.
73
74 // Start with all subgraphs of size 1.
75 const Subgraph::Identity empty_parent;
76 for (const auto& node : nodes_) {
77 if (subgraph_size_ == 1) {
78 result_.ExtendParent(empty_parent, node.second.get());
79 } else {
80 // At this point ExtendParent() is guaranteed to not return nullptr.
81 todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
82 }
83 }
84
85 // Then extend the subgraphs until no more extensions are possible.
86 while (!todo_.empty()) {
87 ExtendSubgraph(todo_.front());
88 todo_.pop_front();
89 }
90
91 partial_.clear();
92 }
93
ExtendSubgraph(Subgraph * parent)94 void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
95 const int next_parent_id = parent->id().size() + 1;
96 bool will_complete = (next_parent_id == subgraph_size_);
97 SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
98
99 const GenNode* last_all_or_none_node = nullptr;
100 for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
101 const GenNode* node = sit.GetNode();
102 GenNode::Port port = sit.GetPort();
103 const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
104
105 if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
106 if (node != last_all_or_none_node) {
107 ExtendSubgraphAllOrNone(parent, node);
108 last_all_or_none_node = node;
109 }
110 sit.SkipPort();
111 } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
112 !port.IsControl()) {
113 if (parent->id().find(neighbor.node) == parent->id().end()) {
114 // Not added yet.
115 ExtendSubgraphAllOrNone(parent, neighbor.node);
116 }
117 } else if (node->IsMultiInput(port)) {
118 ExtendSubgraphPortAllOrNone(parent, node, port);
119 sit.SkipPort();
120 } else if (neighbor.node->IsMultiInput(neighbor.port)) {
121 // Would need to add all inputs of the neighbor node at this port at
122 // once.
123 if (parent->id().find(neighbor.node) != parent->id().end()) {
124 continue; // Already added.
125 }
126 ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
127 } else {
128 Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
129 if (!will_complete && sg != nullptr) {
130 todo_.push_back(sg);
131 }
132 }
133 }
134 }
135
ExtendSubgraphAllOrNone(Subgraph * parent,const GenNode * node)136 void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
137 const GenNode* node) {
138 Subgraph::Identity id = parent->id();
139 id.insert(node);
140
141 auto range_end = node->links().end();
142
143 for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
144 auto port = nbit->first;
145 if (!port.IsInbound() || port.IsControl()) {
146 continue;
147 }
148
149 // Since there might be multiple links to the same nodes,
150 // have to add all links one-by-one to check whether the subgraph
151 // would grow too large. But if it does grow too large, there is no
152 // point in growing it more, can just skip over the rest of the links.
153 for (const auto& link : nbit->second) {
154 id.insert(link.node);
155 const int id_size = id.size();
156 if (id_size > subgraph_size_) {
157 return; // Too big.
158 }
159 }
160 }
161
162 AddExtendedSubgraph(parent, id);
163 }
164
ExtendSubgraphPortAllOrNone(Subgraph * parent,const GenNode * node,GenNode::Port port)165 void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
166 const GenNode* node,
167 GenNode::Port port) {
168 auto nbit = node->links().find(port);
169 if (nbit == node->links().end()) {
170 return; // Should never happen.
171 }
172
173 Subgraph::Identity id = parent->id();
174 id.insert(node);
175
176 // Since there might be multiple links to the same nodes,
177 // have to add all links one-by-one to check whether the subgraph
178 // would grow too large. But if it does grow too large, there is no
179 // point in growing it more, can just skip over the rest of the links.
180 for (const auto& link : nbit->second) {
181 id.insert(link.node);
182 const int id_size = id.size();
183 if (id_size > subgraph_size_) {
184 return; // Too big.
185 }
186 }
187
188 AddExtendedSubgraph(parent, id);
189 }
190
AddExtendedSubgraph(Subgraph * parent,const Subgraph::Identity & id)191 void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
192 const Subgraph::Identity& id) {
193 if (id.size() == parent->id().size()) {
194 return; // Nothing new was added.
195 }
196
197 auto sg = std::make_unique<Subgraph>(id);
198 SubgraphPtrSet& spec_sg_set =
199 (id.size() == subgraph_size_) ? result_ : partial_;
200 if (spec_sg_set.find(sg) != spec_sg_set.end()) {
201 // This subgraph was already found by extending from a different path.
202 return;
203 }
204 const int id_size = id.size();
205 if (id_size != subgraph_size_) {
206 todo_.push_back(sg.get());
207 }
208 spec_sg_set.insert(std::move(sg));
209 }
210
DropInvalidSubgraphs()211 void GraphAnalyzer::DropInvalidSubgraphs() {
212 auto resit = result_.begin();
213 while (resit != result_.end()) {
214 if (HasInvalidMultiInputs(resit->get())) {
215 auto delit = resit;
216 ++resit;
217 result_.erase(delit);
218 } else {
219 ++resit;
220 }
221 }
222 }
223
HasInvalidMultiInputs(Subgraph * sg)224 bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
225 // Do the all-or-none-input nodes.
226 for (auto const& node : sg->id()) {
227 if (!node->AllInputsOrNone()) {
228 continue;
229 }
230
231 bool anyIn = false;
232 bool anyOut = false;
233
234 auto range_end = node->links().end();
235 for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
236 auto port = nbit->first;
237 if (!port.IsInbound() || port.IsControl()) {
238 continue;
239 }
240
241 // Since there might be multiple links to the same nodes,
242 // have to add all links one-by-one to check whether the subgraph
243 // would grow too large. But if it does grow too large, there is no
244 // point in growing it more, can just skip over the rest of the links.
245 for (const auto& link : nbit->second) {
246 if (sg->id().find(link.node) == sg->id().end()) {
247 anyOut = true;
248 } else {
249 anyIn = true;
250 }
251 }
252 }
253
254 if (anyIn && anyOut) {
255 return true;
256 }
257 }
258
259 // Do the multi-input ports.
260 for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
261 if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
262 bool anyIn = false;
263 bool anyOut = false;
264 do {
265 GenNode* peer = sit.GetNeighbor().node;
266 if (sg->id().find(peer) == sg->id().end()) {
267 anyOut = true;
268 } else {
269 anyIn = true;
270 }
271 } while (sit.NextIfSamePort());
272
273 if (anyIn && anyOut) {
274 return true;
275 }
276 }
277 }
278 return false;
279 }
280
CollateResult()281 Status GraphAnalyzer::CollateResult() {
282 ordered_collation_.clear();
283 collation_map_.clear();
284
285 // Collate by the signatures of the graphs.
286 for (const auto& it : result_) {
287 auto sig = std::make_unique<Signature>();
288 it->ExtractForSignature(&sig->map);
289 Status status = sig->Compute();
290 if (!status.ok()) {
291 return status;
292 }
293
294 auto& coll_entry = collation_map_[sig.get()];
295 if (coll_entry.sig == nullptr) {
296 coll_entry.sig = std::move(sig);
297 }
298 ++coll_entry.count;
299 }
300
301 // Then order them by the count.
302 for (auto& entry : collation_map_) {
303 ordered_collation_.insert(&entry.second);
304 }
305
306 result_.clear(); // Not needed after collation.
307
308 return OkStatus();
309 }
310
DumpRawSubgraphs()311 std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
312 std::vector<string> result;
313 for (const auto& it : result_) {
314 result.emplace_back(it->Dump());
315 }
316 return result;
317 }
318
DumpSubgraphs()319 std::vector<string> GraphAnalyzer::DumpSubgraphs() {
320 std::vector<string> result;
321 for (auto ptr : ordered_collation_) {
322 result.emplace_back(
323 absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
324 }
325 return result;
326 }
327
OutputSubgraphs()328 Status GraphAnalyzer::OutputSubgraphs() {
329 size_t total = 0;
330 for (auto ptr : ordered_collation_) {
331 std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
332 total += ptr->count;
333 }
334 std::cout << "Total: " << total << '\n';
335 if (std::cout.fail()) {
336 return Status(error::DATA_LOSS, "Failed to write to stdout");
337 } else {
338 return OkStatus();
339 }
340 }
341
342 } // end namespace graph_analyzer
343 } // end namespace grappler
344 } // end namespace tensorflow
345