1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ 18 19 #include <deque> 20 #include <vector> 21 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/grappler/graph_analyzer/map_tools.h" 24 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h" 25 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 namespace grappler { 30 namespace graph_analyzer { 31 32 namespace test { 33 class GraphAnalyzerTest; 34 } // end namespace test 35 36 // Finds all the subgraphs of a given size and groups them by equivalence. 37 class GraphAnalyzer { 38 public: 39 // Makes a copy of the graph. 40 GraphAnalyzer(const GraphDef& graph, int subgraph_size); 41 42 virtual ~GraphAnalyzer(); 43 44 // Performs the analysis and collects the subgraphs. 45 Status Run(); 46 47 // Returns the subgraphs found in Run() printed to text. 48 std::vector<string> DumpSubgraphs(); 49 50 // Prints the subgraphs found in Run() to stdout. 51 Status OutputSubgraphs(); 52 53 // TODO(babkin): add a way to extract the subgraphs as direct data 54 // structures and as protobufs, and to write protobufs to a RecordIO. 55 56 private: 57 GraphAnalyzer() = delete; 58 GraphAnalyzer(const GraphAnalyzer&) = delete; 59 void operator=(const GraphAnalyzer&) = delete; 60 61 friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest; 62 63 // Builds the map of nodes from the original graph definition. 64 Status BuildMap(); 65 66 // Using nodes_, finds all the subgraphs of size subgraph_size_ and places 67 // them into result_. 68 void FindSubgraphs(); 69 70 // Deletes from result_ the unacceptable subgraphs. Those include the 71 // subgraphs where not all the inputs at a multi-input port are included (this 72 // could happen if some of these inputs were reached and included through 73 // different paths). 74 void DropInvalidSubgraphs(); 75 76 // Deletes from result_ duplicate entries of equivalent topology. 77 Status CollateResult(); 78 79 // Returns the raw subgraphs found in FindSubgraphs() printed to text. 80 std::vector<string> DumpRawSubgraphs(); 81 82 // Finds and adds appropriately to either partial_ or result_ all the 83 // subgraphs that can be created by extending the parent subgraph by one node. 84 // Ignores the duplicates. 85 void ExtendSubgraph(Subgraph* parent); 86 87 // Extends the parent subgraph by adding another node (if it wasn't already 88 // added) and all its non-control inputs in the link map range at once. 89 // If the subgraph would grow over subgraph_size_, it gets ignored. 90 void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node); 91 // Same but adds one specific inbound port (even control) all-or-none. 92 void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node, 93 GenNode::Port port); 94 // The common final step called by ExtendSubgraph*AllOrNone() methods. 95 void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id); 96 97 // Returns true if this subgraph has any multi-inputs that aren't all-in or 98 // all-out. 99 bool HasInvalidMultiInputs(Subgraph* sg); 100 101 // Graph to run the analysis on. 102 GraphDef graph_; 103 int subgraph_size_; 104 105 // The enriched graph of parsed nodes and connections. 106 GenNodeMap nodes_; 107 // The resulting set of subgraphs. 108 SubgraphPtrSet result_; 109 // The subgraphs of partial size, stored while finding the result. 110 SubgraphPtrSet partial_; 111 // The subgraphs of partial size (stored in partial_) that are still waiting 112 // to be extended. 113 // 114 // TODO(babkin): This is rather simple-minded, each subgraph is examined from 115 // scratch, which means that all its internal links get iterated too. But it's 116 // OK for the small subgraphs. This can be improved by keeping not just 117 // subgraphs but iterators on the list, each of them having the list not-yet 118 // examined nodes (and the link position of the next link to be examined for 119 // the first node). This would add extra constant overhead, so the break-even 120 // subgraph size is not clear yet. 121 std::deque<Subgraph*> todo_; 122 123 // The collation map by signature is designed to allow the removal of entries 124 // and moving of the signature references from the keys of this map to the 125 // outside world. Must be careful at inserting and removal: make sure that 126 // when a new entry is inserted, its signature reference gets populated with 127 // the same data as the key of the map, and that if a reference is moved out, 128 // the map entry gets removed before that reference gets destroyed. 129 struct CollationEntry { 130 std::shared_ptr<Signature> sig; 131 size_t count = 0; 132 }; 133 using CollationMap = 134 std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>, 135 EqAtPtr<Signature*> >; 136 CollationMap collation_map_; 137 138 // The entries are owned by collation_map_, so must be removed from 139 // ordered_collation_ before removing them from collation_map_. 140 struct ReverseLessByCount { operatorReverseLessByCount141 bool operator()(CollationEntry* left, CollationEntry* right) const { 142 return left->count > right->count; // Reverse order. 143 } 144 }; 145 using CollationOrderByCount = 146 std::multiset<CollationEntry*, ReverseLessByCount>; 147 CollationOrderByCount ordered_collation_; 148 }; 149 150 } // end namespace graph_analyzer 151 } // end namespace grappler 152 } // end namespace tensorflow 153 154 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ 155