• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
18 
19 #include <deque>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
24 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
25 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace graph_analyzer {
31 
32 namespace test {
33 class GraphAnalyzerTest;
34 }  // end namespace test
35 
36 // Finds all the subgraphs of a given size and groups them by equivalence.
37 class GraphAnalyzer {
38  public:
39   // Makes a copy of the graph.
40   GraphAnalyzer(const GraphDef& graph, int subgraph_size);
41 
42   virtual ~GraphAnalyzer();
43 
44   // Performs the analysis and collects the subgraphs.
45   Status Run();
46 
47   // Returns the subgraphs found in Run() printed to text.
48   std::vector<string> DumpSubgraphs();
49 
50   // Prints the subgraphs found in Run() to stdout.
51   Status OutputSubgraphs();
52 
53   // TODO(babkin): add a way to extract the subgraphs as direct data
54   // structures and as protobufs, and to write protobufs to a RecordIO.
55 
56  private:
57   GraphAnalyzer() = delete;
58   GraphAnalyzer(const GraphAnalyzer&) = delete;
59   void operator=(const GraphAnalyzer&) = delete;
60 
61   friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest;
62 
63   // Builds the map of nodes from the original graph definition.
64   Status BuildMap();
65 
66   // Using nodes_, finds all the subgraphs of size subgraph_size_ and places
67   // them into result_.
68   void FindSubgraphs();
69 
70   // Deletes from result_ the unacceptable subgraphs. Those include the
71   // subgraphs where not all the inputs at a multi-input port are included (this
72   // could happen if some of these inputs were reached and included through
73   // different paths).
74   void DropInvalidSubgraphs();
75 
76   // Deletes from result_ duplicate entries of equivalent topology.
77   Status CollateResult();
78 
79   // Returns the raw subgraphs found in FindSubgraphs() printed to text.
80   std::vector<string> DumpRawSubgraphs();
81 
82   // Finds and adds appropriately to either partial_ or result_ all the
83   // subgraphs that can be created by extending the parent subgraph by one node.
84   // Ignores the duplicates.
85   void ExtendSubgraph(Subgraph* parent);
86 
87   // Extends the parent subgraph by adding another node (if it wasn't already
88   // added) and all its non-control inputs in the link map range at once.
89   // If the subgraph would grow over subgraph_size_, it gets ignored.
90   void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node);
91   // Same but adds one specific inbound port (even control) all-or-none.
92   void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node,
93                                    GenNode::Port port);
94   // The common final step called by ExtendSubgraph*AllOrNone() methods.
95   void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id);
96 
97   // Returns true if this subgraph has any multi-inputs that aren't all-in or
98   // all-out.
99   bool HasInvalidMultiInputs(Subgraph* sg);
100 
101   // Graph to run the analysis on.
102   GraphDef graph_;
103   int subgraph_size_;
104 
105   // The enriched graph of parsed nodes and connections.
106   GenNodeMap nodes_;
107   // The resulting set of subgraphs.
108   SubgraphPtrSet result_;
109   // The subgraphs of partial size, stored while finding the result.
110   SubgraphPtrSet partial_;
111   // The subgraphs of partial size (stored in partial_) that are still waiting
112   // to be extended.
113   //
114   // TODO(babkin): This is rather simple-minded, each subgraph is examined from
115   // scratch, which means that all its internal links get iterated too. But it's
116   // OK for the small subgraphs. This can be improved by keeping not just
117   // subgraphs but iterators on the list, each of them having the list not-yet
118   // examined nodes (and the link position of the next link to be examined for
119   // the first node). This would add extra constant overhead, so the break-even
120   // subgraph size is not clear yet.
121   std::deque<Subgraph*> todo_;
122 
123   // The collation map by signature is designed to allow the removal of entries
124   // and moving of the signature references from the keys of this map to the
125   // outside world. Must be careful at inserting and removal: make sure that
126   // when a new entry is inserted, its signature reference gets populated with
127   // the same data as the key of the map, and that if a reference is moved out,
128   // the map entry gets removed before that reference gets destroyed.
129   struct CollationEntry {
130     std::shared_ptr<Signature> sig;
131     size_t count = 0;
132   };
133   using CollationMap =
134       std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>,
135                          EqAtPtr<Signature*> >;
136   CollationMap collation_map_;
137 
138   // The entries are owned by collation_map_, so must be removed from
139   // ordered_collation_ before removing them from collation_map_.
140   struct ReverseLessByCount {
operatorReverseLessByCount141     bool operator()(CollationEntry* left, CollationEntry* right) const {
142       return left->count > right->count;  // Reverse order.
143     }
144   };
145   using CollationOrderByCount =
146       std::multiset<CollationEntry*, ReverseLessByCount>;
147   CollationOrderByCount ordered_collation_;
148 };
149 
150 }  // end namespace graph_analyzer
151 }  // end namespace grappler
152 }  // end namespace tensorflow
153 
154 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
155