1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <vector> 22 23 #include "tensorflow/core/framework/graph.pb.h" 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h" 26 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/protobuf/meta_graph.pb.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 namespace graph_analyzer { 33 34 namespace test { 35 class SigBaseTest; 36 } // end namespace test 37 38 class SigNode; 39 40 // To find nodes by name. Having the map ordered makes the tests easier, 41 // and it isn't used in production code often enough to get any win from 42 // using an unordered map. 43 using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>; 44 45 // One node in the graph, in the form convenient for generation of the signature 46 // of the graph, and comparison of two (sub)graphs for equivalence. It refers to 47 // the original NodeDef protobuf for most information and adds the extra 48 // enrichment. 49 // 50 // The graph building is 2-stage: first match a SigNode with each NodeDef and 51 // collect them into a map that finds them by name, then process the map, 52 // deep-parse the underlying NodeDefs and connect the SigNodes together. 53 class SigNode { 54 public: 55 friend struct Signature; 56 57 // Will keep the pointer to the underlying NodeDef, so that 58 // underlying object must not be deleted while SigNode is alive. 59 explicit SigNode(const NodeDef* node); 60 61 // Access wrappers. name()62 const string& name() const { return node_->name(); } opcode()63 const string& opcode() const { return node_->op(); } node_def()64 const NodeDef* node_def() const { return node_; } 65 66 // For extraction of subgraphs into a separate SigNodeMap, copies the links 67 // that point inside the subgraph from a full-graph SigNode to a subgraph 68 // SigNode. The translation map defines the subgraph and gives the mapping 69 // from the nodes in the full graph to the matching nodes in subgraph. 70 using TranslationMap = 71 std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>; 72 void CopyLinks(const GenNode& from, const TranslationMap& map); 73 74 // A link is an edge of the graph that connects 2 nodes. Each of the connected 75 // nodes has its own perspective on the link, seeing its local port, remote 76 // port and the remote node. The direction of the link is encoded in the 77 // ports, one port is always incoming and another one outgoing. 78 // 79 // The link tag here contains both ports of the link viewed from the 80 // perspective of this node; consisting of both the local port (i.e. at this 81 // node) and remote port (i.e. on the other node), the local one going first. 82 struct LinkTag { 83 struct Hasher { operatorLinkTag::Hasher84 size_t operator()(const LinkTag& tag) const noexcept { 85 size_t hval = port_hasher(tag.local); 86 CombineHash(port_hasher(tag.remote), &hval); 87 return hval; 88 } 89 GenNode::Port::Hasher port_hasher; 90 }; 91 LinkTagLinkTag92 LinkTag(GenNode::Port a_local, GenNode::Port a_remote) 93 : local(a_local), remote(a_remote) {} 94 95 // The default constructor is used for the default values in maps. 96 // (false, 99) is an arbitrary value that makes the uninitialized 97 // links easy to tell when debugging (they should never happen). LinkTagLinkTag98 LinkTag() : local(false, 99), remote(false, 99) {} 99 100 // Port of the link on the local node. 101 GenNode::Port local; 102 // Port of the link on the remote node. 103 GenNode::Port remote; 104 105 bool operator==(const LinkTag& other) const { 106 return local == other.local && remote == other.remote; 107 } 108 bool operator<(const LinkTag& other) const { 109 return local < other.local || 110 (local == other.local && remote < other.remote); 111 } 112 }; 113 114 // Since the signature logic doesn't differentiate between the links 115 // with the same tag (other than by the "peer" nodes on their other ends), 116 // all the links with the same tag are grouped into a single structure. 117 struct Link { 118 LinkTag tag; 119 size_t unique_hash; // Hash of the tag after conflict resolution. 120 // The remote node(s) on the other side on the link(s). 121 using PeerVector = std::vector<SigNode*>; 122 PeerVector peers; 123 }; 124 125 // A way to look up the link description by its hash. 126 using LinkHashMap = std::map<size_t, Link>; hash_to_link()127 const LinkHashMap& hash_to_link() const { return hash_to_link_; } 128 129 // The enumeration of all the peer nodes in a predictable order. 130 // Before the signature generation, only the link values determine the 131 // order, after the signature generation the entries at the same 132 // links get further sorted by their peer node ranks. 133 struct HashedPeer { HashedPeerHashedPeer134 HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {} 135 136 struct LessByRank { operatorHashedPeer::LessByRank137 bool operator()(const SigNode::HashedPeer& left, 138 const SigNode::HashedPeer& right) { 139 return left.peer->unique_rank_ < right.peer->unique_rank_; 140 } 141 }; 142 143 size_t link_hash; 144 SigNode* peer; 145 }; 146 using HashedPeerVector = std::vector<HashedPeer>; hashed_peers()147 const HashedPeerVector& hashed_peers() const { return hashed_peers_; } 148 149 // Compares two nodes in two different graphs for equivalence (two nodes in 150 // the same graph would never be equivalent). Expects that the signatures of 151 // the graphs have already been computed, so unique_rank_ is filled in and 152 // the hashed_peers_ properly ordered. 153 bool operator==(const SigNode& other) const; 154 155 bool operator!=(const SigNode& other) const { return !(*this == other); } 156 157 private: 158 friend class test::SigBaseTest; 159 160 // The CopyLinks code is split into 2 parts for testability. 161 // The first pass builds a map ordered by LinkTag for predictability. 162 void CopyLinksPass1(const GenNode& from, const TranslationMap& map, 163 std::map<LinkTag, Link>* link_map); 164 // The second pass converts to the map by hash value, 165 // resolves any hash conflicts, and builds the hashed peer vector. 166 void CopyLinksPass2(std::map<LinkTag, Link>* link_map); 167 168 // Computes the topological hash at distance 0. Resets the topo_hash_ vector 169 // and hashed_nodes_; 170 void ComputeTopoHash0(); 171 172 // Compute the topological has at the given distance. The hashes for all the 173 // lower distances must be already computed for all the nodes in the graph. 174 // Also computes next_hashed_nodes_ from last_hashed_nodes_. 175 void ComputeTopoHash(int distance); 176 177 // Get the hash value for a particular distance. It must be previously 178 // computed. 179 size_t GetTopoHash(int distance) const; 180 181 // The hash value for the highest computed distance. It must be previously 182 // computed. GetHighTopoHash()183 size_t GetHighTopoHash() const { 184 CHECK(!topo_hash_.empty()); 185 return topo_hash_.back(); 186 } 187 188 // Rehash the topmost hash, to avoid conflicts. ReHighTopoHash()189 void ReHighTopoHash() { 190 CHECK(!topo_hash_.empty()); 191 CombineHash(1, &topo_hash_.back()); 192 } 193 194 // Ordering by node order and highest available hash (it must be 195 // previously computed). 196 struct NodeOrderLess { operatorNodeOrderLess197 bool operator()(const SigNode* left, const SigNode* right) { 198 return left->topo_hash_.back() < right->topo_hash_.back(); 199 } 200 }; 201 202 private: 203 const NodeDef* node_; 204 205 // The bitmap mask with 1 bit set that represents this node in the set 206 // during the computation of the signature. 207 uint64_t node_mask_ = 0; 208 209 // The code that populates this map makes sure that there are no hash 210 // conflicts, rehashing if necessary. 211 LinkHashMap hash_to_link_; 212 213 // The enumeration of all the direct peers in the predictable order (which 214 // happens to be the order ot their link tags, but the order of the hashes 215 // would do too). It is used for the quick enumeration during the signature 216 // computation. After the signature building is completed, the entries that 217 // have the same link tag get further sorted in the order of the ranks of 218 // their nodes. 219 HashedPeerVector hashed_peers_; 220 221 // The unique rank represents the order in which the node will be included 222 // into the signature. It gets assigned in order either when the topo_hash_ of 223 // this node becomes unique in the graph, or when the nodes are completely 224 // equivalent, one of them is picked at random to assign the next rank, and 225 // then the rest of the nodes attempt to disambiguate based on that 226 // information. 227 size_t unique_rank_ = ~0; 228 // When hash_is_final_ is set, the topo_has_ vector stops growing, and the 229 // last value from it is used for all the further hashes. 230 bool hash_is_final_ = false; 231 // The hashes that include the topology of the nodes up to the distance N. The 232 // hash for distance 0 is produced from the attributes of this node itself and 233 // its general connectivity properties but no information about the 234 // neighboring nodes. The hash for distance D+1 is build from hashes at level 235 // D of this node and of all its immediate neighbors. The neighbors that are 236 // connected by equivalent links are included in a commutative way. 237 std::vector<size_t> topo_hash_; 238 // The set of nodes that got included into the computation of the 239 // last topo_hash_ entry. 240 uint64_t last_hashed_nodes_ = 0; 241 // The next set of nodes that gets used for the current topo_hash entry. 242 uint64_t next_hashed_nodes_ = 0; 243 }; 244 245 // Signature of a graph. The computation is intertwined with the private methods 246 // of SigNode, so keeping both in the same file looks more convenient. 247 struct Signature { 248 friend class test::SigBaseTest; 249 250 // Maximal size of the graphs for which the signature can be computed. 251 // Changing this constant won't magically add the support for a larger size, 252 // the rest of implementation would have to be extended. The value of 64 is 253 // driven by the size of a bitset in an uint64_t, and should be enough for our 254 // purposes, while having a high efficiency of implementation. 255 static constexpr int kMaxGraphSize = 64; 256 257 // Using the map, computes the rest of the fields of a signature. 258 // Returns an error is the graph is too big. 259 Status Compute(); 260 261 // Convert the computed signature to a string representation. 262 string ToString() const; 263 264 SigNodeMap map; // The nodes in the graph, accessible by name. 265 size_t sig_short = 0; // Hash of the signature, for the quick equality check. 266 // The full signature: hashes of the nodes in a predictable order. 267 std::vector<size_t> sig_full; 268 // The nodes in the same order as they go in the signature. 269 std::vector<SigNode*> nodes; 270 271 // For building the unordered maps. HashSignature272 size_t Hash() const { return sig_short; } 273 274 // Returns true if the graphs are equivalent. The signature must be already 275 // computed. 276 bool operator==(const Signature& other) const; 277 278 private: 279 // Populates the nodes vector from the map and initializes the state of the 280 // nodes for the signature computation. 281 void PrepareNodes(); 282 283 // Finds the nodes with the hashes that are unique and assigns the unique ids 284 // to them. If there are nodes with non-unique hashes, exactly one node from 285 // the first such sequence (in the order of hash values) will be picked and 286 // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have 287 // been already assigned the unique ids. Advances next_node_id by at least 1. 288 void FindUniqueHashes(size_t* next_node_id_p); 289 290 // One round of the signature computation. Assumes that the 291 // nodes[0...(next_node_id-1)] have been already assigned the fixed 292 // positions, and thus computes the hashes only for the remaining nodes. 293 void ComputeOneRound(size_t next_node_id); 294 295 // Additional ordering of the hashed_peers_ links in the nodes, so that they 296 // can be compared and printed in a predictable order. 297 void OrderLinks(); 298 }; 299 300 } // end namespace graph_analyzer 301 } // end namespace grappler 302 } // end namespace tensorflow 303 304 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ 305