• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17 
18 #include <algorithm>
19 
20 #include "absl/strings/str_format.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace graph_analyzer {
25 
26 static constexpr bool debug = false;
27 
28 //=== SigNode
29 
SigNode(const NodeDef * node)30 SigNode::SigNode(const NodeDef* node) : node_(node) {}
31 
CopyLinks(const GenNode & from,const TranslationMap & map)32 void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
33   hash_to_link_.clear();
34   hashed_peers_.clear();
35 
36   std::map<LinkTag, Link> link_map;
37   CopyLinksPass1(from, map, &link_map);
38   CopyLinksPass2(&link_map);
39 }
40 
CopyLinksPass1(const GenNode & from,const TranslationMap & map,std::map<LinkTag,Link> * link_map)41 void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
42                              std::map<LinkTag, Link>* link_map) {
43   LinkTag::Hasher link_hasher;
44 
45   for (const auto& entry : from.links()) {
46     for (const auto& target : entry.second) {
47       auto nodeit = map.find(target.node);
48       if (nodeit == map.end()) {
49         // Node is not in the subgraph, ignore.
50         continue;
51       }
52 
53       LinkTag tag(entry.first, target.port);
54       size_t hval = link_hasher(tag);
55 
56       // This instantiates the entry if it was not present.
57       Link& map_entry = (*link_map)[tag];
58       if (map_entry.peers.empty()) {
59         map_entry.tag = tag;
60         map_entry.unique_hash = hval;
61       }
62       map_entry.peers.push_back(nodeit->second);
63     }
64   }
65 }
66 
CopyLinksPass2(std::map<LinkTag,Link> * link_map)67 void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
68   for (auto& entry : *link_map) {
69     Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
70     // In case of a conflict, rehash. This should almost never happen.
71     // Because the order of iteration is predictable, the rehashed values
72     // will also be predictable.
73     while (!hl_entry_ptr->peers.empty()) {
74       CombineHash(1, &entry.second.unique_hash);
75       hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
76     }
77 
78     for (const auto& peer : entry.second.peers) {
79       hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
80     }
81 
82     hl_entry_ptr->tag = entry.second.tag;
83     hl_entry_ptr->unique_hash = entry.second.unique_hash;
84     hl_entry_ptr->peers.swap(entry.second.peers);
85   }
86 }
87 
ComputeTopoHash0()88 void SigNode::ComputeTopoHash0() {
89   topo_hash_.clear();
90   last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
91 
92   // TODO(babkin): include the attributes too, as an option.
93   size_t hval = std::hash<string>()(opcode());
94 
95   // Getting the topology of the links in to the hash early should get more
96   // conflicts resolved early.
97   for (const auto& entry : hashed_peers_) {
98     CombineHash(entry.link_hash, &hval);
99   }
100 
101   topo_hash_.push_back(hval);
102 }
103 
ComputeTopoHash(int distance)104 void SigNode::ComputeTopoHash(int distance) {
105   // The new starting point.
106   next_hashed_nodes_ = last_hashed_nodes_;
107   if (debug) {
108     LOG(INFO) << "DEBUG    node " << name() << " mask=" << std::hex
109               << next_hashed_nodes_;
110   }
111 
112   if (hash_is_final_) {
113     return;
114   }
115 
116   const int64 topo_hash_size = topo_hash_.size();
117   CHECK(topo_hash_size == distance);
118 
119   int prev = distance - 1;
120 
121   // Start with own's local topology hash. This value is stable, so
122   // if the hashes of the surrounding nodes don't change on the following
123   // distances, the hash of this node won't change either.
124   size_t hval = topo_hash_[0];
125 
126   if (!hashed_peers_.empty()) {
127     size_t last_link_hash = hashed_peers_[0].link_hash;
128     size_t comm_hash = 0;
129 
130     for (const auto& entry : hashed_peers_) {
131       if (entry.link_hash != last_link_hash) {
132         CombineHash(last_link_hash, &hval);
133         CombineHash(comm_hash, &hval);
134         comm_hash = 0;
135         last_link_hash = entry.link_hash;
136       }
137 
138       // The links in the same vector are commutative, so combine their
139       // hashes in a commutative way.
140       CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
141       next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
142       if (debug) {
143         LOG(INFO) << "DEBUG    node " << name() << " += " << entry.peer->name()
144                   << " mask=" << std::hex << next_hashed_nodes_;
145       }
146     }
147 
148     // The last commutative group.
149     CombineHash(last_link_hash, &hval);
150     CombineHash(comm_hash, &hval);
151   }
152 
153   topo_hash_.push_back(hval);
154 }
155 
GetTopoHash(int distance) const156 size_t SigNode::GetTopoHash(int distance) const {
157   CHECK(!topo_hash_.empty());
158   const int64 topo_hash_size = topo_hash_.size();
159   if (distance >= topo_hash_size) {
160     CHECK(hash_is_final_);
161     return topo_hash_.back();
162   } else {
163     return topo_hash_[distance];
164   }
165 }
166 
operator ==(const SigNode & other) const167 bool SigNode::operator==(const SigNode& other) const {
168   // TODO(babkin): add attributes too.
169   if (opcode() != other.opcode()) {
170     return false;
171   }
172 
173   // Normally the caller is expected to compare the nodes
174   // at the same rank in different graphs, but just in case...
175   if (unique_rank_ != other.unique_rank_) {
176     return false;
177   }
178 
179   if (hashed_peers_.size() != other.hashed_peers_.size()) {
180     return false;
181   }
182 
183   for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
184        it1 != hashed_peers_.end(); ++it1, ++it2) {
185     // TODO(babkin): might compare the actual values too
186     // but the hash is probably just as good.
187     if (it1->link_hash != it2->link_hash) {
188       return false;
189     }
190     if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
191       return false;
192     }
193   }
194 
195   return true;
196 }
197 
198 //=== Signature
199 
200 constexpr int Signature::kMaxGraphSize;
201 
ToString() const202 string Signature::ToString() const {
203   string result;
204   for (size_t n = 0; n < nodes.size(); ++n) {
205     // TODO(babkin): add attributes too.
206     result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
207     for (const auto& entry : nodes[n]->hashed_peers_) {
208       const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
209 
210       // The link entries are already sorted, by tags and then by the
211       // node ranks.
212       if (link.tag.local.IsInbound()) {
213         result +=
214             absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
215                             string(link.tag.remote), entry.peer->unique_rank_);
216       }
217     }
218     result.push_back(',');
219   }
220   return result;
221 }
222 
Compute()223 Status Signature::Compute() {
224   if (map.size() > kMaxGraphSize) {
225     return Status(
226         error::INVALID_ARGUMENT,
227         absl::StrFormat(
228             "A graph of %d nodes is too big for signature computation, "
229             "the maximal supported node count is %d.",
230             map.size(), kMaxGraphSize));
231   }
232 
233   // The value that will be assigned next as the unique node id.
234   // This also means that all the entries in nodes at indexes less than this
235   // have been finalized and don't need to be touched any more.
236   size_t next_node_id = 0;
237 
238   sig_short = 0;
239   sig_full.resize(0);  // Keep the storage.
240 
241   // The main signature generation.
242   PrepareNodes();
243   FindUniqueHashes(&next_node_id);
244   while (next_node_id < map.size()) {
245     ComputeOneRound(next_node_id);
246     FindUniqueHashes(&next_node_id);
247   }
248 
249   OrderLinks();
250 
251   return Status::OK();
252 }
253 
PrepareNodes()254 void Signature::PrepareNodes() {
255   nodes.resize(0);  // Keep the storage.
256 
257   // Initialize the nodes.
258   int64_t mask = 1;
259   for (const auto& entry : map) {
260     SigNode* node = entry.second.get();
261     node->last_hashed_nodes_ = node->node_mask_ = mask;
262     mask <<= 1;
263     node->unique_rank_ = ~0;
264     node->hash_is_final_ = false;
265     node->ComputeTopoHash0();
266     if (node->GetHighTopoHash() <= map.size()) {
267       // Would conflict with one of the reserved values.
268       node->ReHighTopoHash();
269     }
270 
271     // The initial order is random.
272     nodes.emplace_back(node);
273   }
274 }
275 
FindUniqueHashes(size_t * next_node_id_p)276 void Signature::FindUniqueHashes(size_t* next_node_id_p) {
277   // Start by sorting by the hash value.
278   std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
279             SigNode::NodeOrderLess());
280 
281   // At each call, if no nodes have unique hashes, one node that has a
282   // non-unique (shared) hash can be made unique by assigning a unique id.
283   // This node gets picked predictably by taking the last node.
284   // TODO(babkin): Technically, more than one node can be unshared,
285   // as long as their last_hashed_nodes_ overlap only by the nodes that
286   // already had the assigned ids before the current round. But it's not clear
287   // yet, how often would this beneficial, because it looks like for many
288   // subgraphs unsharing one node should be enough to untangle them. This
289   // would need more measurement before implementing.
290   bool found_unique = false;
291   for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
292     size_t cur_hash = nodes[n]->GetHighTopoHash();
293     if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
294       // A sequence of nodes sharing the same hash. Skip over it.
295       // TODO(babkin): check here for the arbitrary hash conflicts and resolve
296       // them.
297       for (++n;
298            n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
299            ++n) {
300       }
301       if (found_unique || n != nodes.size() - 1) {
302         // Either some unique nodes have already been found, or this is
303         // not the last chance, keep trying to find the unique nodes.
304         continue;
305       }
306       // Here we're at the last node and haven't found any unique ones.
307       // So fall through and make this last node unique.
308     }
309 
310     found_unique = true;
311     size_t id = (*next_node_id_p)++;
312     nodes[n]->unique_rank_ = id;
313 
314     size_t last_hash = nodes[n]->GetHighTopoHash();
315     CombineHash(last_hash, &sig_short);
316     sig_full.push_back(last_hash);
317 
318     // Take the hash at 0 and mix the unique rank into it. After that it will
319     // stay fixed.
320     nodes[n]->topo_hash_.resize(1);
321     nodes[n]->topo_hash_[0] = id + 1;  // Avoid the value of 0.
322 
323     nodes[n]->hash_is_final_ = true;
324     nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
325     if (n != id) {
326       std::swap(nodes[id], nodes[n]);
327     }
328   }
329 }
330 
ComputeOneRound(size_t next_node_id)331 void Signature::ComputeOneRound(size_t next_node_id) {
332   // Reset the state of the nodes.
333   int debug_i = 0;
334   for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
335     auto node = *it;
336     // The hash at distance 0 never changes, so preserve it.
337     node->topo_hash_.resize(1);
338     node->last_hashed_nodes_ = node->node_mask_;
339     node->hash_is_final_ = false;
340     if (debug) {
341       LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
342                 << node->name() << " mask=" << std::hex
343                 << node->last_hashed_nodes_;
344     }
345   }
346 
347   bool stop = false;
348   // The distance can reach up to nodes.size()+1, to include not only all the
349   // nodes but also all the redundant paths.
350   for (int distance = 1; !stop; ++distance) {
351     for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
352       auto node = *it;
353       if (node->hash_is_final_) {
354         continue;
355       }
356       node->ComputeTopoHash(distance);
357       if (node->GetHighTopoHash() <= nodes.size()) {
358         // Would conflict with one of the reserved values.
359         node->ReHighTopoHash();
360       }
361     }
362 
363     // Will be looking for the indications to not stop.
364     stop = true;
365 
366     debug_i = 0;
367     // The bitmasks get moved after all the hash computations are done.
368     for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
369       auto node = *it;
370       if (debug) {
371         LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
372                   << " " << node->name() << " oldmask=" << std::hex
373                   << node->last_hashed_nodes_ << " mask=" << std::hex
374                   << node->next_hashed_nodes_;
375       }
376       if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
377         // Stopped growing, this part of the graph must be fully
378         // surrounded by nodes that already have the unique ids.
379         node->hash_is_final_ = true;
380       } else {
381         node->last_hashed_nodes_ = node->next_hashed_nodes_;
382         stop = false;
383       }
384     }
385   }
386 }
387 
OrderLinks()388 void Signature::OrderLinks() {
389   for (const auto& node : nodes) {
390     if (node->hashed_peers_.empty()) {
391       continue;
392     }
393 
394     size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
395     int first_idx = -1;
396 
397     int idx;
398     for (idx = 0; idx < static_cast<int64>(node->hashed_peers_.size()); ++idx) {
399       auto& entry = node->hashed_peers_[idx];
400       if (entry.link_hash == cur_link_hash) {
401         continue;
402       }
403       if (idx - first_idx > 1) {
404         // Need to sort.
405         std::sort(node->hashed_peers_.begin() + first_idx,
406                   node->hashed_peers_.begin() + idx,
407                   SigNode::HashedPeer::LessByRank());
408       }
409 
410       cur_link_hash = entry.link_hash;
411       first_idx = idx;
412     }
413     if (idx - first_idx > 1) {
414       // Sort the last bunch.
415       std::sort(node->hashed_peers_.begin() + first_idx,
416                 node->hashed_peers_.begin() + idx,
417                 SigNode::HashedPeer::LessByRank());
418     }
419   }
420 }
421 
operator ==(const Signature & other) const422 bool Signature::operator==(const Signature& other) const {
423   // Tries to find the differences as early as possible by
424   // comparing the hashes first.
425 
426   if (sig_short != other.sig_short) {
427     return false;
428   }
429   if (sig_full.size() != other.sig_full.size()) {
430     return false;
431   }
432 
433   for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
434        it1 != sig_full.end(); ++it1, ++it2) {
435     if (*it1 != *it2) {
436       return false;
437     }
438   }
439 
440   if (nodes.size() != other.nodes.size()) {
441     return false;
442   }
443   for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
444        ++it1, ++it2) {
445     if (**it1 != **it2) {
446       return false;
447     }
448   }
449 
450   return true;
451 }
452 
453 }  // end namespace graph_analyzer
454 }  // end namespace grappler
455 }  // end namespace tensorflow
456