1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ 18 19 #include <initializer_list> 20 #include <set> 21 #include <unordered_set> 22 23 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h" 24 #include "tensorflow/core/grappler/graph_analyzer/map_tools.h" 25 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h" 26 #include "tensorflow/core/lib/gtl/flatset.h" 27 28 namespace tensorflow { 29 namespace grappler { 30 namespace graph_analyzer { 31 32 // The description of a single subgraph for processing. 33 class Subgraph { 34 public: 35 // Identity of a single subgraph as a set of nodes. 36 class Identity : public gtl::FlatSet<const GenNode*> { 37 public: 38 using InitializerList = std::initializer_list<GenNode*>; 39 40 Identity() = default; 41 Identity(InitializerList init); 42 bool operator<(const Identity& other) const; 43 bool operator==(const Identity& other) const; 44 45 // Compute the hash. 46 size_t Hash() const; 47 }; 48 Subgraph(Identity id)49 explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {} 50 51 // Construct by extending the parent identity with an extra node. 52 Subgraph(const Identity& parent_id, GenNode* add_node); 53 54 Subgraph() = delete; 55 Subgraph(const Subgraph& other) = delete; 56 void operator=(const Subgraph& other) = delete; 57 58 // Order for building sets of subgraphs. 59 bool operator<(const Subgraph& other) const { return this->id_ < other.id_; } 60 // Support for hashed sets. 61 bool operator==(const Subgraph& other) const { 62 return this->id_ == other.id_; 63 } Hash()64 size_t Hash() const { return hash_; } 65 66 // Dump the subgraph information to a string. 67 string Dump(); 68 69 // Extract this subgraph into a separate graph representation for signature 70 // building, that includes only the links between the nodes in the subgraph 71 // and drops all the external links. The result map should be clear before the 72 // call. 73 void ExtractForSignature(SigNodeMap* result); 74 id()75 const Identity& id() const { return id_; } specific()76 bool specific() const { return specific_; } SetSpecific(bool value)77 void SetSpecific(bool value) { specific_ = value; } collation_count()78 int32_t collation_count() const { return collation_count_; } 79 void AddCollation(int32_t n = 1) { collation_count_ += n; } ResetCollation()80 void ResetCollation() { collation_count_ = 1; } MergeCollation(const Subgraph & other)81 void MergeCollation(const Subgraph& other) { 82 collation_count_ += other.collation_count_; 83 } 84 85 private: 86 // Identity also serves as the list of nodes. It never changes throughout the 87 // life of subgraph. 88 Identity id_; 89 size_t hash_; // Cached from the identity. 90 // Whether the dump should include the specific names of the nodes. The 91 // non-specific (i.e. generic) subgraphs represent a collation of multiple 92 // subgraphs. 93 bool specific_ = true; 94 // How many collated subgraphs are represented by this subgraph. 95 int32_t collation_count_ = 1; 96 }; 97 98 // Iteration of all links in a subgraph. This is more like Java iterators than 99 // the normal C++ iterators. It's simpler this way and there seems to be no 100 // major reason to make it a proper C++ iterator. 101 class SubgraphIterator { 102 public: 103 // Obviously an iterator is valid only until the original object 104 // gets destroyed. 105 explicit SubgraphIterator(const Subgraph::Identity* id); SubgraphIterator(const Subgraph * sg)106 explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {} 107 108 // Check whether the built-in iterator is at the end. AtEnd()109 bool AtEnd() const { return id_it_ == id_->end(); } 110 111 // Get the neighbor at the current iterator. 112 // MUST NOT be called when AtEnd(); GetNeighbor()113 const GenNode::LinkTarget& GetNeighbor() const { 114 return link_map_it_->second[link_idx_]; 115 } 116 117 // Get the node at the current iterator. 118 // MUST NOT be called when AtEnd(); GetNode()119 const GenNode* GetNode() const { return *id_it_; } 120 121 // Get the port leading to the neighbor at the current iterator. 122 // MUST NOT be called when AtEnd(); GetPort()123 GenNode::Port GetPort() const { return link_map_it_->first; } 124 125 // Increases the iterator. 126 // Returns true if NOT AtEnd() after increasing the iterator. 127 // Safe to call if already AtEnd(). 128 bool Next(); 129 130 // If there are more links at the same port, increases the iterator and 131 // returns true. Otherwise leaves the iterator unchanged and returns false. 132 bool NextIfSamePort(); 133 134 // Increases the iterator directly to the last position on the current port 135 // (or if already there then doesn't increase). Equivalent to calling 136 // NextIfSamePort() while it returns true, but faster. 137 // Safe to call if already AtEnd(). 138 void SkipPort(); 139 140 // Increases the iterator directly to the last position on the current node. 141 // Safe to call if already AtEnd(). 142 void SkipNode(); 143 144 // Returns true if the iterators are exactly the same. 145 bool operator==(const SubgraphIterator& other) const; 146 bool operator!=(const SubgraphIterator& other) const { 147 return !(*this == other); 148 } 149 150 private: 151 // After link_idx_ has been increased, make sure that it points to the 152 // next valid element (or end) by increasing the higher levels of iteration if 153 // needed. 154 // Returns true if NOT AtEnd() after increasing the iterator. 155 // NOT safe to call if already AtEnd(). 156 bool PropagateNext(); 157 158 // Identity of the subgraph being iterated over. 159 const Subgraph::Identity* id_; 160 161 // The current position, allowing to iterate through the links (see the 162 // reasoning for it in the public section). 163 // 164 // (1) Iterator of the nodes in the subgraph. 165 Subgraph::Identity::const_iterator id_it_; 166 // (2) Iterator in the link map of the node. 167 GenNode::LinkMap::const_iterator link_map_it_; 168 // (3) Index in the vector of the links. 169 int32_t link_idx_; 170 }; 171 172 // A convenient way to store subgraphs: in a set of unique_ptrs. This way the 173 // addresses of subgraph objects will stay stable, and the objects themselves 174 // won't be copied. 175 class SubgraphPtrSet 176 : public std::unordered_set<std::unique_ptr<Subgraph>, 177 HashAtPtr<std::unique_ptr<Subgraph>>, 178 EqAtPtr<std::unique_ptr<Subgraph>>> { 179 public: 180 // Attempts to extend the set by adding a new subgraph that gets created by 181 // adding one node to the parent subgraph. If such a subgraph already exists, 182 // returns nullptr, otherwise returns the pointer to the new subgraph. 183 Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node); 184 }; 185 186 } // end namespace graph_analyzer 187 } // end namespace grappler 188 } // end namespace tensorflow 189 190 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ 191