1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17
18 #include <algorithm>
19
20 #include "absl/strings/str_format.h"
21
22 namespace tensorflow {
23 namespace grappler {
24 namespace graph_analyzer {
25
26 static constexpr bool debug = false;
27
28 //=== SigNode
29
SigNode(const NodeDef * node)30 SigNode::SigNode(const NodeDef* node) : node_(node) {}
31
CopyLinks(const GenNode & from,const TranslationMap & map)32 void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
33 hash_to_link_.clear();
34 hashed_peers_.clear();
35
36 std::map<LinkTag, Link> link_map;
37 CopyLinksPass1(from, map, &link_map);
38 CopyLinksPass2(&link_map);
39 }
40
CopyLinksPass1(const GenNode & from,const TranslationMap & map,std::map<LinkTag,Link> * link_map)41 void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
42 std::map<LinkTag, Link>* link_map) {
43 LinkTag::Hasher link_hasher;
44
45 for (const auto& entry : from.links()) {
46 for (const auto& target : entry.second) {
47 auto nodeit = map.find(target.node);
48 if (nodeit == map.end()) {
49 // Node is not in the subgraph, ignore.
50 continue;
51 }
52
53 LinkTag tag(entry.first, target.port);
54 size_t hval = link_hasher(tag);
55
56 // This instantiates the entry if it was not present.
57 Link& map_entry = (*link_map)[tag];
58 if (map_entry.peers.empty()) {
59 map_entry.tag = tag;
60 map_entry.unique_hash = hval;
61 }
62 map_entry.peers.push_back(nodeit->second);
63 }
64 }
65 }
66
CopyLinksPass2(std::map<LinkTag,Link> * link_map)67 void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
68 for (auto& entry : *link_map) {
69 Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
70 // In case of a conflict, rehash. This should almost never happen.
71 // Because the order of iteration is predictable, the rehashed values
72 // will also be predictable.
73 while (!hl_entry_ptr->peers.empty()) {
74 CombineHash(1, &entry.second.unique_hash);
75 hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
76 }
77
78 for (const auto& peer : entry.second.peers) {
79 hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
80 }
81
82 hl_entry_ptr->tag = entry.second.tag;
83 hl_entry_ptr->unique_hash = entry.second.unique_hash;
84 hl_entry_ptr->peers.swap(entry.second.peers);
85 }
86 }
87
ComputeTopoHash0()88 void SigNode::ComputeTopoHash0() {
89 topo_hash_.clear();
90 last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
91
92 // TODO(babkin): include the attributes too, as an option.
93 size_t hval = std::hash<string>()(opcode());
94
95 // Getting the topology of the links in to the hash early should get more
96 // conflicts resolved early.
97 for (const auto& entry : hashed_peers_) {
98 CombineHash(entry.link_hash, &hval);
99 }
100
101 topo_hash_.push_back(hval);
102 }
103
ComputeTopoHash(int distance)104 void SigNode::ComputeTopoHash(int distance) {
105 // The new starting point.
106 next_hashed_nodes_ = last_hashed_nodes_;
107 if (debug) {
108 LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex
109 << next_hashed_nodes_;
110 }
111
112 if (hash_is_final_) {
113 return;
114 }
115
116 const int64 topo_hash_size = topo_hash_.size();
117 CHECK(topo_hash_size == distance);
118
119 int prev = distance - 1;
120
121 // Start with own's local topology hash. This value is stable, so
122 // if the hashes of the surrounding nodes don't change on the following
123 // distances, the hash of this node won't change either.
124 size_t hval = topo_hash_[0];
125
126 if (!hashed_peers_.empty()) {
127 size_t last_link_hash = hashed_peers_[0].link_hash;
128 size_t comm_hash = 0;
129
130 for (const auto& entry : hashed_peers_) {
131 if (entry.link_hash != last_link_hash) {
132 CombineHash(last_link_hash, &hval);
133 CombineHash(comm_hash, &hval);
134 comm_hash = 0;
135 last_link_hash = entry.link_hash;
136 }
137
138 // The links in the same vector are commutative, so combine their
139 // hashes in a commutative way.
140 CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
141 next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
142 if (debug) {
143 LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name()
144 << " mask=" << std::hex << next_hashed_nodes_;
145 }
146 }
147
148 // The last commutative group.
149 CombineHash(last_link_hash, &hval);
150 CombineHash(comm_hash, &hval);
151 }
152
153 topo_hash_.push_back(hval);
154 }
155
GetTopoHash(int distance) const156 size_t SigNode::GetTopoHash(int distance) const {
157 CHECK(!topo_hash_.empty());
158 const int64 topo_hash_size = topo_hash_.size();
159 if (distance >= topo_hash_size) {
160 CHECK(hash_is_final_);
161 return topo_hash_.back();
162 } else {
163 return topo_hash_[distance];
164 }
165 }
166
operator ==(const SigNode & other) const167 bool SigNode::operator==(const SigNode& other) const {
168 // TODO(babkin): add attributes too.
169 if (opcode() != other.opcode()) {
170 return false;
171 }
172
173 // Normally the caller is expected to compare the nodes
174 // at the same rank in different graphs, but just in case...
175 if (unique_rank_ != other.unique_rank_) {
176 return false;
177 }
178
179 if (hashed_peers_.size() != other.hashed_peers_.size()) {
180 return false;
181 }
182
183 for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
184 it1 != hashed_peers_.end(); ++it1, ++it2) {
185 // TODO(babkin): might compare the actual values too
186 // but the hash is probably just as good.
187 if (it1->link_hash != it2->link_hash) {
188 return false;
189 }
190 if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
191 return false;
192 }
193 }
194
195 return true;
196 }
197
198 //=== Signature
199
200 constexpr int Signature::kMaxGraphSize;
201
ToString() const202 string Signature::ToString() const {
203 string result;
204 for (size_t n = 0; n < nodes.size(); ++n) {
205 // TODO(babkin): add attributes too.
206 result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
207 for (const auto& entry : nodes[n]->hashed_peers_) {
208 const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
209
210 // The link entries are already sorted, by tags and then by the
211 // node ranks.
212 if (link.tag.local.IsInbound()) {
213 result +=
214 absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
215 string(link.tag.remote), entry.peer->unique_rank_);
216 }
217 }
218 result.push_back(',');
219 }
220 return result;
221 }
222
Compute()223 Status Signature::Compute() {
224 if (map.size() > kMaxGraphSize) {
225 return Status(
226 error::INVALID_ARGUMENT,
227 absl::StrFormat(
228 "A graph of %d nodes is too big for signature computation, "
229 "the maximal supported node count is %d.",
230 map.size(), kMaxGraphSize));
231 }
232
233 // The value that will be assigned next as the unique node id.
234 // This also means that all the entries in nodes at indexes less than this
235 // have been finalized and don't need to be touched any more.
236 size_t next_node_id = 0;
237
238 sig_short = 0;
239 sig_full.resize(0); // Keep the storage.
240
241 // The main signature generation.
242 PrepareNodes();
243 FindUniqueHashes(&next_node_id);
244 while (next_node_id < map.size()) {
245 ComputeOneRound(next_node_id);
246 FindUniqueHashes(&next_node_id);
247 }
248
249 OrderLinks();
250
251 return Status::OK();
252 }
253
PrepareNodes()254 void Signature::PrepareNodes() {
255 nodes.resize(0); // Keep the storage.
256
257 // Initialize the nodes.
258 int64_t mask = 1;
259 for (const auto& entry : map) {
260 SigNode* node = entry.second.get();
261 node->last_hashed_nodes_ = node->node_mask_ = mask;
262 mask <<= 1;
263 node->unique_rank_ = ~0;
264 node->hash_is_final_ = false;
265 node->ComputeTopoHash0();
266 if (node->GetHighTopoHash() <= map.size()) {
267 // Would conflict with one of the reserved values.
268 node->ReHighTopoHash();
269 }
270
271 // The initial order is random.
272 nodes.emplace_back(node);
273 }
274 }
275
FindUniqueHashes(size_t * next_node_id_p)276 void Signature::FindUniqueHashes(size_t* next_node_id_p) {
277 // Start by sorting by the hash value.
278 std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
279 SigNode::NodeOrderLess());
280
281 // At each call, if no nodes have unique hashes, one node that has a
282 // non-unique (shared) hash can be made unique by assigning a unique id.
283 // This node gets picked predictably by taking the last node.
284 // TODO(babkin): Technically, more than one node can be unshared,
285 // as long as their last_hashed_nodes_ overlap only by the nodes that
286 // already had the assigned ids before the current round. But it's not clear
287 // yet, how often would this beneficial, because it looks like for many
288 // subgraphs unsharing one node should be enough to untangle them. This
289 // would need more measurement before implementing.
290 bool found_unique = false;
291 for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
292 size_t cur_hash = nodes[n]->GetHighTopoHash();
293 if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
294 // A sequence of nodes sharing the same hash. Skip over it.
295 // TODO(babkin): check here for the arbitrary hash conflicts and resolve
296 // them.
297 for (++n;
298 n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
299 ++n) {
300 }
301 if (found_unique || n != nodes.size() - 1) {
302 // Either some unique nodes have already been found, or this is
303 // not the last chance, keep trying to find the unique nodes.
304 continue;
305 }
306 // Here we're at the last node and haven't found any unique ones.
307 // So fall through and make this last node unique.
308 }
309
310 found_unique = true;
311 size_t id = (*next_node_id_p)++;
312 nodes[n]->unique_rank_ = id;
313
314 size_t last_hash = nodes[n]->GetHighTopoHash();
315 CombineHash(last_hash, &sig_short);
316 sig_full.push_back(last_hash);
317
318 // Take the hash at 0 and mix the unique rank into it. After that it will
319 // stay fixed.
320 nodes[n]->topo_hash_.resize(1);
321 nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0.
322
323 nodes[n]->hash_is_final_ = true;
324 nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
325 if (n != id) {
326 std::swap(nodes[id], nodes[n]);
327 }
328 }
329 }
330
ComputeOneRound(size_t next_node_id)331 void Signature::ComputeOneRound(size_t next_node_id) {
332 // Reset the state of the nodes.
333 int debug_i = 0;
334 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
335 auto node = *it;
336 // The hash at distance 0 never changes, so preserve it.
337 node->topo_hash_.resize(1);
338 node->last_hashed_nodes_ = node->node_mask_;
339 node->hash_is_final_ = false;
340 if (debug) {
341 LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
342 << node->name() << " mask=" << std::hex
343 << node->last_hashed_nodes_;
344 }
345 }
346
347 bool stop = false;
348 // The distance can reach up to nodes.size()+1, to include not only all the
349 // nodes but also all the redundant paths.
350 for (int distance = 1; !stop; ++distance) {
351 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
352 auto node = *it;
353 if (node->hash_is_final_) {
354 continue;
355 }
356 node->ComputeTopoHash(distance);
357 if (node->GetHighTopoHash() <= nodes.size()) {
358 // Would conflict with one of the reserved values.
359 node->ReHighTopoHash();
360 }
361 }
362
363 // Will be looking for the indications to not stop.
364 stop = true;
365
366 debug_i = 0;
367 // The bitmasks get moved after all the hash computations are done.
368 for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
369 auto node = *it;
370 if (debug) {
371 LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
372 << " " << node->name() << " oldmask=" << std::hex
373 << node->last_hashed_nodes_ << " mask=" << std::hex
374 << node->next_hashed_nodes_;
375 }
376 if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
377 // Stopped growing, this part of the graph must be fully
378 // surrounded by nodes that already have the unique ids.
379 node->hash_is_final_ = true;
380 } else {
381 node->last_hashed_nodes_ = node->next_hashed_nodes_;
382 stop = false;
383 }
384 }
385 }
386 }
387
OrderLinks()388 void Signature::OrderLinks() {
389 for (const auto& node : nodes) {
390 if (node->hashed_peers_.empty()) {
391 continue;
392 }
393
394 size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
395 int first_idx = -1;
396
397 int idx;
398 for (idx = 0; idx < static_cast<int64>(node->hashed_peers_.size()); ++idx) {
399 auto& entry = node->hashed_peers_[idx];
400 if (entry.link_hash == cur_link_hash) {
401 continue;
402 }
403 if (idx - first_idx > 1) {
404 // Need to sort.
405 std::sort(node->hashed_peers_.begin() + first_idx,
406 node->hashed_peers_.begin() + idx,
407 SigNode::HashedPeer::LessByRank());
408 }
409
410 cur_link_hash = entry.link_hash;
411 first_idx = idx;
412 }
413 if (idx - first_idx > 1) {
414 // Sort the last bunch.
415 std::sort(node->hashed_peers_.begin() + first_idx,
416 node->hashed_peers_.begin() + idx,
417 SigNode::HashedPeer::LessByRank());
418 }
419 }
420 }
421
operator ==(const Signature & other) const422 bool Signature::operator==(const Signature& other) const {
423 // Tries to find the differences as early as possible by
424 // comparing the hashes first.
425
426 if (sig_short != other.sig_short) {
427 return false;
428 }
429 if (sig_full.size() != other.sig_full.size()) {
430 return false;
431 }
432
433 for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
434 it1 != sig_full.end(); ++it1, ++it2) {
435 if (*it1 != *it2) {
436 return false;
437 }
438 }
439
440 if (nodes.size() != other.nodes.size()) {
441 return false;
442 }
443 for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
444 ++it1, ++it2) {
445 if (**it1 != **it2) {
446 return false;
447 }
448 }
449
450 return true;
451 }
452
453 } // end namespace graph_analyzer
454 } // end namespace grappler
455 } // end namespace tensorflow
456