• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
26 #include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/protobuf/meta_graph.pb.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace graph_analyzer {
33 
34 namespace test {
35 class SigBaseTest;
36 }  // end namespace test
37 
38 class SigNode;
39 
40 // To find nodes by name. Having the map ordered makes the tests easier,
41 // and it isn't used in production code often enough to get any win from
42 // using an unordered map.
43 using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>;
44 
45 // One node in the graph, in the form convenient for generation of the signature
46 // of the graph, and comparison of two (sub)graphs for equivalence. It refers to
47 // the original NodeDef protobuf for most information and adds the extra
48 // enrichment.
49 //
50 // The graph building is 2-stage: first match a SigNode with each NodeDef and
51 // collect them into a map that finds them by name, then process the map,
52 // deep-parse the underlying NodeDefs and connect the SigNodes together.
53 class SigNode {
54  public:
55   friend struct Signature;
56 
57   // Will keep the pointer to the underlying NodeDef, so that
58   // underlying object must not be deleted while SigNode is alive.
59   explicit SigNode(const NodeDef* node);
60 
61   // Access wrappers.
name()62   const string& name() const { return node_->name(); }
opcode()63   const string& opcode() const { return node_->op(); }
node_def()64   const NodeDef* node_def() const { return node_; }
65 
66   // For extraction of subgraphs into a separate SigNodeMap, copies the links
67   // that point inside the subgraph from a full-graph SigNode to a subgraph
68   // SigNode. The translation map defines the subgraph and gives the mapping
69   // from the nodes in the full graph to the matching nodes in subgraph.
70   using TranslationMap =
71       std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>;
72   void CopyLinks(const GenNode& from, const TranslationMap& map);
73 
74   // A link is an edge of the graph that connects 2 nodes. Each of the connected
75   // nodes has its own perspective on the link, seeing its local port, remote
76   // port and the remote node. The direction of the link is encoded in the
77   // ports, one port is always incoming and another one outgoing.
78   //
79   // The link tag here contains both ports of the link viewed from the
80   // perspective of this node; consisting of both the local port (i.e. at this
81   // node) and remote port (i.e. on the other node), the local one going first.
82   struct LinkTag {
83     struct Hasher {
operatorLinkTag::Hasher84       size_t operator()(const LinkTag& tag) const noexcept {
85         size_t hval = port_hasher(tag.local);
86         CombineHash(port_hasher(tag.remote), &hval);
87         return hval;
88       }
89       GenNode::Port::Hasher port_hasher;
90     };
91 
LinkTagLinkTag92     LinkTag(GenNode::Port a_local, GenNode::Port a_remote)
93         : local(a_local), remote(a_remote) {}
94 
95     // The default constructor is used for the default values in maps.
96     // (false, 99) is an arbitrary value that makes the uninitialized
97     // links easy to tell when debugging (they should never happen).
LinkTagLinkTag98     LinkTag() : local(false, 99), remote(false, 99) {}
99 
100     // Port of the link on the local node.
101     GenNode::Port local;
102     // Port of the link on the remote node.
103     GenNode::Port remote;
104 
105     bool operator==(const LinkTag& other) const {
106       return local == other.local && remote == other.remote;
107     }
108     bool operator<(const LinkTag& other) const {
109       return local < other.local ||
110              (local == other.local && remote < other.remote);
111     }
112   };
113 
114   // Since the signature logic doesn't differentiate between the links
115   // with the same tag (other than by the "peer" nodes on their other ends),
116   // all the links with the same tag are grouped into a single structure.
117   struct Link {
118     LinkTag tag;
119     size_t unique_hash;  // Hash of the tag after conflict resolution.
120     // The remote node(s) on the other side on the link(s).
121     using PeerVector = std::vector<SigNode*>;
122     PeerVector peers;
123   };
124 
125   // A way to look up the link description by its hash.
126   using LinkHashMap = std::map<size_t, Link>;
hash_to_link()127   const LinkHashMap& hash_to_link() const { return hash_to_link_; }
128 
129   // The enumeration of all the peer nodes in a predictable order.
130   // Before the signature generation, only the link values determine the
131   // order, after the signature generation the entries at the same
132   // links get further sorted by their peer node ranks.
133   struct HashedPeer {
HashedPeerHashedPeer134     HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {}
135 
136     struct LessByRank {
operatorHashedPeer::LessByRank137       bool operator()(const SigNode::HashedPeer& left,
138                       const SigNode::HashedPeer& right) {
139         return left.peer->unique_rank_ < right.peer->unique_rank_;
140       }
141     };
142 
143     size_t link_hash;
144     SigNode* peer;
145   };
146   using HashedPeerVector = std::vector<HashedPeer>;
hashed_peers()147   const HashedPeerVector& hashed_peers() const { return hashed_peers_; }
148 
149   // Compares two nodes in two different graphs for equivalence (two nodes in
150   // the same graph would never be equivalent). Expects that the signatures of
151   // the graphs have already been computed, so unique_rank_ is filled in and
152   // the hashed_peers_ properly ordered.
153   bool operator==(const SigNode& other) const;
154 
155   bool operator!=(const SigNode& other) const { return !(*this == other); }
156 
157  private:
158   friend class test::SigBaseTest;
159 
160   // The CopyLinks code is split into 2 parts for testability.
161   // The first pass builds a map ordered by LinkTag for predictability.
162   void CopyLinksPass1(const GenNode& from, const TranslationMap& map,
163                       std::map<LinkTag, Link>* link_map);
164   // The second pass converts to the map by hash value,
165   // resolves any hash conflicts, and builds the hashed peer vector.
166   void CopyLinksPass2(std::map<LinkTag, Link>* link_map);
167 
168   // Computes the topological hash at distance 0. Resets the topo_hash_ vector
169   // and hashed_nodes_;
170   void ComputeTopoHash0();
171 
172   // Compute the topological has at the given distance. The hashes for all the
173   // lower distances must be already computed for all the nodes in the graph.
174   // Also computes next_hashed_nodes_ from last_hashed_nodes_.
175   void ComputeTopoHash(int distance);
176 
177   // Get the hash value for a particular distance. It must be previously
178   // computed.
179   size_t GetTopoHash(int distance) const;
180 
181   // The hash value for the highest computed distance. It must be previously
182   // computed.
GetHighTopoHash()183   size_t GetHighTopoHash() const {
184     CHECK(!topo_hash_.empty());
185     return topo_hash_.back();
186   }
187 
188   // Rehash the topmost hash, to avoid conflicts.
ReHighTopoHash()189   void ReHighTopoHash() {
190     CHECK(!topo_hash_.empty());
191     CombineHash(1, &topo_hash_.back());
192   }
193 
194   // Ordering by node order and highest available hash (it must be
195   // previously computed).
196   struct NodeOrderLess {
operatorNodeOrderLess197     bool operator()(const SigNode* left, const SigNode* right) {
198       return left->topo_hash_.back() < right->topo_hash_.back();
199     }
200   };
201 
202  private:
203   const NodeDef* node_;
204 
205   // The bitmap mask with 1 bit set that represents this node in the set
206   // during the computation of the signature.
207   uint64_t node_mask_ = 0;
208 
209   // The code that populates this map makes sure that there are no hash
210   // conflicts, rehashing if necessary.
211   LinkHashMap hash_to_link_;
212 
213   // The enumeration of all the direct peers in the predictable order (which
214   // happens to be the order ot their link tags, but the order of the hashes
215   // would do too). It is used for the quick enumeration during the signature
216   // computation. After the signature building is completed, the entries that
217   // have the same link tag get further sorted in the order of the ranks of
218   // their nodes.
219   HashedPeerVector hashed_peers_;
220 
221   // The unique rank represents the order in which the node will be included
222   // into the signature. It gets assigned in order either when the topo_hash_ of
223   // this node becomes unique in the graph, or when the nodes are completely
224   // equivalent, one of them is picked at random to assign the next rank, and
225   // then the rest of the nodes attempt to disambiguate based on that
226   // information.
227   size_t unique_rank_ = ~0;
228   // When hash_is_final_ is set, the topo_has_ vector stops growing, and the
229   // last value from it is used for all the further hashes.
230   bool hash_is_final_ = false;
231   // The hashes that include the topology of the nodes up to the distance N. The
232   // hash for distance 0 is produced from the attributes of this node itself and
233   // its general connectivity properties but no information about the
234   // neighboring nodes. The hash for distance D+1 is build from hashes at level
235   // D of this node and of all its immediate neighbors. The neighbors that are
236   // connected by equivalent links are included in a commutative way.
237   std::vector<size_t> topo_hash_;
238   // The set of nodes that got included into the computation of the
239   // last topo_hash_ entry.
240   uint64_t last_hashed_nodes_ = 0;
241   // The next set of nodes that gets used for the current topo_hash entry.
242   uint64_t next_hashed_nodes_ = 0;
243 };
244 
245 // Signature of a graph. The computation is intertwined with the private methods
246 // of SigNode, so keeping both in the same file looks more convenient.
247 struct Signature {
248   friend class test::SigBaseTest;
249 
250   // Maximal size of the graphs for which the signature can be computed.
251   // Changing this constant won't magically add the support for a larger size,
252   // the rest of implementation would have to be extended. The value of 64 is
253   // driven by the size of a bitset in an uint64_t, and should be enough for our
254   // purposes, while having a high efficiency of implementation.
255   static constexpr int kMaxGraphSize = 64;
256 
257   // Using the map, computes the rest of the fields of a signature.
258   // Returns an error is the graph is too big.
259   Status Compute();
260 
261   // Convert the computed signature to a string representation.
262   string ToString() const;
263 
264   SigNodeMap map;        // The nodes in the graph, accessible by name.
265   size_t sig_short = 0;  // Hash of the signature, for the quick equality check.
266   // The full signature: hashes of the nodes in a predictable order.
267   std::vector<size_t> sig_full;
268   // The nodes in the same order as they go in the signature.
269   std::vector<SigNode*> nodes;
270 
271   // For building the unordered maps.
HashSignature272   size_t Hash() const { return sig_short; }
273 
274   // Returns true if the graphs are equivalent. The signature must be already
275   // computed.
276   bool operator==(const Signature& other) const;
277 
278  private:
279   // Populates the nodes vector from the map and initializes the state of the
280   // nodes for the signature computation.
281   void PrepareNodes();
282 
283   // Finds the nodes with the hashes that are unique and assigns the unique ids
284   // to them. If there are nodes with non-unique hashes, exactly one node from
285   // the first such sequence (in the order of hash values) will be picked and
286   // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have
287   // been already assigned the unique ids. Advances next_node_id by at least 1.
288   void FindUniqueHashes(size_t* next_node_id_p);
289 
290   // One round of the signature computation. Assumes that the
291   // nodes[0...(next_node_id-1)] have been already assigned the fixed
292   // positions, and thus computes the hashes only for the remaining nodes.
293   void ComputeOneRound(size_t next_node_id);
294 
295   // Additional ordering of the hashed_peers_ links in the nodes, so that they
296   // can be compared and printed in a predictable order.
297   void OrderLinks();
298 };
299 
300 }  // end namespace graph_analyzer
301 }  // end namespace grappler
302 }  // end namespace tensorflow
303 
304 #endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
305