• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <deque>
17 #include <iostream>
18 
19 #include "absl/memory/memory.h"
20 #include "absl/strings/str_format.h"
21 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
22 #include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
23 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 namespace graph_analyzer {
28 
GraphAnalyzer(const GraphDef & graph,int subgraph_size)29 GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
30     : graph_(graph), subgraph_size_(subgraph_size) {}
31 
~GraphAnalyzer()32 GraphAnalyzer::~GraphAnalyzer() {}
33 
Run()34 Status GraphAnalyzer::Run() {
35   // The signature computation code would detect this too, but better
36   // to report it up front than spend time computing all the graphs first.
37   if (subgraph_size_ > Signature::kMaxGraphSize) {
38     return Status(error::INVALID_ARGUMENT,
39                   absl::StrFormat("Subgraphs of %d nodes are not supported, "
40                                   "the maximal supported node count is %d.",
41                                   subgraph_size_, Signature::kMaxGraphSize));
42   }
43 
44   Status st = BuildMap();
45   if (!st.ok()) {
46     return st;
47   }
48 
49   FindSubgraphs();
50   DropInvalidSubgraphs();
51   st = CollateResult();
52   if (!st.ok()) {
53     return st;
54   }
55 
56   return OkStatus();
57 }
58 
BuildMap()59 Status GraphAnalyzer::BuildMap() {
60   nodes_.clear();
61   return GenNode::BuildGraphInMap(graph_, &nodes_);
62 }
63 
FindSubgraphs()64 void GraphAnalyzer::FindSubgraphs() {
65   result_.clear();
66 
67   if (subgraph_size_ < 1) {
68     return;
69   }
70 
71   partial_.clear();
72   todo_.clear();  // Just in case.
73 
74   // Start with all subgraphs of size 1.
75   const Subgraph::Identity empty_parent;
76   for (const auto& node : nodes_) {
77     if (subgraph_size_ == 1) {
78       result_.ExtendParent(empty_parent, node.second.get());
79     } else {
80       // At this point ExtendParent() is guaranteed to not return nullptr.
81       todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
82     }
83   }
84 
85   // Then extend the subgraphs until no more extensions are possible.
86   while (!todo_.empty()) {
87     ExtendSubgraph(todo_.front());
88     todo_.pop_front();
89   }
90 
91   partial_.clear();
92 }
93 
ExtendSubgraph(Subgraph * parent)94 void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
95   const int next_parent_id = parent->id().size() + 1;
96   bool will_complete = (next_parent_id == subgraph_size_);
97   SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
98 
99   const GenNode* last_all_or_none_node = nullptr;
100   for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
101     const GenNode* node = sit.GetNode();
102     GenNode::Port port = sit.GetPort();
103     const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
104 
105     if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
106       if (node != last_all_or_none_node) {
107         ExtendSubgraphAllOrNone(parent, node);
108         last_all_or_none_node = node;
109       }
110       sit.SkipPort();
111     } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
112                !port.IsControl()) {
113       if (parent->id().find(neighbor.node) == parent->id().end()) {
114         // Not added yet.
115         ExtendSubgraphAllOrNone(parent, neighbor.node);
116       }
117     } else if (node->IsMultiInput(port)) {
118       ExtendSubgraphPortAllOrNone(parent, node, port);
119       sit.SkipPort();
120     } else if (neighbor.node->IsMultiInput(neighbor.port)) {
121       // Would need to add all inputs of the neighbor node at this port at
122       // once.
123       if (parent->id().find(neighbor.node) != parent->id().end()) {
124         continue;  // Already added.
125       }
126       ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
127     } else {
128       Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
129       if (!will_complete && sg != nullptr) {
130         todo_.push_back(sg);
131       }
132     }
133   }
134 }
135 
ExtendSubgraphAllOrNone(Subgraph * parent,const GenNode * node)136 void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
137                                             const GenNode* node) {
138   Subgraph::Identity id = parent->id();
139   id.insert(node);
140 
141   auto range_end = node->links().end();
142 
143   for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
144     auto port = nbit->first;
145     if (!port.IsInbound() || port.IsControl()) {
146       continue;
147     }
148 
149     // Since there might be multiple links to the same nodes,
150     // have to add all links one-by-one to check whether the subgraph
151     // would grow too large. But if it does grow too large, there is no
152     // point in growing it more, can just skip over the rest of the links.
153     for (const auto& link : nbit->second) {
154       id.insert(link.node);
155       const int id_size = id.size();
156       if (id_size > subgraph_size_) {
157         return;  // Too big.
158       }
159     }
160   }
161 
162   AddExtendedSubgraph(parent, id);
163 }
164 
ExtendSubgraphPortAllOrNone(Subgraph * parent,const GenNode * node,GenNode::Port port)165 void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
166                                                 const GenNode* node,
167                                                 GenNode::Port port) {
168   auto nbit = node->links().find(port);
169   if (nbit == node->links().end()) {
170     return;  // Should never happen.
171   }
172 
173   Subgraph::Identity id = parent->id();
174   id.insert(node);
175 
176   // Since there might be multiple links to the same nodes,
177   // have to add all links one-by-one to check whether the subgraph
178   // would grow too large. But if it does grow too large, there is no
179   // point in growing it more, can just skip over the rest of the links.
180   for (const auto& link : nbit->second) {
181     id.insert(link.node);
182     const int id_size = id.size();
183     if (id_size > subgraph_size_) {
184       return;  // Too big.
185     }
186   }
187 
188   AddExtendedSubgraph(parent, id);
189 }
190 
AddExtendedSubgraph(Subgraph * parent,const Subgraph::Identity & id)191 void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
192                                         const Subgraph::Identity& id) {
193   if (id.size() == parent->id().size()) {
194     return;  // Nothing new was added.
195   }
196 
197   auto sg = std::make_unique<Subgraph>(id);
198   SubgraphPtrSet& spec_sg_set =
199       (id.size() == subgraph_size_) ? result_ : partial_;
200   if (spec_sg_set.find(sg) != spec_sg_set.end()) {
201     // This subgraph was already found by extending from a different path.
202     return;
203   }
204   const int id_size = id.size();
205   if (id_size != subgraph_size_) {
206     todo_.push_back(sg.get());
207   }
208   spec_sg_set.insert(std::move(sg));
209 }
210 
DropInvalidSubgraphs()211 void GraphAnalyzer::DropInvalidSubgraphs() {
212   auto resit = result_.begin();
213   while (resit != result_.end()) {
214     if (HasInvalidMultiInputs(resit->get())) {
215       auto delit = resit;
216       ++resit;
217       result_.erase(delit);
218     } else {
219       ++resit;
220     }
221   }
222 }
223 
HasInvalidMultiInputs(Subgraph * sg)224 bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
225   // Do the all-or-none-input nodes.
226   for (auto const& node : sg->id()) {
227     if (!node->AllInputsOrNone()) {
228       continue;
229     }
230 
231     bool anyIn = false;
232     bool anyOut = false;
233 
234     auto range_end = node->links().end();
235     for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
236       auto port = nbit->first;
237       if (!port.IsInbound() || port.IsControl()) {
238         continue;
239       }
240 
241       // Since there might be multiple links to the same nodes,
242       // have to add all links one-by-one to check whether the subgraph
243       // would grow too large. But if it does grow too large, there is no
244       // point in growing it more, can just skip over the rest of the links.
245       for (const auto& link : nbit->second) {
246         if (sg->id().find(link.node) == sg->id().end()) {
247           anyOut = true;
248         } else {
249           anyIn = true;
250         }
251       }
252     }
253 
254     if (anyIn && anyOut) {
255       return true;
256     }
257   }
258 
259   // Do the multi-input ports.
260   for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
261     if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
262       bool anyIn = false;
263       bool anyOut = false;
264       do {
265         GenNode* peer = sit.GetNeighbor().node;
266         if (sg->id().find(peer) == sg->id().end()) {
267           anyOut = true;
268         } else {
269           anyIn = true;
270         }
271       } while (sit.NextIfSamePort());
272 
273       if (anyIn && anyOut) {
274         return true;
275       }
276     }
277   }
278   return false;
279 }
280 
CollateResult()281 Status GraphAnalyzer::CollateResult() {
282   ordered_collation_.clear();
283   collation_map_.clear();
284 
285   // Collate by the signatures of the graphs.
286   for (const auto& it : result_) {
287     auto sig = std::make_unique<Signature>();
288     it->ExtractForSignature(&sig->map);
289     Status status = sig->Compute();
290     if (!status.ok()) {
291       return status;
292     }
293 
294     auto& coll_entry = collation_map_[sig.get()];
295     if (coll_entry.sig == nullptr) {
296       coll_entry.sig = std::move(sig);
297     }
298     ++coll_entry.count;
299   }
300 
301   // Then order them by the count.
302   for (auto& entry : collation_map_) {
303     ordered_collation_.insert(&entry.second);
304   }
305 
306   result_.clear();  // Not needed after collation.
307 
308   return OkStatus();
309 }
310 
DumpRawSubgraphs()311 std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
312   std::vector<string> result;
313   for (const auto& it : result_) {
314     result.emplace_back(it->Dump());
315   }
316   return result;
317 }
318 
DumpSubgraphs()319 std::vector<string> GraphAnalyzer::DumpSubgraphs() {
320   std::vector<string> result;
321   for (auto ptr : ordered_collation_) {
322     result.emplace_back(
323         absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
324   }
325   return result;
326 }
327 
OutputSubgraphs()328 Status GraphAnalyzer::OutputSubgraphs() {
329   size_t total = 0;
330   for (auto ptr : ordered_collation_) {
331     std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
332     total += ptr->count;
333   }
334   std::cout << "Total: " << total << '\n';
335   if (std::cout.fail()) {
336     return Status(error::DATA_LOSS, "Failed to write to stdout");
337   } else {
338     return OkStatus();
339   }
340 }
341 
342 }  // end namespace graph_analyzer
343 }  // end namespace grappler
344 }  // end namespace tensorflow
345