1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
23 #include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
24 #include "tensorflow/core/grappler/utils.h"
25
26 namespace tensorflow {
27 namespace grappler {
28 namespace graph_analyzer {
29 namespace test {
30
31 using ::testing::ElementsAre;
32 using ::testing::Eq;
33 using ::testing::Gt;
34 using ::testing::Ne;
35 using ::testing::SizeIs;
36
37 //===
38
TEST(SigNodeLinkTag,Compare)39 TEST(SigNodeLinkTag, Compare) {
40 SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2));
41 SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2));
42 SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1));
43 SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3));
44 SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2));
45
46 EXPECT_TRUE(a == b);
47 EXPECT_FALSE(a == c);
48 EXPECT_FALSE(a == e);
49
50 EXPECT_FALSE(a < b);
51 EXPECT_FALSE(b < a);
52
53 EXPECT_TRUE(a < c);
54 EXPECT_FALSE(c < a);
55
56 EXPECT_TRUE(a < d);
57 EXPECT_FALSE(d < a);
58 }
59
60 //===
61
62 class SigBaseTest : public ::testing::Test, protected TestGraphs {
63 protected:
BuildSigMap(const GraphDef & graph)64 void BuildSigMap(const GraphDef& graph) {
65 gen_map_.clear();
66 sig_.map.clear();
67 CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok());
68 Subgraph::Identity id;
69 for (const auto& entry : gen_map_) {
70 id.insert(entry.second.get());
71 }
72 Subgraph sg(id);
73 sg.ExtractForSignature(&sig_.map);
74 }
75
CopyLinksPass2(std::map<SigNode::LinkTag,SigNode::Link> * link_map,SigNode * node)76 static void CopyLinksPass2(
77 std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) {
78 node->CopyLinksPass2(link_map);
79 }
80
ComputeTopoHash0(SigNode * node)81 static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); }
82
ComputeTopoHash(int distance,SigNode * node)83 static void ComputeTopoHash(int distance, SigNode* node) {
84 node->ComputeTopoHash(distance);
85 }
86
GetTopoHash(int distance,SigNode * node)87 static size_t GetTopoHash(int distance, SigNode* node) {
88 return node->GetTopoHash(distance);
89 }
90
GetHighTopoHash(SigNode * node)91 static size_t GetHighTopoHash(SigNode* node) {
92 return node->GetHighTopoHash();
93 }
94
ReHighTopoHash(SigNode * node)95 static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); }
96
RefHashedPeers(SigNode * node)97 static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) {
98 return node->hashed_peers_;
99 }
RefUniqueRank(SigNode * node)100 static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; }
RefHashIsFinal(SigNode * node)101 static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; }
RefTopoHash(SigNode * node)102 static std::vector<size_t>& RefTopoHash(SigNode* node) {
103 return node->topo_hash_;
104 }
RefNodeMask(SigNode * node)105 static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; }
RefLastHashedNodes(SigNode * node)106 static uint64_t& RefLastHashedNodes(SigNode* node) {
107 return node->last_hashed_nodes_;
108 }
RefNextHashedNodes(SigNode * node)109 static uint64_t& RefNextHashedNodes(SigNode* node) {
110 return node->next_hashed_nodes_;
111 }
112
PrepareNodes(Signature * signature)113 static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); }
114
FindUniqueHashes(size_t * next_node_id_p,Signature * signature)115 static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) {
116 signature->FindUniqueHashes(next_node_id_p);
117 }
118
ComputeOneRound(size_t next_node_id,Signature * signature)119 static void ComputeOneRound(size_t next_node_id, Signature* signature) {
120 signature->ComputeOneRound(next_node_id);
121 }
122
OrderLinks(Signature * signature)123 static void OrderLinks(Signature* signature) { signature->OrderLinks(); }
124
125 // These get initialized in BuildSigMap().
126 GenNodeMap gen_map_;
127 Signature sig_;
128 };
129
130 //===
131
132 class SigNodeTest : public SigBaseTest {};
133
134 // Tests that the duplicate hashes get resolved by rehashing.
TEST_F(SigNodeTest,DuplicateHash)135 TEST_F(SigNodeTest, DuplicateHash) {
136 NodeDef node1 = MakeNodeConst("node1");
137 NodeDef node2 = MakeNodeConst("node2");
138 NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
139
140 SigNode sn1(&node1);
141 SigNode sn2(&node2);
142 SigNode sn3(&node3);
143
144 constexpr size_t kSameHash = 999;
145
146 SigNode::Link link1;
147 link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0));
148 link1.unique_hash = kSameHash;
149 link1.peers.emplace_back(&sn1);
150
151 SigNode::Link link2;
152 link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0));
153 link2.unique_hash = kSameHash;
154 link2.peers.emplace_back(&sn2);
155
156 SigNode::Link link3;
157 link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0));
158 link3.unique_hash = kSameHash;
159 link3.peers.emplace_back(&sn3);
160
161 std::map<SigNode::LinkTag, SigNode::Link> link_map;
162 link_map[link1.tag] = link1;
163 link_map[link2.tag] = link2;
164 link_map[link3.tag] = link3;
165
166 CopyLinksPass2(&link_map, &sn3);
167 auto& hl = sn3.hash_to_link();
168 EXPECT_THAT(hl, SizeIs(3));
169
170 // Check that the hashes are self_consistent, and put the entries into
171 // another map with a known order.
172 std::map<SigNode::LinkTag, SigNode::Link> rehashed;
173 auto hlit = hl.begin();
174 ASSERT_THAT(hlit, Ne(hl.end()));
175 EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
176 rehashed[hlit->second.tag] = hlit->second;
177 ++hlit;
178 ASSERT_THAT(hlit, Ne(hl.end()));
179 EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
180 rehashed[hlit->second.tag] = hlit->second;
181 ++hlit;
182 ASSERT_THAT(hlit, Ne(hl.end()));
183 EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
184 rehashed[hlit->second.tag] = hlit->second;
185
186 // Just in case.
187 ASSERT_THAT(rehashed, SizeIs(3));
188
189 auto rhit = rehashed.begin();
190 ASSERT_THAT(rhit, Ne(rehashed.end()));
191 EXPECT_TRUE(rhit->second.tag == link1.tag);
192 EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash));
193 EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1));
194
195 ++rhit;
196 ASSERT_THAT(rhit, Ne(rehashed.end()));
197 EXPECT_TRUE(rhit->second.tag == link2.tag);
198 // This hash must be rehashed.
199 EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
200 size_t hash2 = rhit->second.unique_hash;
201 EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2));
202
203 ++rhit;
204 ASSERT_THAT(rhit, Ne(rehashed.end()));
205 EXPECT_TRUE(rhit->second.tag == link3.tag);
206 // This hash must be rehashed.
207 EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
208 EXPECT_THAT(rhit->second.unique_hash, Ne(hash2));
209 size_t hash3 = rhit->second.unique_hash;
210 EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3));
211
212 auto& peers = sn3.hashed_peers();
213 EXPECT_THAT(peers, SizeIs(3));
214
215 auto peerit = peers.begin();
216 ASSERT_THAT(peerit, Ne(peers.end()));
217 EXPECT_THAT(peerit->link_hash, Eq(kSameHash));
218 EXPECT_THAT(peerit->peer, Eq(&sn1));
219
220 ++peerit;
221 ASSERT_THAT(peerit, Ne(peers.end()));
222 EXPECT_THAT(peerit->link_hash, Eq(hash2));
223 EXPECT_THAT(peerit->peer, Eq(&sn2));
224
225 ++peerit;
226 ASSERT_THAT(peerit, Ne(peers.end()));
227 EXPECT_THAT(peerit->link_hash, Eq(hash3));
228 EXPECT_THAT(peerit->peer, Eq(&sn3));
229 }
230
231 // The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature).
232
TEST_F(SigNodeTest,GetTopoHash)233 TEST_F(SigNodeTest, GetTopoHash) {
234 NodeDef node1 = MakeNodeConst("node1");
235 SigNode sn1(&node1);
236
237 // Fake some hash values.
238 RefTopoHash(&sn1).emplace_back(123);
239 RefTopoHash(&sn1).emplace_back(456);
240
241 EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
242 EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
243
244 RefHashIsFinal(&sn1) = true;
245
246 EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
247 EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
248 EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456));
249
250 EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456));
251 }
252
TEST_F(SigNodeTest,ReTopoHash)253 TEST_F(SigNodeTest, ReTopoHash) {
254 NodeDef node1 = MakeNodeConst("node1");
255 SigNode sn1(&node1);
256
257 // Fake some hash values.
258 RefTopoHash(&sn1).emplace_back(123);
259 RefTopoHash(&sn1).emplace_back(456);
260
261 EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
262 EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
263
264 ReHighTopoHash(&sn1);
265
266 size_t expected_hash = 456;
267 CombineHash(1, &expected_hash);
268
269 EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
270 EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash));
271 }
272
TEST_F(SigNodeTest,ComputeTopoHash0)273 TEST_F(SigNodeTest, ComputeTopoHash0) {
274 NodeDef node1 = MakeNodeConst("node1");
275 SigNode sn1(&node1);
276
277 // Fake a topology.
278 RefUniqueRank(&sn1) = 10;
279 RefNodeMask(&sn1) = 0x02;
280
281 RefTopoHash(&sn1).emplace_back(123);
282 RefTopoHash(&sn1).emplace_back(456);
283
284 // Fake a state.
285 RefLastHashedNodes(&sn1) = 0xFF;
286 RefNextHashedNodes(&sn1) = 0xFF;
287
288 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
289 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
290 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr));
291 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
292 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
293
294 // Run the test.
295 ComputeTopoHash0(&sn1);
296
297 EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02));
298 EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02));
299 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1));
300
301 size_t exp_hval = std::hash<string>()(sn1.opcode());
302 CombineHash(1, &exp_hval);
303 CombineHash(1, &exp_hval);
304 CombineHash(2, &exp_hval);
305 CombineHash(3, &exp_hval);
306 CombineHash(3, &exp_hval);
307
308 EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval));
309 }
310
TEST_F(SigNodeTest,ComputeTopoHashNotFinal)311 TEST_F(SigNodeTest, ComputeTopoHashNotFinal) {
312 NodeDef node1 = MakeNodeConst("node1");
313 SigNode sn1(&node1);
314 NodeDef node2 = MakeNodeConst("node2");
315 SigNode sn2(&node2);
316 NodeDef node3 = MakeNodeConst("node3");
317 SigNode sn3(&node3);
318
319 // Fake a topology.
320 RefUniqueRank(&sn1) = 0;
321 RefNodeMask(&sn1) = 0x01;
322 RefUniqueRank(&sn2) = 0;
323 RefNodeMask(&sn2) = 0x02;
324 RefUniqueRank(&sn3) = 0;
325 RefNodeMask(&sn3) = 0x04;
326
327 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
328 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
329 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
330 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
331 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
332
333 // Fake a state.
334 RefTopoHash(&sn1).emplace_back(123);
335 RefTopoHash(&sn1).emplace_back(321);
336
337 RefTopoHash(&sn2).emplace_back(456);
338 RefTopoHash(&sn2).emplace_back(654);
339
340 RefTopoHash(&sn3).emplace_back(789);
341 RefTopoHash(&sn3).emplace_back(987);
342
343 // These values are not realistic in the way that they don't include the bits
344 // from the mask of nodes themselves, but that's the point of this test: only
345 // the previous nodes' node sets are used in the computation, not their own
346 // masks directly.
347 RefLastHashedNodes(&sn1) = 0x8;
348 RefLastHashedNodes(&sn2) = 0x10;
349 RefLastHashedNodes(&sn3) = 0x20;
350
351 // A scratch value to get overwritten.
352 RefNextHashedNodes(&sn1) = 0x100;
353
354 ComputeTopoHash(2, &sn1);
355
356 EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
357 EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38));
358
359 // This computes the hash form the explicit numbers above.
360 size_t exp_hash = 123; // The 0th hash is the starting point.
361 size_t comm_hash;
362
363 comm_hash = 0;
364 CombineHashCommutative(654, &comm_hash);
365 CombineHashCommutative(987, &comm_hash);
366
367 CombineHash(10, &exp_hash);
368 CombineHash(comm_hash, &exp_hash);
369
370 comm_hash = 0;
371 CombineHashCommutative(654, &comm_hash);
372
373 CombineHash(20, &exp_hash);
374 CombineHash(comm_hash, &exp_hash);
375
376 comm_hash = 0;
377 CombineHashCommutative(654, &comm_hash);
378 CombineHashCommutative(987, &comm_hash);
379
380 CombineHash(30, &exp_hash);
381 CombineHash(comm_hash, &exp_hash);
382
383 EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash));
384 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3));
385 }
386
TEST_F(SigNodeTest,ComputeTopoHashFinal)387 TEST_F(SigNodeTest, ComputeTopoHashFinal) {
388 NodeDef node1 = MakeNodeConst("node1");
389 SigNode sn1(&node1);
390 NodeDef node2 = MakeNodeConst("node2");
391 SigNode sn2(&node2);
392 NodeDef node3 = MakeNodeConst("node3");
393 SigNode sn3(&node3);
394
395 // Fake a topology - same as for ComputeTopoHashNotFinal.
396 RefUniqueRank(&sn1) = 0;
397 RefNodeMask(&sn1) = 0x01;
398 RefUniqueRank(&sn2) = 0;
399 RefNodeMask(&sn2) = 0x02;
400 RefUniqueRank(&sn3) = 0;
401 RefNodeMask(&sn3) = 0x04;
402
403 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
404 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
405 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
406 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
407 RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
408
409 // Fake a state - mostly same as for ComputeTopoHashNotFinal.
410 RefTopoHash(&sn1).emplace_back(123);
411 RefTopoHash(&sn1).emplace_back(321);
412
413 RefTopoHash(&sn2).emplace_back(456);
414 RefTopoHash(&sn2).emplace_back(654);
415
416 RefTopoHash(&sn3).emplace_back(789);
417 RefTopoHash(&sn3).emplace_back(987);
418
419 // These values are not realistic in the way that they don't include the bits
420 // from the mask of nodes themselves, but that's the point of this test: only
421 // the previous nodes' node sets are used in the computation, not their own
422 // masks directly.
423 RefLastHashedNodes(&sn1) = 0x8;
424 RefLastHashedNodes(&sn2) = 0x10;
425 RefLastHashedNodes(&sn3) = 0x20;
426
427 // A scratch value to get overwritten.
428 RefNextHashedNodes(&sn1) = 0x100;
429
430 // This is the difference in configuration.
431 RefHashIsFinal(&sn1) = true;
432
433 ComputeTopoHash(2, &sn1);
434
435 EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
436 EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8));
437 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
438 EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321));
439 }
440
TEST_F(SigNodeTest,EqualsOpcode)441 TEST_F(SigNodeTest, EqualsOpcode) {
442 NodeDef node1 = MakeNodeConst("node1");
443 SigNode sn1(&node1);
444
445 NodeDef node2 = MakeNodeConst("node2");
446 SigNode sn2(&node2);
447
448 EXPECT_TRUE(sn1 == sn2);
449 EXPECT_FALSE(sn1 != sn2);
450
451 node2.set_op("Mul");
452
453 EXPECT_TRUE(sn1 != sn2);
454 EXPECT_FALSE(sn1 == sn2);
455 }
456
TEST_F(SigNodeTest,EqualsRank)457 TEST_F(SigNodeTest, EqualsRank) {
458 NodeDef node1 = MakeNodeConst("node1");
459 SigNode sn1(&node1);
460
461 NodeDef node2 = MakeNodeConst("node2");
462 SigNode sn2(&node2);
463
464 EXPECT_TRUE(sn1 == sn2);
465 EXPECT_FALSE(sn1 != sn2);
466
467 RefUniqueRank(&sn1) = 1;
468 RefUniqueRank(&sn2) = 2;
469
470 EXPECT_TRUE(sn1 != sn2);
471 EXPECT_FALSE(sn1 == sn2);
472 }
473
474 // Checks that if the nodes have a different number of links,
475 // they will be considered unequal.
TEST_F(SigNodeTest,EqualsLinkSize)476 TEST_F(SigNodeTest, EqualsLinkSize) {
477 GraphDef graph1;
478 (*graph1.add_node()) = MakeNodeConst("node1");
479 (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
480
481 GenNodeMap gen_map1;
482 ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
483
484 Subgraph::Identity id1;
485 id1.insert(gen_map1["node1"].get());
486 id1.insert(gen_map1["node2"].get());
487 Subgraph sg1(id1);
488
489 SigNodeMap sig_map1;
490 sg1.ExtractForSignature(&sig_map1);
491
492 GraphDef graph2;
493 (*graph2.add_node()) = MakeNodeConst("node1");
494 // The difference between graph1 and graph2: one more input.
495 auto node22 = graph2.add_node();
496 *node22 = MakeNodeMul("node2", "node1", "node1");
497 node22->add_input("node2");
498
499 GenNodeMap gen_map2;
500 ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK()));
501
502 Subgraph::Identity id2;
503 id2.insert(gen_map2["node1"].get());
504 id2.insert(gen_map2["node2"].get());
505 Subgraph sg2(id2);
506
507 SigNodeMap sig_map2;
508 sg2.ExtractForSignature(&sig_map2);
509
510 EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
511 EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
512 EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]);
513 }
514
TEST_F(SigNodeTest,EqualsLinks)515 TEST_F(SigNodeTest, EqualsLinks) {
516 // Start with 2 copies of the same graph.
517 GraphDef graph1;
518 (*graph1.add_node()) = MakeNodeConst("node1");
519 (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
520
521 GenNodeMap gen_map1;
522 ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
523
524 Subgraph::Identity id1;
525 id1.insert(gen_map1["node1"].get());
526 id1.insert(gen_map1["node2"].get());
527 Subgraph sg1(id1);
528
529 SigNodeMap sig_map1;
530 sg1.ExtractForSignature(&sig_map1);
531
532 GenNodeMap gen_map2;
533 ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK()));
534
535 Subgraph::Identity id2;
536 id2.insert(gen_map2["node1"].get());
537 id2.insert(gen_map2["node2"].get());
538 Subgraph sg2(id2);
539
540 SigNodeMap sig_map2;
541 sg2.ExtractForSignature(&sig_map2);
542
543 EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
544 EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
545
546 // Alter the link hash of one of the nodes.
547 SigNode* sn2 = sig_map2["node2"].get();
548 ++RefHashedPeers(sn2)[0].link_hash;
549
550 EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
551
552 // Restore back.
553 --RefHashedPeers(sn2)[0].link_hash;
554 EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
555
556 // Alter the unique rank of a referenced node.
557 ++RefUniqueRank(sig_map2["node1"].get());
558
559 EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
560 }
561
562 //===
563
564 class SignatureTest : public SigBaseTest {
565 protected:
566 // Initializeds the state used to generate the permutations of a given size.
InitPermutation(size_t size,std::vector<size_t> * plain_permutation,std::vector<size_t> * countdown)567 static void InitPermutation(size_t size,
568 std::vector<size_t>* plain_permutation,
569 std::vector<size_t>* countdown) {
570 plain_permutation->clear();
571 countdown->clear();
572 for (size_t i = 0; i < size; ++i) {
573 plain_permutation->emplace_back(i);
574 countdown->emplace_back(size - 1 - i);
575 }
576 }
577
578 // Builds a permutation guided by the count-down value.
BuildPermutation(const std::vector<size_t> & plain_permutation,const std::vector<size_t> & countdown,std::vector<size_t> * result)579 static void BuildPermutation(const std::vector<size_t>& plain_permutation,
580 const std::vector<size_t>& countdown,
581 std::vector<size_t>* result) {
582 *result = plain_permutation;
583 for (int i = 0; i < result->size(); ++i) {
584 std::swap((*result)[i], (*result)[i + countdown[i]]);
585 }
586 }
587
588 // Returns false when the count-down is finished.
CountDown(std::vector<size_t> * countdown)589 static bool CountDown(std::vector<size_t>* countdown) {
590 // The last position always contains 0, so skip it.
591 int pos;
592 for (pos = countdown->size() - 2; pos >= 0; --pos) {
593 if ((*countdown)[pos] > 0) {
594 --(*countdown)[pos];
595 break;
596 }
597 (*countdown)[pos] = (countdown->size() - 1 - pos);
598 }
599
600 return pos >= 0;
601 }
602
603 // Permutes the nodes every which way and checks that all the signatures
604 // produced are the same. This is reasonable for the graphs up to the
605 // size 5, maybe 6 at the stretch. After that the number of permutation grows
606 // huge and the test becomes very slow.
TestGraphEveryWay(const GraphDef & graph)607 void TestGraphEveryWay(const GraphDef& graph) {
608 size_t graph_size = graph.node_size();
609
610 gen_map_.clear();
611 sig_.map.clear();
612 Status result = GenNode::BuildGraphInMap(graph, &gen_map_);
613 ASSERT_THAT(result, Eq(Status::OK()));
614 Subgraph::Identity id;
615 for (const auto& entry : gen_map_) {
616 id.insert(entry.second.get());
617 }
618 Subgraph sg(id);
619 sg.ExtractForSignature(&sig_.map);
620
621 std::vector<size_t> plain_permutation;
622 std::vector<size_t> countdown;
623 InitPermutation(graph_size, &plain_permutation, &countdown);
624
625 std::set<string> signatures;
626 std::vector<size_t> permutation;
627 do {
628 BuildPermutation(plain_permutation, countdown, &permutation);
629
630 constexpr bool kDebugPermutation = false;
631 if (kDebugPermutation) {
632 string p;
633 for (int i = 0; i < permutation.size(); ++i) {
634 p.push_back('0' + permutation[i]);
635 }
636 LOG(INFO) << "Permutation: " << p;
637 }
638
639 std::vector<std::unique_ptr<SigNode>> hold(graph_size);
640 int idx;
641
642 // Permute the nodes.
643 sig_.nodes.clear();
644 idx = 0;
645 if (kDebugPermutation) {
646 LOG(INFO) << " nodes before permutation:";
647 }
648 for (auto& entry : sig_.map) {
649 if (kDebugPermutation) {
650 LOG(INFO) << " " << entry.second.get();
651 }
652 hold[idx++] = std::move(entry.second);
653 }
654 idx = 0;
655 if (kDebugPermutation) {
656 LOG(INFO) << " nodes after permutation:";
657 }
658 for (auto& entry : sig_.map) {
659 entry.second = std::move(hold[permutation[idx++]]);
660 if (kDebugPermutation) {
661 LOG(INFO) << " " << entry.second.get();
662 }
663 // This is used to order the links per permutation.
664 sig_.nodes.emplace_back(entry.second.get());
665 RefUniqueRank(entry.second.get()) = idx;
666 }
667 // Order the links with the same tags per permutation.
668 OrderLinks(&sig_);
669
670 // The test as such.
671 ASSERT_THAT(sig_.Compute(), Eq(Status::OK()));
672
673 signatures.insert(sig_.ToString());
674
675 EXPECT_THAT(sig_.sig_full, SizeIs(graph_size));
676 size_t hval = 0;
677 for (size_t ih : sig_.sig_full) {
678 // The space 1..graph_size is reserved.
679 EXPECT_THAT(ih, Gt(graph_size));
680 CombineHash(ih, &hval);
681 }
682 EXPECT_THAT(sig_.sig_short, Eq(hval));
683
684 // Un-permute the nodes for the next iteration.
685 idx = 0;
686 for (auto& entry : sig_.map) {
687 hold[permutation[idx++]] = std::move(entry.second);
688 }
689 idx = 0;
690 if (kDebugPermutation) {
691 LOG(INFO) << " nodes after un-permutation:";
692 }
693 for (auto& entry : sig_.map) {
694 entry.second = std::move(hold[idx++]);
695 if (kDebugPermutation) {
696 LOG(INFO) << " " << entry.second.get();
697 }
698 }
699 } while (CountDown(&countdown));
700
701 for (const auto& s : signatures) {
702 LOG(INFO) << "Signature: " << s;
703 }
704
705 // All the permutations should produce the same signature.
706 EXPECT_THAT(signatures, SizeIs(1));
707 }
708 };
709
TEST_F(SignatureTest,PrepareNodes)710 TEST_F(SignatureTest, PrepareNodes) {
711 NodeDef node1 = MakeNodeConst("node1");
712 sig_.map["node1"] = absl::make_unique<SigNode>(&node1);
713 NodeDef node2 = MakeNodeConst("node2");
714 sig_.map["node2"] = absl::make_unique<SigNode>(&node2);
715 NodeDef node3 = MakeNodeConst("node3");
716 sig_.map["node3"] = absl::make_unique<SigNode>(&node3);
717
718 PrepareNodes(&sig_);
719
720 ASSERT_THAT(sig_.nodes, SizeIs(3));
721
722 int idx = 0;
723 for (const auto& entry : sig_.map) {
724 EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx))
725 << " at index " << idx;
726 EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0)))
727 << " at index " << idx;
728 EXPECT_THAT(RefHashIsFinal(entry.second.get()), false)
729 << " at index " << idx;
730 EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1))
731 << " at index " << idx;
732 ++idx;
733 }
734 }
735
TEST_F(SignatureTest,FindUniqueHashesAllDifferent)736 TEST_F(SignatureTest, FindUniqueHashesAllDifferent) {
737 NodeDef node1 = MakeNodeConst("node1");
738 SigNode sn1(&node1);
739 NodeDef node2 = MakeNodeConst("node2");
740 SigNode sn2(&node2);
741 NodeDef node3 = MakeNodeConst("node3");
742 SigNode sn3(&node3);
743 NodeDef node4 = MakeNodeConst("node4");
744 SigNode sn4(&node4);
745
746 // The last values in the arrays values go in the backwards order.
747 RefTopoHash(&sn1).emplace_back(100);
748 RefTopoHash(&sn1).emplace_back(900);
749
750 RefTopoHash(&sn2).emplace_back(200);
751 RefTopoHash(&sn2).emplace_back(800);
752
753 RefTopoHash(&sn3).emplace_back(300);
754 RefTopoHash(&sn3).emplace_back(700);
755
756 RefTopoHash(&sn4).emplace_back(400);
757 RefTopoHash(&sn4).emplace_back(600);
758
759 sig_.nodes.emplace_back(&sn1);
760 sig_.nodes.emplace_back(&sn2);
761 sig_.nodes.emplace_back(&sn3);
762 sig_.nodes.emplace_back(&sn4);
763
764 size_t next = 1; // Skips over sn1.
765
766 FindUniqueHashes(&next, &sig_);
767 EXPECT_THAT(next, Eq(4));
768
769 EXPECT_THAT(sig_.nodes[0], Eq(&sn1));
770 // The nodes after first one get sorted by the high hash.
771 EXPECT_THAT(sig_.nodes[1], Eq(&sn4));
772 EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
773 EXPECT_THAT(sig_.nodes[3], Eq(&sn2));
774
775 EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
776 // Nodes that get finalized are marked as such.
777 EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true));
778 EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
779 EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true));
780
781 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
782 ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1));
783 ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1));
784 ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1));
785
786 EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4));
787 EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3));
788 EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2));
789
790 EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800));
791
792 size_t exp_short_hash = 0;
793 CombineHash(600, &exp_short_hash);
794 CombineHash(700, &exp_short_hash);
795 CombineHash(800, &exp_short_hash);
796 EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash));
797 }
798
TEST_F(SignatureTest,FindUniqueHashesDuplicatesExceptOne)799 TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) {
800 NodeDef node1 = MakeNodeConst("node1");
801 SigNode sn1(&node1);
802 NodeDef node2 = MakeNodeConst("node2");
803 SigNode sn2(&node2);
804 NodeDef node3 = MakeNodeConst("node3");
805 SigNode sn3(&node3);
806 NodeDef node4 = MakeNodeConst("node4");
807 SigNode sn4(&node4);
808 NodeDef node5 = MakeNodeConst("node5");
809 SigNode sn5(&node5);
810
811 RefTopoHash(&sn1).emplace_back(100);
812 RefTopoHash(&sn1).emplace_back(600);
813
814 RefTopoHash(&sn2).emplace_back(200);
815 RefTopoHash(&sn2).emplace_back(600);
816
817 RefTopoHash(&sn3).emplace_back(300);
818 RefTopoHash(&sn3).emplace_back(700);
819
820 RefTopoHash(&sn4).emplace_back(400);
821 RefTopoHash(&sn4).emplace_back(800);
822
823 RefTopoHash(&sn5).emplace_back(500);
824 RefTopoHash(&sn5).emplace_back(800);
825
826 sig_.nodes.emplace_back(&sn1);
827 sig_.nodes.emplace_back(&sn2);
828 sig_.nodes.emplace_back(&sn3);
829 sig_.nodes.emplace_back(&sn4);
830 sig_.nodes.emplace_back(&sn5);
831
832 size_t next = 0;
833
834 FindUniqueHashes(&next, &sig_);
835 EXPECT_THAT(next, Eq(1));
836
837 // The unique node goes first.
838 EXPECT_THAT(sig_.nodes[0], Eq(&sn3));
839
840 // The rest of the nodes are assumed to be sorted in a stable order.
841 EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
842 // Node 1 gets swapped with node 3.
843 EXPECT_THAT(sig_.nodes[2], Eq(&sn1));
844 EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
845 EXPECT_THAT(sig_.nodes[4], Eq(&sn5));
846
847 EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
848 EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
849 EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
850 EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
851 EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false));
852
853 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
854 EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
855 EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1));
856 EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
857 EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2));
858
859 EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1));
860 }
861
TEST_F(SignatureTest,FindUniqueHashesDuplicates)862 TEST_F(SignatureTest, FindUniqueHashesDuplicates) {
863 NodeDef node1 = MakeNodeConst("node1");
864 SigNode sn1(&node1);
865 NodeDef node2 = MakeNodeConst("node2");
866 SigNode sn2(&node2);
867 NodeDef node3 = MakeNodeConst("node3");
868 SigNode sn3(&node3);
869 NodeDef node4 = MakeNodeConst("node4");
870 SigNode sn4(&node4);
871 NodeDef node5 = MakeNodeConst("node5");
872 SigNode sn5(&node5);
873
874 RefTopoHash(&sn1).emplace_back(100);
875 RefTopoHash(&sn1).emplace_back(600);
876
877 RefTopoHash(&sn2).emplace_back(200);
878 RefTopoHash(&sn2).emplace_back(600);
879
880 RefTopoHash(&sn3).emplace_back(300);
881 RefTopoHash(&sn3).emplace_back(700);
882
883 RefTopoHash(&sn4).emplace_back(400);
884 RefTopoHash(&sn4).emplace_back(700);
885
886 RefTopoHash(&sn5).emplace_back(500);
887 RefTopoHash(&sn5).emplace_back(700);
888
889 sig_.nodes.emplace_back(&sn1);
890 sig_.nodes.emplace_back(&sn2);
891 sig_.nodes.emplace_back(&sn3);
892 sig_.nodes.emplace_back(&sn4);
893 sig_.nodes.emplace_back(&sn5);
894
895 size_t next = 0;
896
897 FindUniqueHashes(&next, &sig_);
898 EXPECT_THAT(next, Eq(1));
899
900 // The last copy of the last duplicate wins.
901 EXPECT_THAT(sig_.nodes[0], Eq(&sn5));
902
903 // The rest of the nodes are assumed to be sorted in a stable order.
904 // Node 1 gets swapped.
905 EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
906 EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
907 EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
908 EXPECT_THAT(sig_.nodes[4], Eq(&sn1));
909
910 EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
911 EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
912 EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false));
913 EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
914 EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true));
915
916 EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
917 EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
918 EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2));
919 EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
920 EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1));
921
922 EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1));
923 }
924
925 // On a circular topology.
TEST_F(SignatureTest,ComputeOneRoundCircular)926 TEST_F(SignatureTest, ComputeOneRoundCircular) {
927 BuildSigMap(graph_circular_onedir_);
928 PrepareNodes(&sig_);
929
930 ASSERT_THAT(sig_.nodes, SizeIs(5));
931
932 // This skips FindUniqueHashes() which would pick one node, so that
933 // all the nodes are equivalent for ComputeOneRound().
934
935 ComputeOneRound(0, &sig_);
936
937 // All the nodes are the same, so the computed hashes will also be the same.
938 size_t hval = GetHighTopoHash(sig_.nodes[0]);
939 for (int i = 0; i < 5; ++i) {
940 EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i;
941 EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
942 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
943 << " at index " << i;
944 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
945 << " at index " << i;
946 // The sets of hashed nodes go like this:
947 // Step 0: self.
948 // Step 1: self, previous (-1) and next (+1) node.
949 // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
950 // Step 3: still all 5 nodes in the graph
951 EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i;
952 }
953 }
954
955 // On a linear topology.
TEST_F(SignatureTest,ComputeOneRoundLinear)956 TEST_F(SignatureTest, ComputeOneRoundLinear) {
957 BuildSigMap(graph_linear_);
958 PrepareNodes(&sig_);
959
960 ASSERT_THAT(sig_.nodes, SizeIs(5));
961
962 // This skips FindUniqueHashes() which would pick one node, so that
963 // all the nodes are equivalent for ComputeOneRound().
964
965 ComputeOneRound(0, &sig_);
966
967 std::vector<size_t> hash_size;
968 for (int i = 0; i < 5; ++i) {
969 EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
970 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
971 << " at index " << i;
972 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
973 << " at index " << i;
974 hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
975 }
976
977 // The sets of hashed nodes for the central node go like this:
978 // Step 0: self.
979 // Step 1: self, previous (-1) and next (+1) node.
980 // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
981 // Step 3: still all 5 nodes in the graph
982 //
983 // The nodes one step closer to the ends require one more step. The end nodes
984 // require one more step yet.
985 std::sort(hash_size.begin(), hash_size.end());
986 EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6));
987 }
988
989 // On a linear topology where the central node has been already marked as unique
990 // (yeah, not a very realistic case but tests the situations when the
991 // disconnected subgraphs get created).
TEST_F(SignatureTest,ComputeOneRoundSplitLinear)992 TEST_F(SignatureTest, ComputeOneRoundSplitLinear) {
993 BuildSigMap(graph_linear_);
994 PrepareNodes(&sig_);
995
996 ASSERT_THAT(sig_.nodes, SizeIs(5));
997
998 // This test relies on the order of SigNodeMap imposed on sig_.nodes.
999
1000 // The middle node gets separated by moving it to the front.
1001 std::swap(sig_.nodes[0], sig_.nodes[2]);
1002 ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04));
1003 ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
1004 ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
1005 RefHashIsFinal(sig_.nodes[0]) = true;
1006
1007 ComputeOneRound(1, &sig_);
1008
1009 // These should stay unchanged.
1010 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
1011 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
1012
1013 std::vector<size_t> hash_size;
1014 for (int i = 1; i < 5; ++i) {
1015 EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
1016 hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
1017 }
1018
1019 std::sort(hash_size.begin(), hash_size.end());
1020 // The end nodes take 4 steps, closer to the center 3 steps.
1021 EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4));
1022
1023 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07));
1024 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07));
1025 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07));
1026 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07));
1027
1028 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C));
1029 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C));
1030 EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C));
1031 EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C));
1032 }
1033
TEST_F(SignatureTest,OrderLinks)1034 TEST_F(SignatureTest, OrderLinks) {
1035 gen_map_.clear();
1036 sig_.map.clear();
1037 Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_);
1038 ASSERT_THAT(result, Eq(Status::OK()));
1039 Subgraph::Identity id;
1040 for (const auto& entry : gen_map_) {
1041 id.insert(entry.second.get());
1042 }
1043 Subgraph sg(id);
1044 sg.ExtractForSignature(&sig_.map);
1045
1046 // Populate the fake signature and assign the ranks in the backwards order.
1047 for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) {
1048 auto& entry = *it;
1049 RefUniqueRank(entry.second.get()) = sig_.nodes.size();
1050 sig_.nodes.emplace_back(entry.second.get());
1051 }
1052
1053 // How it was ordered in the original graph.
1054 string before = sig_.ToString();
1055 // clang-format off
1056 EXPECT_THAT(before, Eq(
1057 "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2],"
1058 "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2],"
1059 "2:Const,"
1060 "3:Const,"
1061 "4:Const,"
1062 "5:Const,"
1063 ));
1064 // clang-format on
1065
1066 OrderLinks(&sig_);
1067
1068 string after = sig_.ToString();
1069 // clang-format off
1070 EXPECT_THAT(after, Eq(
1071 "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2],"
1072 "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5],"
1073 "2:Const,"
1074 "3:Const,"
1075 "4:Const,"
1076 "5:Const,"
1077 ));
1078 // clang-format on
1079 }
1080
TEST_F(SignatureTest,GraphTooBig)1081 TEST_F(SignatureTest, GraphTooBig) {
1082 GraphDef graph;
1083 for (int i = 0; i <= Signature::kMaxGraphSize; ++i) {
1084 (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i));
1085 }
1086
1087 ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK()));
1088
1089 Subgraph::Identity id;
1090 for (const auto& entry : gen_map_) {
1091 id.insert(entry.second.get());
1092 }
1093 Subgraph sg(id);
1094 sg.ExtractForSignature(&sig_.map);
1095
1096 ASSERT_THAT(sig_.Compute(),
1097 Eq(Status(error::INVALID_ARGUMENT,
1098 "A graph of 65 nodes is too big for signature "
1099 "computation, the maximal supported node count is "
1100 "64.")));
1101 }
1102
TEST_F(SignatureTest,ToString)1103 TEST_F(SignatureTest, ToString) {
1104 BuildSigMap(graph_circular_onedir_);
1105 PrepareNodes(&sig_);
1106
1107 ASSERT_THAT(sig_.nodes, SizeIs(5));
1108
1109 // Fake the works by assigning unique ranks as they go in the initial order.
1110 for (int i = 0; i < 5; ++i) {
1111 RefUniqueRank(sig_.nodes[i]) = i;
1112 RefHashIsFinal(sig_.nodes[i]) = true;
1113 }
1114
1115 string result = sig_.ToString();
1116
1117 // clang-format off
1118 ASSERT_THAT(result, Eq(
1119 "0:Mul[i0:o0:4][i0:o0:4],"
1120 "1:Mul[i0:o0:0][i0:o0:0],"
1121 "2:Mul[i0:o0:1][i0:o0:1],"
1122 "3:Mul[i0:o0:2][i0:o0:2],"
1123 "4:Mul[i0:o0:3][i0:o0:3],"
1124 ));
1125 // clang-format on
1126 }
1127
1128 // This is a test of the permutation logic itself.
TEST_F(SignatureTest,Permutation)1129 TEST_F(SignatureTest, Permutation) {
1130 std::vector<size_t> plain_permutation;
1131 std::vector<size_t> countdown;
1132 InitPermutation(5, &plain_permutation, &countdown);
1133
1134 std::set<string> results;
1135
1136 std::vector<size_t> permutation;
1137 do {
1138 BuildPermutation(plain_permutation, countdown, &permutation);
1139 EXPECT_THAT(permutation, SizeIs(5));
1140
1141 string p;
1142 for (int i = 0; i < permutation.size(); ++i) {
1143 p.push_back('0' + permutation[i]);
1144 }
1145 LOG(INFO) << "Permutation: " << p;
1146 results.insert(p);
1147 } while (CountDown(&countdown));
1148
1149 EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1));
1150 }
1151
TEST_F(SignatureTest,ComputeCircularOneDir)1152 TEST_F(SignatureTest, ComputeCircularOneDir) {
1153 TestGraphEveryWay(graph_circular_onedir_);
1154 }
1155
TEST_F(SignatureTest,ComputeCircularBiDir)1156 TEST_F(SignatureTest, ComputeCircularBiDir) {
1157 TestGraphEveryWay(graph_circular_bidir_);
1158 }
1159
TEST_F(SignatureTest,ComputeLinear)1160 TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); }
1161
TEST_F(SignatureTest,ComputeMultiInput)1162 TEST_F(SignatureTest, ComputeMultiInput) {
1163 TestGraphEveryWay(graph_multi_input_);
1164 }
1165
TEST_F(SignatureTest,ComputeAllOrNone)1166 TEST_F(SignatureTest, ComputeAllOrNone) {
1167 TestGraphEveryWay(graph_all_or_none_);
1168 }
1169
TEST_F(SignatureTest,ComputeCross)1170 TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); }
1171
TEST_F(SignatureTest,Equals)1172 TEST_F(SignatureTest, Equals) {
1173 // Start with 2 copies of the same graph.
1174 GenNodeMap gen_map1;
1175 ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1),
1176 Eq(Status::OK()));
1177
1178 Subgraph::Identity id1;
1179 id1.insert(gen_map1["node1"].get());
1180 id1.insert(gen_map1["node2"].get());
1181 Subgraph sg1(id1);
1182
1183 Signature sig1;
1184 sg1.ExtractForSignature(&sig1.map);
1185 ASSERT_THAT(sig1.Compute(), Eq(Status::OK()));
1186
1187 GenNodeMap gen_map2;
1188 ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2),
1189 Eq(Status::OK()));
1190
1191 Subgraph::Identity id2;
1192 id2.insert(gen_map2["node1"].get());
1193 id2.insert(gen_map2["node2"].get());
1194 Subgraph sg2(id2);
1195
1196 Signature sig2;
1197 sg2.ExtractForSignature(&sig2.map);
1198 ASSERT_THAT(sig2.Compute(), Eq(Status::OK()));
1199
1200 EXPECT_TRUE(sig1 == sig2);
1201
1202 // Change the short hash.
1203 ++sig2.sig_short;
1204 EXPECT_FALSE(sig1 == sig2);
1205
1206 // Restore back.
1207 --sig2.sig_short;
1208 EXPECT_TRUE(sig1 == sig2);
1209
1210 // Change the full hash.
1211 ++sig2.sig_full[0];
1212 EXPECT_FALSE(sig1 == sig2);
1213
1214 // Restore back.
1215 --sig2.sig_full[0];
1216 EXPECT_TRUE(sig1 == sig2);
1217
1218 // Make the nodes different.
1219 std::swap(sig2.nodes[0], sig2.nodes[1]);
1220 EXPECT_FALSE(sig1 == sig2);
1221
1222 // Restore back.
1223 std::swap(sig2.nodes[0], sig2.nodes[1]);
1224 EXPECT_TRUE(sig1 == sig2);
1225
1226 // Different number of nodes.
1227 sig2.nodes.emplace_back(sig2.nodes[0]);
1228 EXPECT_FALSE(sig1 == sig2);
1229 EXPECT_FALSE(sig2 == sig1);
1230 }
1231
1232 } // end namespace test
1233 } // end namespace graph_analyzer
1234 } // end namespace grappler
1235 } // end namespace tensorflow
1236