1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/hash_utils.h"
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/protobuf/config.pb.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/util/work_sharder.h"
30
31 namespace tensorflow {
32 namespace data {
33 namespace {
34 using ::testing::ContainsRegex;
35
36 class DatasetHashUtilsTest : public ::testing::Test {
37 protected:
GetHash(const FunctionDefLibrary & library,const FunctionDef & fn)38 uint64 GetHash(const FunctionDefLibrary& library, const FunctionDef& fn) {
39 // Construct a node with a function as an attr.
40 GraphDef graph_def;
41 *graph_def.mutable_library() = library;
42 NodeDef* node = graph_def.add_node();
43 node->set_op("RemoteCall");
44 NameAttrList func;
45 func.set_name(fn.signature().name());
46 AddNodeAttr("f", func, node);
47 uint64 hash = 0;
48 TF_CHECK_OK(HashNode(graph_def, *node, &hash));
49 return hash;
50 }
51
CheckEqual(const FunctionDefLibrary & library,const FunctionDef & fn1,const FunctionDef & fn2)52 Status CheckEqual(const FunctionDefLibrary& library, const FunctionDef& fn1,
53 const FunctionDef& fn2) {
54 // Construct nodes with a function as an attr.
55 GraphDef graph_def;
56 *graph_def.mutable_library() = library;
57
58 NodeDef* node1 = graph_def.add_node();
59 node1->set_name("RemoteCall");
60 node1->set_op("RemoteCall");
61 NameAttrList func1;
62 func1.set_name(fn1.signature().name());
63 AddNodeAttr("f", func1, node1);
64
65 NodeDef* node2 = graph_def.add_node();
66 node1->set_name("RemoteCall2");
67 node2->set_op("RemoteCall");
68 NameAttrList func2;
69 func2.set_name(fn2.signature().name());
70 AddNodeAttr("f", func2, node2);
71
72 return CheckSubgraphsEqual(graph_def, node1, graph_def, node2);
73 }
74
GetHash(const GraphDef & graph,const NodeDef & node)75 uint64 GetHash(const GraphDef& graph, const NodeDef& node) {
76 uint64 hash = 0;
77 TF_CHECK_OK(HashNode(graph, node, &hash));
78 return hash;
79 }
80
GetHash(const Tensor & tensor)81 uint64 GetHash(const Tensor& tensor) {
82 uint64 hash = 0;
83 TF_CHECK_OK(HashTensor(tensor, &hash));
84 return hash;
85 }
86 };
87
TEST_F(DatasetHashUtilsTest,HashFunctionSameFunctionDifferentNames)88 TEST_F(DatasetHashUtilsTest, HashFunctionSameFunctionDifferentNames) {
89 FunctionDefLibrary fl;
90
91 FunctionDef* f1 = fl.add_function();
92 *f1 = FunctionDefHelper::Create(
93 "AddAndMul", {"i: float"}, {"o: float"}, {},
94 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
95 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
96 /*ret_def=*/{{"o", "ret:z:0"}},
97 /*control_ret_def=*/{{"must_execute", "add"}});
98
99 FunctionDef* f2 = fl.add_function();
100 *f2 = FunctionDefHelper::Create(
101 "AddAndMul2", {"input: float"}, {"o: float"}, {},
102 {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
103 {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
104 /*ret_def=*/{{"o", "ret:z:0"}},
105 /*control_ret_def=*/{{"must_execute", "add"}});
106
107 EXPECT_EQ(GetHash(fl, *f1), GetHash(fl, *f2));
108 TF_EXPECT_OK(CheckEqual(fl, *f1, *f2));
109 }
110
TEST_F(DatasetHashUtilsTest,HashFunctionDifferentFunctions)111 TEST_F(DatasetHashUtilsTest, HashFunctionDifferentFunctions) {
112 FunctionDefLibrary fl;
113
114 FunctionDef* f1 = fl.add_function();
115 *f1 = FunctionDefHelper::Create(
116 "AddAndMul", {"i: float"}, {"o: float"}, {},
117 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
118 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
119 /*ret_def=*/{{"o", "ret:z:0"}},
120 /*control_ret_def=*/{{"must_execute", "add"}});
121
122 FunctionDef* f2 = fl.add_function();
123 *f2 = FunctionDefHelper::Create(
124 "AddAndAdd", {"i: float"}, {"o: float"}, {},
125 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
126 {{"ret"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}},
127 /*ret_def=*/{{"o", "ret:z:0"}},
128 /*control_ret_def=*/{{"must_execute", "add"}});
129
130 // The second op in `f2` is changed to "Add"
131 EXPECT_NE(GetHash(fl, *f1), GetHash(fl, *f2));
132 Status s = CheckEqual(fl, *f1, *f2);
133 EXPECT_NE(s.code(), error::OK);
134 EXPECT_THAT(s.error_message(), ContainsRegex("Add"));
135 }
136
TEST_F(DatasetHashUtilsTest,HashFunctionDifferentInternalNodeNames)137 TEST_F(DatasetHashUtilsTest, HashFunctionDifferentInternalNodeNames) {
138 FunctionDefLibrary fl;
139
140 FunctionDef* f1 = fl.add_function();
141 *f1 = FunctionDefHelper::Create(
142 "AddAndMul", {"i: float", "j: float", "k: float"}, {"o: float"}, {},
143 {{{"add"}, "Add", {"i", "j"}, {{"T", DT_FLOAT}}},
144 {{"ret"}, "Mul", {"add:z:0", "k"}, {{"T", DT_FLOAT}}}},
145 /*ret_def=*/{{"o", "ret:z:0"}},
146 /*control_ret_def=*/{{"must_execute", "ret"}});
147
148 FunctionDef* f2 = fl.add_function();
149 *f2 = FunctionDefHelper::Create(
150 "AddAndMul2", {"a: float", "b: float", "c: float"}, {"o: float"}, {},
151 {{{"add"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
152 {{"mul"}, "Mul", {"add:z:0", "c"}, {{"T", DT_FLOAT}}}},
153 /*ret_def=*/{{"o", "mul:z:0"}},
154 /*control_ret_def=*/{{"must_execute", "mul"}});
155
156 EXPECT_EQ(GetHash(fl, *f1), GetHash(fl, *f2));
157 TF_EXPECT_OK(CheckEqual(fl, *f1, *f2));
158 }
159
TEST_F(DatasetHashUtilsTest,HashGraphWithMultipleCycles)160 TEST_F(DatasetHashUtilsTest, HashGraphWithMultipleCycles) {
161 uint64 hash = 0;
162 for (int i = 0; i < 1000; ++i) {
163 GraphDef g;
164 NodeDef* output_node = g.add_node();
165 TF_CHECK_OK(NodeDefBuilder("O", "Add")
166 .Input("A", 0, DT_FLOAT)
167 .Input("D", 0, DT_FLOAT)
168 .Finalize(output_node));
169 TF_CHECK_OK(NodeDefBuilder("A", "Abs")
170 .Input("B", 0, DT_FLOAT)
171 .Finalize(g.add_node()));
172 TF_CHECK_OK(NodeDefBuilder("B", "Add")
173 .Input("C", 0, DT_FLOAT)
174 .Input("D", 0, DT_FLOAT)
175 .Finalize(g.add_node()));
176 TF_CHECK_OK(NodeDefBuilder("C", "Ceil")
177 .Input("A", 0, DT_FLOAT)
178 .Finalize(g.add_node()));
179 TF_CHECK_OK(NodeDefBuilder("D", "Cos")
180 .Input("E", 0, DT_FLOAT)
181 .Finalize(g.add_node()));
182 TF_CHECK_OK(NodeDefBuilder("E", "Floor")
183 .Input("B", 0, DT_FLOAT)
184 .Finalize(g.add_node()));
185 uint64 t = GetHash(g, *output_node);
186 if (hash == 0) {
187 hash = t;
188 } else {
189 EXPECT_EQ(t, hash);
190 }
191 }
192 }
193
TEST_F(DatasetHashUtilsTest,HashNodeSameGraphDifferentNames)194 TEST_F(DatasetHashUtilsTest, HashNodeSameGraphDifferentNames) {
195 GraphDef gd;
196
197 NodeDef* n1 = gd.add_node();
198 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
199 .Attr("value", 1)
200 .Device("CPU:0")
201 .Finalize(n1));
202
203 NodeDef* n2 = gd.add_node();
204 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
205 .Attr("value", 2)
206 .Device("CPU:0")
207 .Finalize(n2));
208
209 NodeDef* n3 = gd.add_node();
210 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
211 .Device("CPU:0")
212 .Input(n1->name(), 0, DT_INT32)
213 .Input(n2->name(), 0, DT_INT32)
214 .Finalize(n3));
215
216 NodeDef* n4 = gd.add_node();
217 TF_CHECK_OK(NodeDefBuilder("graph_3/node_7", "Const")
218 .Attr("value", 1)
219 .Device("CPU:0")
220 .Finalize(n4));
221
222 NodeDef* n5 = gd.add_node();
223 TF_CHECK_OK(NodeDefBuilder("graph_4/node_9", "Const")
224 .Attr("value", 2)
225 .Device("CPU:0")
226 .Finalize(n5));
227
228 NodeDef* n6 = gd.add_node();
229 TF_CHECK_OK(NodeDefBuilder("graph_5/node_11", "Add")
230 .Device("CPU:0")
231 .Input(n4->name(), 0, DT_INT32)
232 .Input(n5->name(), 0, DT_INT32)
233 .Finalize(n6));
234
235 uint64 hash1 = GetHash(gd, *n3);
236 uint64 hash2 = GetHash(gd, *n6);
237 EXPECT_EQ(hash1, hash2);
238 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n3, gd, n6));
239 }
240
TEST_F(DatasetHashUtilsTest,HashNodeDifferentGraphs)241 TEST_F(DatasetHashUtilsTest, HashNodeDifferentGraphs) {
242 GraphDef gd;
243
244 NodeDef* n1 = gd.add_node();
245 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
246 .Attr("value", 1)
247 .Device("CPU:0")
248 .Finalize(n1));
249
250 NodeDef* n2 = gd.add_node();
251 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
252 .Attr("value", 2)
253 .Device("CPU:0")
254 .Finalize(n2));
255
256 NodeDef* n3 = gd.add_node();
257 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
258 .Device("CPU:0")
259 .Input(n1->name(), 0, DT_INT32)
260 .Input(n2->name(), 0, DT_INT32)
261 .Finalize(n3));
262
263 NodeDef* n4 = gd.add_node();
264 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Mul")
265 .Device("CPU:0")
266 .Input(n1->name(), 0, DT_INT32)
267 .Input(n2->name(), 0, DT_INT32)
268 .Finalize(n4));
269
270 uint64 hash1 = GetHash(gd, *n3);
271 uint64 hash2 = GetHash(gd, *n4);
272 // We expect different hashes because the op has changed.
273 EXPECT_NE(hash1, hash2);
274 Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
275 EXPECT_NE(s.code(), error::OK);
276 EXPECT_THAT(s.error_message(), ContainsRegex("Add"));
277 EXPECT_THAT(s.error_message(), ContainsRegex("Mul"));
278 }
279
TEST_F(DatasetHashUtilsTest,HashSameGraphDifferentSeeds)280 TEST_F(DatasetHashUtilsTest, HashSameGraphDifferentSeeds) {
281 GraphDef gd;
282
283 NodeDef* n1 = gd.add_node();
284 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
285 .Attr("value", 1)
286 .Device("CPU:0")
287 .Finalize(n1));
288
289 NodeDef* seed = gd.add_node();
290 TF_CHECK_OK(NodeDefBuilder("graph_1/seed", "Const")
291 .Attr("value", 123)
292 .Device("CPU:0")
293 .Finalize(seed));
294
295 NodeDef* seed2 = gd.add_node();
296 TF_CHECK_OK(NodeDefBuilder("graph_1/seed2", "Const")
297 .Attr("value", 456)
298 .Device("CPU:0")
299 .Finalize(seed2));
300
301 NodeDef* range_ds = gd.add_node();
302 TF_CHECK_OK(NodeDefBuilder("graph_1/range", "RangeDataset")
303 .Input(n1->name(), 0, DT_INT64)
304 .Input(n1->name(), 0, DT_INT64)
305 .Input(n1->name(), 0, DT_INT64)
306 .Device("CPU:0")
307 .Finalize(range_ds));
308
309 NodeDef* shuffle_ds = gd.add_node();
310 TF_CHECK_OK(NodeDefBuilder("graph_1/shuffle", "ShuffleDataset")
311 .Input(range_ds->name(), 0, DT_VARIANT)
312 .Input(n1->name(), 0, DT_INT64)
313 .Input(seed->name(), 0, DT_INT64)
314 .Input(seed2->name(), 0, DT_INT64)
315 .Device("CPU:0")
316 .Finalize(shuffle_ds));
317
318 NodeDef* different_seed = gd.add_node();
319 TF_CHECK_OK(NodeDefBuilder("graph_1/different_seed", "Const")
320 .Attr("value", 789)
321 .Device("CPU:0")
322 .Finalize(different_seed));
323 NodeDef* different_seed2 = gd.add_node();
324 TF_CHECK_OK(NodeDefBuilder("graph_1/different_seed2", "Const")
325 .Attr("value", 654)
326 .Device("CPU:0")
327 .Finalize(different_seed2));
328
329 NodeDef* range_ds_2 = gd.add_node();
330 TF_CHECK_OK(NodeDefBuilder("graph_1/range_2", "RangeDataset")
331 .Input(n1->name(), 0, DT_INT64)
332 .Input(n1->name(), 0, DT_INT64)
333 .Input(n1->name(), 0, DT_INT64)
334 .Device("CPU:0")
335 .Finalize(range_ds_2));
336
337 NodeDef* shuffle_ds_2 = gd.add_node();
338 TF_CHECK_OK(NodeDefBuilder("graph_1/shuffle_2", "ShuffleDataset")
339 .Input(range_ds_2->name(), 0, DT_VARIANT)
340 .Input(n1->name(), 0, DT_INT64)
341 .Input(different_seed->name(), 0, DT_INT64)
342 .Input(different_seed2->name(), 0, DT_INT64)
343 .Device("CPU:0")
344 .Finalize(shuffle_ds_2));
345
346 uint64 hash1 = GetHash(gd, *shuffle_ds);
347 uint64 hash2 = GetHash(gd, *shuffle_ds_2);
348 EXPECT_EQ(hash1, hash2);
349 TF_EXPECT_OK(CheckSubgraphsEqual(gd, shuffle_ds, gd, shuffle_ds_2));
350 }
351
TEST_F(DatasetHashUtilsTest,HashNodeSameGraphDifferentColocationNames)352 TEST_F(DatasetHashUtilsTest, HashNodeSameGraphDifferentColocationNames) {
353 GraphDef gd;
354
355 NodeDef* n1 = gd.add_node();
356 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
357 .Attr("value", 1)
358 .Attr("_class", {"graph_1/node_2"})
359 .Device("CPU:0")
360 .Finalize(n1));
361
362 NodeDef* n2 = gd.add_node();
363 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
364 .Attr("value", 2)
365 .Device("CPU:0")
366 .Finalize(n2));
367
368 NodeDef* n3 = gd.add_node();
369 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
370 .Device("CPU:0")
371 .Input(n1->name(), 0, DT_INT32)
372 .Input(n2->name(), 0, DT_INT32)
373 .Finalize(n3));
374
375 NodeDef* n4 = gd.add_node();
376 TF_CHECK_OK(NodeDefBuilder("graph_3/node_7", "Const")
377 .Attr("value", 1)
378 .Attr("_class", {"graph_3/node_9"})
379 .Device("CPU:0")
380 .Finalize(n4));
381
382 NodeDef* n5 = gd.add_node();
383 TF_CHECK_OK(NodeDefBuilder("graph_4/node_9", "Const")
384 .Attr("value", 2)
385 .Device("CPU:0")
386 .Finalize(n5));
387
388 NodeDef* n6 = gd.add_node();
389 TF_CHECK_OK(NodeDefBuilder("graph_5/node_11", "Add")
390 .Device("CPU:0")
391 .Input(n1->name(), 0, DT_INT32)
392 .Input(n2->name(), 0, DT_INT32)
393 .Finalize(n6));
394
395 uint64 hash1 = GetHash(gd, *n3);
396 uint64 hash2 = GetHash(gd, *n6);
397
398 EXPECT_EQ(hash1, hash2);
399 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n3, gd, n6));
400 }
401
TEST_F(DatasetHashUtilsTest,HashNodeReversedOrder)402 TEST_F(DatasetHashUtilsTest, HashNodeReversedOrder) {
403 GraphDef gd;
404
405 NodeDef* n1 = gd.add_node();
406 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
407 .Attr("value", 1)
408 .Device("CPU:0")
409 .Finalize(n1));
410
411 NodeDef* n2 = gd.add_node();
412 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
413 .Attr("value", 2)
414 .Device("CPU:0")
415 .Finalize(n2));
416
417 NodeDef* n3 = gd.add_node();
418 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
419 .Device("CPU:0")
420 .Input(n1->name(), 0, DT_INT32)
421 .Input(n2->name(), 0, DT_INT32)
422 .Finalize(n3));
423
424 NodeDef* n4 = gd.add_node();
425 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Add")
426 .Device("CPU:0")
427 .Input(n2->name(), 0, DT_INT32)
428 .Input(n1->name(), 0, DT_INT32)
429 .Finalize(n4));
430
431 uint64 hash1 = GetHash(gd, *n3);
432 uint64 hash2 = GetHash(gd, *n4);
433 // We expect different hashes because the inputs of n3 are swapped.
434 EXPECT_NE(hash1, hash2);
435 Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
436 EXPECT_NE(s.code(), error::OK);
437 EXPECT_THAT(s.error_message(), ContainsRegex("AttrValues are different"));
438 }
439
TEST_F(DatasetHashUtilsTest,HashNodeInputPortChanged)440 TEST_F(DatasetHashUtilsTest, HashNodeInputPortChanged) {
441 GraphDef gd;
442
443 NodeDef* n1 = gd.add_node();
444 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
445 .Attr("value", 1)
446 .Device("CPU:0")
447 .Finalize(n1));
448
449 NodeDef* n2 = gd.add_node();
450 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
451 .Attr("value", 2)
452 .Device("CPU:0")
453 .Finalize(n2));
454
455 NodeDef* n3 = gd.add_node();
456 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
457 .Device("CPU:0")
458 .Input(n1->name(), 0, DT_INT32)
459 .Input(n2->name(), 0, DT_INT32)
460 .Finalize(n3));
461
462 NodeDef* n4 = gd.add_node();
463 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Add")
464 .Device("CPU:0")
465 .Input(n1->name(), 1, DT_INT32)
466 .Input(n2->name(), 2, DT_INT32)
467 .Finalize(n4));
468
469 uint64 hash1 = GetHash(gd, *n3);
470 uint64 hash2 = GetHash(gd, *n4);
471 // We expect different hashes because the input ports for nodes used by n3
472 // has changed.
473 EXPECT_NE(hash1, hash2);
474 Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
475 EXPECT_NE(s.code(), error::OK);
476 EXPECT_THAT(s.error_message(), ContainsRegex("Node inputs"));
477 }
478
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionDifferentNames)479 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionDifferentNames) {
480 GraphDef gd;
481 FunctionDefLibrary* fl1 = gd.mutable_library();
482
483 FunctionDef* f1 = fl1->add_function();
484 *f1 = FunctionDefHelper::Create(
485 "AddAndMul", {"i: float"}, {"o: float"}, {},
486 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
487 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
488 /*ret_def=*/{{"o", "ret:z:0"}},
489 /*control_ret_def=*/{{"must_execute", "add"}});
490
491 FunctionDef* f2 = fl1->add_function();
492 *f2 = FunctionDefHelper::Create(
493 "AddAndMul2", {"input: float"}, {"o: float"}, {},
494 {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
495 {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
496 /*ret_def=*/{{"o", "ret:z:0"}},
497 /*control_ret_def=*/{{"must_execute", "add"}});
498
499 AttrValue a1;
500 NameAttrList* nal1 = a1.mutable_func();
501 nal1->set_name("AddAndMul");
502
503 NodeDef* n1 = gd.add_node();
504 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
505 .Attr("value", 1)
506 .Device("CPU:0")
507 .Finalize(n1));
508
509 std::vector<NodeDefBuilder::NodeOut> func_inputs;
510 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
511 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
512
513 NodeDef* n2 = gd.add_node();
514 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
515 .Input(n1->name(), 0, DT_INT32)
516 .Input(n1->name(), 0, DT_INT32)
517 .Input(n1->name(), 0, DT_INT32)
518 .Input(func_inputs)
519 .Attr("body", a1)
520 .Device("CPU:0")
521 .Finalize(n2));
522
523 NodeDef* n3 = gd.add_node();
524 AttrValue a2;
525 NameAttrList* nal2 = a2.mutable_func();
526 nal2->set_name("AddAndMul2");
527
528 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
529 .Input(n1->name(), 0, DT_INT32)
530 .Input(n1->name(), 0, DT_INT32)
531 .Input(n1->name(), 0, DT_INT32)
532 .Input(func_inputs)
533 .Attr("body", a2)
534 .Device("CPU:0")
535 .Finalize(n3));
536
537 uint64 hash1 = GetHash(gd, *n2);
538 uint64 hash2 = GetHash(gd, *n3);
539 EXPECT_EQ(hash1, hash2);
540 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
541 }
542
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionListsDifferentNames)543 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionListsDifferentNames) {
544 GraphDef gd;
545 FunctionDefLibrary* fl1 = gd.mutable_library();
546
547 FunctionDef* f1 = fl1->add_function();
548 *f1 = FunctionDefHelper::Create(
549 "AddAndMul", {"i: float"}, {"o: float"}, {},
550 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
551 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
552 /*ret_def=*/{{"o", "ret:z:0"}},
553 /*control_ret_def=*/{{"must_execute", "add"}});
554
555 FunctionDef* f2 = fl1->add_function();
556 *f2 = FunctionDefHelper::Create(
557 "AddAndMul2", {"input: float"}, {"o: float"}, {},
558 {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
559 {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
560 /*ret_def=*/{{"o", "ret:z:0"}},
561 /*control_ret_def=*/{{"must_execute", "add"}});
562
563 AttrValue a1;
564 AttrValue_ListValue* list1 = a1.mutable_list();
565 NameAttrList* nal1 = list1->add_func();
566 nal1->set_name("AddAndMul");
567
568 NodeDef* n1 = gd.add_node();
569 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
570 .Attr("value", 1)
571 .Device("CPU:0")
572 .Finalize(n1));
573
574 std::vector<NodeDefBuilder::NodeOut> func_inputs;
575 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
576 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
577
578 NodeDef* n2 = gd.add_node();
579 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
580 .Input(n1->name(), 0, DT_INT32)
581 .Input(n1->name(), 0, DT_INT32)
582 .Input(n1->name(), 0, DT_INT32)
583 .Input(func_inputs)
584 .Attr("body", a1)
585 .Device("CPU:0")
586 .Finalize(n2));
587
588 NodeDef* n3 = gd.add_node();
589 AttrValue a2;
590 AttrValue_ListValue* list2 = a2.mutable_list();
591 NameAttrList* nal2 = list2->add_func();
592 nal2->set_name("AddAndMul2");
593
594 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
595 .Input(n1->name(), 0, DT_INT32)
596 .Input(n1->name(), 0, DT_INT32)
597 .Input(n1->name(), 0, DT_INT32)
598 .Input(func_inputs)
599 .Attr("body", a2)
600 .Device("CPU:0")
601 .Finalize(n3));
602
603 uint64 hash1 = GetHash(gd, *n2);
604 uint64 hash2 = GetHash(gd, *n3);
605 EXPECT_EQ(hash1, hash2);
606 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
607 }
608
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionsOps)609 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionsOps) {
610 GraphDef gd;
611
612 FunctionDefLibrary* fl1 = gd.mutable_library();
613 FunctionDef* f1 = fl1->add_function();
614
615 FunctionDef func = FunctionDefHelper::Create(
616 "AddAndMul", {"i: float"}, {"o: float"}, {},
617 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
618 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
619 /*ret_def=*/{{"o", "ret:z:0"}},
620 /*control_ret_def=*/{{"must_execute", "add"}});
621 *f1 = func;
622
623 FunctionDef* f2 = fl1->add_function();
624 func = FunctionDefHelper::Create(
625 "AddAndMul2", {"i: float"}, {"o: float"}, {},
626 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
627 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
628 /*ret_def=*/{{"o", "ret:z:0"}},
629 /*control_ret_def=*/{{"must_execute", "add"}});
630 *f2 = func;
631 FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library());
632
633 NodeDef* n1 = gd.add_node();
634 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
635 .Attr("value", 1)
636 .Device("CPU:0")
637 .Finalize(n1));
638
639 NodeDef* n2 = gd.add_node();
640 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib)
641 .Input(n1->name(), 0, DT_FLOAT)
642 .Device("CPU:0")
643 .Finalize(n2));
644
645 NodeDef* n3 = gd.add_node();
646 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "AddAndMul2", &flib)
647 .Input(n1->name(), 0, DT_FLOAT)
648 .Device("CPU:0")
649 .Finalize(n3));
650
651 uint64 hash1 = GetHash(gd, *n2);
652 uint64 hash2 = GetHash(gd, *n3);
653 EXPECT_EQ(hash1, hash2);
654 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
655 }
656
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctionsOps)657 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionsOps) {
658 GraphDef gd;
659
660 FunctionDefLibrary* fl1 = gd.mutable_library();
661 FunctionDef* f1 = fl1->add_function();
662
663 FunctionDef func = FunctionDefHelper::Create(
664 "AddAndMul", {"i: float"}, {"o: float"}, {},
665 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
666 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
667 /*ret_def=*/{{"o", "ret:z:0"}},
668 /*control_ret_def=*/{{"must_execute", "add"}});
669 *f1 = func;
670
671 FunctionDef* f2 = fl1->add_function();
672 func = FunctionDefHelper::Create(
673 "AddAndMul2", {"i: float"}, {"o: float"}, {},
674 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
675 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
676 /*ret_def=*/{{"o", "ret:z:0"}},
677 /*control_ret_def=*/{{"must_execute", "ret"}});
678 *f2 = func;
679 FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library());
680
681 NodeDef* n1 = gd.add_node();
682 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
683 .Attr("value", 1)
684 .Device("CPU:0")
685 .Finalize(n1));
686
687 NodeDef* n2 = gd.add_node();
688 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib)
689 .Input(n1->name(), 0, DT_FLOAT)
690 .Device("CPU:0")
691 .Finalize(n2));
692
693 NodeDef* n3 = gd.add_node();
694 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "AddAndMul2", &flib)
695 .Input(n1->name(), 0, DT_FLOAT)
696 .Device("CPU:0")
697 .Finalize(n3));
698
699 uint64 hash1 = GetHash(gd, *n2);
700 uint64 hash2 = GetHash(gd, *n3);
701 EXPECT_NE(hash1, hash2);
702 Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
703 EXPECT_NE(s.code(), error::OK);
704 EXPECT_THAT(
705 s.error_message(),
706 ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
707 }
708
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctions)709 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctions) {
710 GraphDef gd;
711
712 FunctionDefLibrary* fl1 = gd.mutable_library();
713 FunctionDef* f1 = fl1->add_function();
714
715 FunctionDef func = FunctionDefHelper::Create(
716 "AddAndMul", {"i: float"}, {"o: float"}, {},
717 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
718 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
719 /*ret_def=*/{{"o", "ret:z:0"}},
720 /*control_ret_def=*/{{"must_execute", "add"}});
721 *f1 = func;
722
723 FunctionDef* f2 = fl1->add_function();
724 func = FunctionDefHelper::Create(
725 "AddAndMul2", {"i: float"}, {"o: float"}, {},
726 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
727 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
728 /*ret_def=*/{{"o", "ret:z:0"}},
729 /*control_ret_def=*/{{"must_execute", "ret"}});
730 *f2 = func;
731
732 AttrValue a1;
733 NameAttrList* nal1 = a1.mutable_func();
734 nal1->set_name("AddAndMul");
735
736 NodeDef* n1 = gd.add_node();
737 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
738 .Attr("value", 1)
739 .Device("CPU:0")
740 .Finalize(n1));
741
742 std::vector<NodeDefBuilder::NodeOut> func_inputs;
743 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
744 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
745
746 NodeDef* n2 = gd.add_node();
747 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
748 .Input(n1->name(), 0, DT_INT32)
749 .Input(n1->name(), 0, DT_INT32)
750 .Input(n1->name(), 0, DT_INT32)
751 .Input(func_inputs)
752 .Attr("body", a1)
753 .Device("CPU:0")
754 .Finalize(n2));
755
756 NodeDef* n3 = gd.add_node();
757 AttrValue a2;
758 NameAttrList* nal2 = a2.mutable_func();
759 nal2->set_name("AddAndMul2");
760
761 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
762 .Input(n1->name(), 0, DT_INT32)
763 .Input(n1->name(), 0, DT_INT32)
764 .Input(n1->name(), 0, DT_INT32)
765 .Input(func_inputs)
766 .Attr("body", a2)
767 .Device("CPU:0")
768 .Finalize(n3));
769
770 uint64 hash1 = GetHash(gd, *n2);
771 uint64 hash2 = GetHash(gd, *n3);
772 EXPECT_NE(hash1, hash2);
773 Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
774 EXPECT_NE(s.code(), error::OK);
775 EXPECT_THAT(
776 s.error_message(),
777 ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
778 }
779
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctionLists)780 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionLists) {
781 GraphDef gd;
782
783 FunctionDefLibrary* fl1 = gd.mutable_library();
784 FunctionDef* f1 = fl1->add_function();
785
786 FunctionDef func = FunctionDefHelper::Create(
787 "AddAndMul", {"i: float"}, {"o: float"}, {},
788 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
789 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
790 /*ret_def=*/{{"o", "ret:z:0"}},
791 /*control_ret_def=*/{{"must_execute", "add"}});
792 *f1 = func;
793
794 FunctionDef* f2 = fl1->add_function();
795 func = FunctionDefHelper::Create(
796 "AddAndMul2", {"i: float"}, {"o: float"}, {},
797 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
798 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
799 /*ret_def=*/{{"o", "ret:z:0"}},
800 /*control_ret_def=*/{{"must_execute", "ret"}});
801 *f2 = func;
802
803 AttrValue a1;
804 AttrValue_ListValue* list1 = a1.mutable_list();
805 NameAttrList* nal1 = list1->add_func();
806 nal1->set_name("AddAndMul");
807
808 NodeDef* n1 = gd.add_node();
809 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
810 .Attr("value", 1)
811 .Device("CPU:0")
812 .Finalize(n1));
813
814 std::vector<NodeDefBuilder::NodeOut> func_inputs;
815 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
816 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
817
818 NodeDef* n2 = gd.add_node();
819 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
820 .Input(n1->name(), 0, DT_INT32)
821 .Input(n1->name(), 0, DT_INT32)
822 .Input(n1->name(), 0, DT_INT32)
823 .Input(func_inputs)
824 .Attr("body", a1)
825 .Device("CPU:0")
826 .Finalize(n2));
827
828 NodeDef* n3 = gd.add_node();
829 AttrValue a2;
830 AttrValue_ListValue* list2 = a2.mutable_list();
831 NameAttrList* nal2 = list2->add_func();
832 nal2->set_name("AddAndMul2");
833
834 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
835 .Input(n1->name(), 0, DT_INT32)
836 .Input(n1->name(), 0, DT_INT32)
837 .Input(n1->name(), 0, DT_INT32)
838 .Input(func_inputs)
839 .Attr("body", a2)
840 .Device("CPU:0")
841 .Finalize(n3));
842
843 uint64 hash1 = GetHash(gd, *n2);
844 uint64 hash2 = GetHash(gd, *n3);
845 EXPECT_NE(hash1, hash2);
846 Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
847 EXPECT_NE(s.code(), error::OK);
848 EXPECT_THAT(
849 s.error_message(),
850 ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
851 }
852
TEST_F(DatasetHashUtilsTest,HashNodeDifferentControlInputs)853 TEST_F(DatasetHashUtilsTest, HashNodeDifferentControlInputs) {
854 GraphDef gd;
855
856 NodeDef* n1 = gd.add_node();
857 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
858 .Attr("value", 1)
859 .Device("CPU:0")
860 .Finalize(n1));
861
862 NodeDef* n2 = gd.add_node();
863 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
864 .Attr("value", 2)
865 .Device("CPU:0")
866 .Finalize(n2));
867
868 NodeDef* n3 = gd.add_node();
869 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Const")
870 .Attr("value", 10)
871 .Device("CPU:0")
872 .Finalize(n3));
873
874 NodeDef* n4 = gd.add_node();
875 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Identity")
876 .Device("CPU:0")
877 .Input(n1->name(), 0, DT_INT32)
878 .ControlInput(n2->name())
879 .Finalize(n4));
880
881 NodeDef* n5 = gd.add_node();
882 TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Identity")
883 .Device("CPU:0")
884 .Input(n1->name(), 0, DT_INT32)
885 .ControlInput(n3->name())
886 .Finalize(n5));
887
888 // Control inputs are different between these two graphs.
889 uint64 hash1 = GetHash(gd, *n4);
890 uint64 hash2 = GetHash(gd, *n5);
891 EXPECT_NE(hash1, hash2);
892 Status s = CheckSubgraphsEqual(gd, n4, gd, n5);
893 EXPECT_NE(s.code(), error::OK);
894 EXPECT_THAT(s.error_message(),
895 ContainsRegex("Control dependencies are different"));
896 }
897
TEST_F(DatasetHashUtilsTest,HashNodeControlInputDifferentOrdering)898 TEST_F(DatasetHashUtilsTest, HashNodeControlInputDifferentOrdering) {
899 GraphDef gd;
900
901 NodeDef* n1 = gd.add_node();
902 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
903 .Attr("value", 1)
904 .Device("CPU:0")
905 .Finalize(n1));
906
907 NodeDef* n2 = gd.add_node();
908 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
909 .Attr("value", 2)
910 .Device("CPU:0")
911 .Finalize(n2));
912
913 NodeDef* n3 = gd.add_node();
914 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Const")
915 .Attr("value", 10)
916 .Device("CPU:0")
917 .Finalize(n3));
918
919 NodeDef* n4 = gd.add_node();
920 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Identity")
921 .Device("CPU:0")
922 .Input(n1->name(), 0, DT_INT32)
923 .ControlInput(n2->name())
924 .ControlInput(n3->name())
925 .Finalize(n4));
926
927 NodeDef* n5 = gd.add_node();
928 TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Identity")
929 .Device("CPU:0")
930 .Input(n1->name(), 0, DT_INT32)
931 .ControlInput(n3->name())
932 .ControlInput(n2->name())
933 .Finalize(n5));
934
935 uint64 hash1 = GetHash(gd, *n4);
936 uint64 hash2 = GetHash(gd, *n5);
937 EXPECT_EQ(hash1, hash2);
938 TF_EXPECT_OK(CheckSubgraphsEqual(gd, n4, gd, n5));
939 }
940
TEST_F(DatasetHashUtilsTest,HashNodeDifferentGraphSamePartialGraph)941 TEST_F(DatasetHashUtilsTest, HashNodeDifferentGraphSamePartialGraph) {
942 GraphDef gd;
943
944 NodeDef* n1 = gd.add_node();
945 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
946 .Attr("value", 1)
947 .Device("CPU:0")
948 .Finalize(n1));
949
950 NodeDef* n2 = gd.add_node();
951 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
952 .Attr("value", 2)
953 .Device("CPU:0")
954 .Finalize(n2));
955
956 NodeDef* n3 = gd.add_node();
957
958 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
959 .Device("CPU:0")
960 .Input(n1->name(), 0, DT_INT32)
961 .Input(n2->name(), 0, DT_INT32)
962 .Finalize(n3));
963
964 uint64 hash1 = GetHash(gd, *n1);
965
966 n3->Clear();
967 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Mul")
968 .Device("CPU:0")
969 .Input(n1->name(), 0, DT_INT32)
970 .Input(n2->name(), 0, DT_INT32)
971 .Finalize(n3));
972
973 uint64 hash2 = GetHash(gd, *n1);
974
975 EXPECT_EQ(hash1, hash2);
976 }
977
TEST_F(DatasetHashUtilsTest,HashNodeWithManyControlDependencies)978 TEST_F(DatasetHashUtilsTest, HashNodeWithManyControlDependencies) {
979 GraphDef gd;
980 NodeDef* n;
981
982 for (int i = 0; i < 1000; ++i) {
983 n = gd.add_node();
984 NodeDefBuilder ndb(absl::StrCat("graph_1/node_", i), "Const");
985 ndb.Attr("value", 1);
986 ndb.Device("CPU:0");
987 for (int j = 0; j < i; ++j) {
988 ndb.ControlInput(absl::StrCat("graph_1/node_", j));
989 }
990 TF_CHECK_OK(ndb.Finalize(n));
991 }
992
993 // No checks here, because so long as this does not time out, we are OK.
994 GetHash(gd, *n);
995 }
996
TEST_F(DatasetHashUtilsTest,HashFunctionsWithControlDependencyLoop)997 TEST_F(DatasetHashUtilsTest, HashFunctionsWithControlDependencyLoop) {
998 GraphDef gd;
999
1000 FunctionDefLibrary* fl1 = gd.mutable_library();
1001 FunctionDef* f1 = fl1->add_function();
1002
1003 AttrValue a1;
1004 NameAttrList* nal1 = a1.mutable_func();
1005 nal1->set_name("AddAndMul");
1006
1007 std::pair<string, FunctionDefHelper::AttrValueWrapper> func_attr = {
1008 "body", FunctionDefHelper::AttrValueWrapper(*nal1)};
1009
1010 FunctionDef func = FunctionDefHelper::Create(
1011 /*function_name=*/"AddAndMul",
1012 /*in_def=*/{"i: float", "j: int32"},
1013 /*out_def=*/{"o: float"},
1014 /*attr_def=*/{},
1015 /*node_def=*/
1016 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}, {"ret"}},
1017 // This creates a dependency on the same function.
1018 {{"for"}, "For", {"j", "j", "j"}, {func_attr, {"T", DT_FLOAT}}, {"ret"}},
1019 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1020 /*ret_def=*/{{"o", "ret:z:0"}},
1021 /*control_ret_def=*/{{"must_execute", "add"}});
1022 *f1 = func;
1023
1024 NodeDef* n1 = gd.add_node();
1025 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1026 .Attr("value", 1)
1027 .Device("CPU:0")
1028 .Finalize(n1));
1029
1030 std::vector<NodeDefBuilder::NodeOut> func_inputs;
1031 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
1032 func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
1033
1034 NodeDef* n2 = gd.add_node();
1035 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
1036 .Input(n1->name(), 0, DT_INT32)
1037 .Input(n1->name(), 0, DT_INT32)
1038 .Input(n1->name(), 0, DT_INT32)
1039 .Input(func_inputs)
1040 .ControlInput("graph_1/node_2")
1041 .Attr("body", a1)
1042 .Device("CPU:0")
1043 .Finalize(n2));
1044
1045 // No checks in the test, the fact that it runs and doesn't timeout or exhaust
1046 // the stack means it is successful.
1047 GetHash(gd, *n2);
1048 }
1049
TEST_F(DatasetHashUtilsTest,HashNodeWithControlDependencyLoop)1050 TEST_F(DatasetHashUtilsTest, HashNodeWithControlDependencyLoop) {
1051 GraphDef gd;
1052
1053 NodeDef* n1 = gd.add_node();
1054 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1055 .Attr("value", 1)
1056 .Device("CPU:0")
1057 .ControlInput("graph_1/node_2")
1058 .Finalize(n1));
1059
1060 NodeDef* n2 = gd.add_node();
1061 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
1062 .Attr("value", 2)
1063 .Device("CPU:0")
1064 .ControlInput("graph_1/node_1")
1065 .Finalize(n2));
1066
1067 NodeDef* n3 = gd.add_node();
1068 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
1069 .Device("CPU:0")
1070 .Input(n1->name(), 0, DT_INT32)
1071 .Input(n2->name(), 0, DT_INT32)
1072 .ControlInput("graph_1/node_1")
1073 .ControlInput("graph_1/node_2")
1074 .Finalize(n3));
1075
1076 // No checks in the test, the fact that it runs and doesn't timeout or exhaust
1077 // the stack means it is successful.
1078 GetHash(gd, *n3);
1079 }
1080
TEST_F(DatasetHashUtilsTest,HashNodeWithControlDependencyLoopDifferentNames)1081 TEST_F(DatasetHashUtilsTest, HashNodeWithControlDependencyLoopDifferentNames) {
1082 GraphDef gd1;
1083
1084 NodeDef* n1 = gd1.add_node();
1085 TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1086 .Attr("value", 1)
1087 .Device("CPU:0")
1088 .ControlInput("graph_1/node_2")
1089 .Finalize(n1));
1090
1091 NodeDef* n2 = gd1.add_node();
1092 TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
1093 .Attr("value", 2)
1094 .Device("CPU:0")
1095 .ControlInput("graph_1/node_1")
1096 .Finalize(n2));
1097
1098 NodeDef* n3 = gd1.add_node();
1099 TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
1100 .Device("CPU:0")
1101 .Input(n1->name(), 0, DT_INT32)
1102 .Input(n2->name(), 0, DT_INT32)
1103 .ControlInput("graph_1/node_1")
1104 .ControlInput("graph_1/node_2")
1105 .Finalize(n3));
1106
1107 GraphDef gd2;
1108
1109 NodeDef* n4 = gd2.add_node();
1110 TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Const")
1111 .Attr("value", 1)
1112 .Device("CPU:0")
1113 .ControlInput("graph_1/node_5")
1114 .Finalize(n4));
1115
1116 NodeDef* n5 = gd2.add_node();
1117 TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Const")
1118 .Attr("value", 2)
1119 .Device("CPU:0")
1120 .ControlInput("graph_1/node_4")
1121 .Finalize(n5));
1122
1123 NodeDef* n6 = gd2.add_node();
1124 TF_CHECK_OK(NodeDefBuilder("graph_1/node_6", "Add")
1125 .Device("CPU:0")
1126 .Input(n4->name(), 0, DT_INT32)
1127 .Input(n5->name(), 0, DT_INT32)
1128 .ControlInput("graph_1/node_4")
1129 .ControlInput("graph_1/node_5")
1130 .Finalize(n6));
1131
1132 EXPECT_EQ(GetHash(gd1, *n3), GetHash(gd2, *n6));
1133 }
1134
TEST_F(DatasetHashUtilsTest,HashInt32Tensor)1135 TEST_F(DatasetHashUtilsTest, HashInt32Tensor) {
1136 Tensor s1(42);
1137 Tensor s2(42);
1138 Tensor s3(43);
1139
1140 EXPECT_EQ(GetHash(s1), GetHash(s2));
1141 EXPECT_NE(GetHash(s1), GetHash(s3));
1142
1143 Tensor v1(DT_INT32, TensorShape({2}));
1144 v1.vec<int32>()(0) = 0;
1145 v1.vec<int32>()(1) = 1;
1146 Tensor v2(DT_INT32, TensorShape({2}));
1147 v2.vec<int32>()(0) = 0;
1148 v2.vec<int32>()(1) = 1;
1149 Tensor v3(DT_INT32, TensorShape({2}));
1150 v3.vec<int32>()(0) = 0;
1151 v3.vec<int32>()(1) = 2;
1152
1153 EXPECT_EQ(GetHash(v1), GetHash(v2));
1154 EXPECT_NE(GetHash(v1), GetHash(v3));
1155 }
1156
TEST_F(DatasetHashUtilsTest,HashStringTensor)1157 TEST_F(DatasetHashUtilsTest, HashStringTensor) {
1158 Tensor s1("hello");
1159 Tensor s2("hello");
1160 Tensor s3("world");
1161
1162 EXPECT_EQ(GetHash(s1), GetHash(s2));
1163 EXPECT_NE(GetHash(s1), GetHash(s3));
1164
1165 Tensor v1(DT_STRING, TensorShape({2}));
1166 v1.vec<tstring>()(0) = "hello";
1167 v1.vec<tstring>()(1) = "world";
1168 Tensor v2(DT_STRING, TensorShape({2}));
1169 v2.vec<tstring>()(0) = "hello";
1170 v2.vec<tstring>()(1) = "world";
1171 Tensor v3(DT_STRING, TensorShape({2}));
1172 v3.vec<tstring>()(0) = "hello";
1173 v3.vec<tstring>()(1) = "universe";
1174
1175 EXPECT_EQ(GetHash(v1), GetHash(v2));
1176 EXPECT_NE(GetHash(v1), GetHash(v3));
1177 }
1178
1179 // Benchmark that simulates a shallow and wide graph.
BM_ParallelFunctionCallsGraph(benchmark::State & state)1180 static void BM_ParallelFunctionCallsGraph(benchmark::State& state) {
1181 GraphDef graph_def;
1182 FunctionDefLibrary* fl = graph_def.mutable_library();
1183
1184 FunctionDef* fd = fl->add_function();
1185 *fd = FunctionDefHelper::Create(
1186 "AddAndMul", {"i: float"}, {"o: float"}, {},
1187 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1188 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1189 /*ret_def=*/{{"o", "ret:z:0"}},
1190 /*control_ret_def=*/{{"must_execute", "add"}});
1191
1192 NodeDef* input = graph_def.add_node();
1193 input->set_name("InputPlaceholder");
1194 input->set_op("Placeholder");
1195 AddNodeAttr("dtype", DT_FLOAT, input);
1196
1197 // Equivalent of a `tf.group()`.
1198 NodeDef* target = graph_def.add_node();
1199 target->set_name("Target");
1200 target->set_op("NoOp");
1201
1202 // Create 100 parallel PartitionedCalls that all depend on input. Generate a
1203 // NodeDef that has close to similar attributes that TensorFlow will generate.
1204 ConfigProto config_pb;
1205 config_pb.mutable_device_count()->insert({"CPU", 1});
1206 config_pb.mutable_device_count()->insert({"GPU", 1});
1207 config_pb.set_allow_soft_placement(true);
1208 for (int i = 0; i < 100; ++i) {
1209 NodeDef* node = graph_def.add_node();
1210 node->set_name(absl::StrCat("PartitionedCall_", i));
1211 node->set_op("PartitionedCall");
1212 *node->add_input() = input->name();
1213 AddNodeAttr("Tin", DT_FLOAT, node);
1214 AddNodeAttr("Tout", DT_FLOAT, node);
1215 AddNodeAttr("config", "", node);
1216 AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1217 NameAttrList func;
1218 func.set_name(fd->signature().name());
1219 AddNodeAttr("f", func, node);
1220 *target->add_input() = absl::StrCat("^", node->name());
1221 }
1222
1223 uint64 hash_value;
1224 for (auto _ : state) {
1225 CHECK(HashNode(graph_def, *target, &hash_value).ok());
1226 }
1227 }
1228 BENCHMARK(BM_ParallelFunctionCallsGraph);
1229
1230 // Benchmark that simulates a narrow and deep graph.
BM_ChainedFunctionCallsGraph(benchmark::State & state)1231 static void BM_ChainedFunctionCallsGraph(benchmark::State& state) {
1232 GraphDef graph_def;
1233 FunctionDefLibrary* fl = graph_def.mutable_library();
1234
1235 FunctionDef* fd = fl->add_function();
1236 *fd = FunctionDefHelper::Create(
1237 "AddAndMul", {"i: float"}, {"o: float"}, {},
1238 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1239 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1240 /*ret_def=*/{{"o", "ret:z:0"}},
1241 /*control_ret_def=*/{{"must_execute", "add"}});
1242
1243 NodeDef* input = graph_def.add_node();
1244 input->set_name("InputPlaceholder");
1245 input->set_op("Placeholder");
1246 AddNodeAttr("dtype", DT_FLOAT, input);
1247
1248 // Create 100 chained PartitionedCalls, each depending on the previous.
1249 // Generate a NodeDef that has close to similar attributes that TensorFlow
1250 // will generate.
1251 ConfigProto config_pb;
1252 config_pb.mutable_device_count()->insert({"CPU", 1});
1253 config_pb.mutable_device_count()->insert({"GPU", 1});
1254 config_pb.set_allow_soft_placement(true);
1255 for (int i = 0; i < 100; ++i) {
1256 NodeDef* node = graph_def.add_node();
1257 node->set_name(absl::StrCat("PartitionedCall_", i));
1258 node->set_op("PartitionedCall");
1259 if (i > 0) {
1260 *node->add_input() = absl::StrCat("PartitionedCall_", i - 1);
1261 } else {
1262 *node->add_input() = input->name();
1263 }
1264 AddNodeAttr("Tin", DT_FLOAT, node);
1265 AddNodeAttr("Tout", DT_FLOAT, node);
1266 AddNodeAttr("config", "", node);
1267 AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1268 NameAttrList func;
1269 func.set_name(fd->signature().name());
1270 AddNodeAttr("f", func, node);
1271 }
1272
1273 const NodeDef& target = graph_def.node(graph_def.node_size() - 1);
1274
1275 uint64 hash_value;
1276 for (auto _ : state) {
1277 CHECK(HashNode(graph_def, target, &hash_value).ok());
1278 }
1279 }
1280 BENCHMARK(BM_ChainedFunctionCallsGraph);
1281
1282 // Benchmark that simulates many nested function calls.
BM_ComposedFunctionCallsGraph(benchmark::State & state)1283 static void BM_ComposedFunctionCallsGraph(benchmark::State& state) {
1284 GraphDef graph_def;
1285 FunctionDefLibrary* fl = graph_def.mutable_library();
1286
1287 // AddAndMul will be the last function, all others will be calls up to this.
1288 FunctionDef* fd = fl->add_function();
1289 *fd = FunctionDefHelper::Create(
1290 "AddAndMul", {"i: float"}, {"o: float"}, {},
1291 {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1292 {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1293 /*ret_def=*/{{"o", "ret:z:0"}},
1294 /*control_ret_def=*/{{"must_execute", "add"}});
1295
1296 ConfigProto config_pb;
1297 config_pb.mutable_device_count()->insert({"CPU", 1});
1298 config_pb.mutable_device_count()->insert({"GPU", 1});
1299 config_pb.set_allow_soft_placement(true);
1300 for (int i = 0; i < 99; ++i) {
1301 // Get the name fo the previous function
1302 NameAttrList func;
1303 func.set_name(fd->signature().name());
1304
1305 FunctionDef* fd = fl->add_function();
1306 *fd = FunctionDefHelper::Create(
1307 /*function_name=*/absl::StrCat("F_", i),
1308 /*in_def=*/{"i: float"},
1309 /*out_def=*/{"o: float"},
1310 /*attr_def=*/{},
1311 /*node_def=*/
1312 {
1313 {
1314 {"inner_call"},
1315 "PartitionedCall",
1316 {"i"},
1317 {{"Ti", DT_FLOAT},
1318 {"Tout", DT_FLOAT},
1319 {"config", ""},
1320 {"config_proto", config_pb.SerializeAsString()},
1321 {"f", func}},
1322 },
1323 },
1324 /*ret_def=*/{{"o", "inner_call:o:0"}},
1325 /*control_ret_def=*/{{"must_execute", "inner_call"}});
1326 }
1327
1328 NodeDef* input = graph_def.add_node();
1329 input->set_name("InputPlaceholder");
1330 input->set_op("Placeholder");
1331 AddNodeAttr("dtype", DT_FLOAT, input);
1332
1333 // Create call to the outer most function.
1334 NodeDef* node = graph_def.add_node();
1335 node->set_name("PartitionedCall_start");
1336 node->set_op("PartitionedCall");
1337 *node->add_input() = input->name();
1338 AddNodeAttr("Tin", DT_FLOAT, node);
1339 AddNodeAttr("Tout", DT_FLOAT, node);
1340 AddNodeAttr("config", "", node);
1341 AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1342 NameAttrList func;
1343 func.set_name(fd->signature().name());
1344 AddNodeAttr("f", func, node);
1345
1346 const NodeDef& target = graph_def.node(graph_def.node_size() - 1);
1347
1348 uint64 hash_value;
1349 for (auto _ : state) {
1350 CHECK(HashNode(graph_def, target, &hash_value).ok());
1351 }
1352 }
1353 BENCHMARK(BM_ComposedFunctionCallsGraph);
1354
1355 } // namespace
1356 } // namespace data
1357 } // namespace tensorflow
1358